Skip to content

Commit

Permalink
Clean up and complete xfromR and friends
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 11, 2024
1 parent 269c804 commit d082e2b
Show file tree
Hide file tree
Showing 12 changed files with 178 additions and 69 deletions.
9 changes: 5 additions & 4 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,8 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
-> AstTensor AstMethodLet s2 z
AstRFromS :: (KnownShS sh, TensorKind1 r)
=> AstTensor ms s (TKS2 sh r) -> AstTensor ms s (TKR2 (Rank sh) r)
AstRFromX :: (KnownShX sh, TensorKind1 r)
=> AstTensor ms s (TKX2 sh r) -> AstTensor ms s (TKR2 (Rank sh) r)

-- Here starts the shaped part.
AstFromScalar :: GoodScalar r
Expand Down Expand Up @@ -495,11 +497,8 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
AstSFromR :: (KnownShS sh, KnownNat (Rank sh), TensorKind1 r)
=> AstTensor ms s (TKR2 (Rank sh) r) -> AstTensor ms s (TKS2 sh r)
AstSFromX :: ( KnownShS sh, KnownShX sh', Rank sh ~ Rank sh'
, KnownShX (Nested.MapJust sh), TensorKind1 r )
=> AstTensor ms s (TKX2 sh' r) -> AstTensor ms s (TKS2 sh r)
AstXFromS :: ( KnownShS sh, KnownShX sh', sh' ~ Nested.MapJust sh
, TensorKind1 r )
=> AstTensor ms s (TKS2 sh r) -> AstTensor ms s (TKX2 sh' r)
=> AstTensor ms s (TKX2 sh' r) -> AstTensor ms s (TKS2 sh r)

-- Here starts the mixed part.
AstN1X :: (GoodScalar r, KnownShX sh)
Expand Down Expand Up @@ -590,6 +589,8 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
=> AstTensor ms s TKUntyped -> Int -> AstTensor ms s (TKX sh r)
AstXFromR :: (KnownShX sh, KnownNat (Rank sh), TensorKind1 r)
=> AstTensor ms s (TKR2 (Rank sh) r) -> AstTensor ms s (TKX2 sh r)
AstXFromS :: (KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind1 r)
=> AstTensor ms s (TKS2 sh r) -> AstTensor ms s (TKX2 sh' r)

-- Here starts the misc part.
AstMkHVector :: HVector (AstTensor ms s) -> AstTensor ms s TKUntyped
Expand Down
2 changes: 2 additions & 0 deletions src/HordeAd/Core/AstInline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ inlineAst memo v0 = case v0 of
(memo2, v2) = inlineAst memo1 v
in (memo2, Ast.AstLetHVectorIn vars l2 v2)
Ast.AstRFromS v -> second Ast.AstRFromS $ inlineAst memo v
Ast.AstRFromX v -> second Ast.AstRFromX $ inlineAst memo v

Ast.AstMinIndexS a -> second Ast.AstMinIndexS $ inlineAst memo a
Ast.AstMaxIndexS a -> second Ast.AstMaxIndexS $ inlineAst memo a
Expand Down Expand Up @@ -582,6 +583,7 @@ unshareAst memo = \case
let (memo1, l2) = unshareAst memo l
in (memo1, Ast.AstProjectR l2 p)
Ast.AstRFromS v -> second Ast.AstRFromS $ unshareAst memo v
Ast.AstRFromX v -> second Ast.AstRFromX $ unshareAst memo v

Ast.AstMinIndexS a -> second Ast.AstMinIndexS $ unshareAst memo a
Ast.AstMaxIndexS a -> second Ast.AstMaxIndexS $ unshareAst memo a
Expand Down
5 changes: 3 additions & 2 deletions src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,7 @@ interpretAst !env = \case
(\lw -> interpretAst (env2 (dunHVector lw)) v)
_ -> error "TODO"
AstRFromS v -> rfromS $ interpretAst env v
AstRFromX v -> rfromX $ interpretAst env v

AstMinIndexS v ->
sminIndex $ sfromPrimal $ interpretAstPrimalSRuntimeSpecialized env v
Expand Down Expand Up @@ -829,7 +830,6 @@ interpretAst !env = \case
AstUnNestS v -> sunNest $ interpretAst env v
AstSFromR v -> sfromR $ interpretAst env v
AstSFromX v -> sfromX $ interpretAst env v
AstXFromS v -> xfromS $ interpretAst env v

AstMinIndexX _v -> error "TODO"
AstMaxIndexX _v -> error "TODO"
Expand Down Expand Up @@ -881,7 +881,8 @@ interpretAst !env = \case
AstCastX _v -> error "TODO"
AstFromIntegralX _v -> error "TODO"
AstProjectX _l _p -> error "TODO"
AstXFromR _v -> error "TODO"
AstXFromR v -> xfromR $ interpretAst env v
AstXFromS v -> xfromS $ interpretAst env v

AstMkHVector l -> dmkHVector $ interpretAstDynamic env <$> l
AstApply t ll ->
Expand Down
67 changes: 54 additions & 13 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ module HordeAd.Core.AstSimplify
, astCast, astCastR, astCastS
, astFromIntegral, astFromIntegralR, astFromIntegralS
, astProject1, astProject2, astProjectR, astProjectS, astNestS, astUnNestS
, astRFromS, astSFromR, astSFromX, astXFromS
, astRFromS, astRFromX, astSFromR, astSFromX, astXFromR, astXFromS
, astPrimalPart, astDualPart
, astLetHVectorIn, astHApply, astLetFun
-- * The simplifying bottom-up pass
Expand Down Expand Up @@ -338,6 +338,7 @@ astNonIndexStep t = case t of
Ast.AstProjectR l p -> astProjectR l p
Ast.AstLetHVectorIn vars u v -> astLetHVectorIn vars u v
Ast.AstRFromS v -> astRFromS v
Ast.AstRFromX v -> astRFromX v

Ast.AstMinIndexS{} -> t
Ast.AstMaxIndexS{} -> t
Expand Down Expand Up @@ -372,6 +373,7 @@ astNonIndexStep t = case t of
Ast.AstUnNestS v -> astUnNestS v
Ast.AstSFromR v -> astSFromR v
Ast.AstSFromX v -> astSFromX v
Ast.AstXFromR v -> astXFromR v
Ast.AstXFromS v -> astXFromS v
_ -> t -- TODO

Expand Down Expand Up @@ -608,6 +610,7 @@ astIndexKnobsR knobs v0 ix@(i1 :.: (rest1 :: AstIxR AstMethodLet m1)) =
gcastWith (unsafeCoerce Refl :: Rank p_drop :~: n) $
astRFromS $ astIndexKnobsS @p_take @p_drop knobs
t (ShapedList.listToIndex $ indexToList ix)
Ast.AstRFromX{} -> error "TODO"

Ast.AstApply{} -> Ast.AstIndex v0 ix

Expand Down Expand Up @@ -845,7 +848,9 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 i1 (rest1 :: AstIxS AstMethodLet shm1)) |
Ast.AstProjectS{} -> Ast.AstIndexS v0 ix
Ast.AstLetHVectorIn vars l v ->
astLetHVectorIn vars l (astIndexRec v ix)
{- TODO: Ast.AstNestS @_ @_ @sh2 v ->
Ast.AstNestS{} -> Ast.AstIndexS v0 ix
{- TODO:
Ast.AstNestS @_ @_ @sh2 v ->
withKnownShS (Nested.Internal.Shape.shsAppend (knownShS @shn) (knownShS @sh2)) $
gcastWith (unsafeCoerce Refl
:: (shm ++ shn) ++ sh2 :~: shm ++ (shn ++ sh2)) $
Expand Down Expand Up @@ -1213,6 +1218,7 @@ astGatherKnobsR knobs sh0 v0 (vars0, ix0) =
astRFromS $ astGatherStepS @_ @p' @sh v
( ShapedList.listToSized $ sizedToList vars4
, ShapedList.listToSized $ indexToList ix4 ) -}
Ast.AstRFromX{} -> error "TODO"

Ast.AstApply{} -> Ast.AstGather sh4 v4 (vars4, ix4)

Expand Down Expand Up @@ -2218,22 +2224,36 @@ astUnNestS t = case t of
_ -> Ast.AstUnNestS t

astRFromS :: forall sh s r. (TensorKind1 r, KnownShS sh)
=> AstTensor AstMethodLet s (TKS2 sh r) -> AstTensor AstMethodLet s (TKR2 (Rank sh) r)
astRFromS (AstConcrete ftk t) = case ftk of
=> AstTensor AstMethodLet s (TKS2 sh r)
-> AstTensor AstMethodLet s (TKR2 (Rank sh) r)
astRFromS (AstConcrete ftk t)
| Dict <- lemKnownNatRankS (knownShS @sh) = case ftk of
FTKS _ x ->
withListSh (Proxy @sh) $ \(_ :: IShR p) ->
gcastWith (unsafeCoerce Refl :: Rank sh :~: p) $
let u = Nested.stoRanked (unRepN t)
in AstConcrete (FTKR (Nested.rshape u) x) (RepN u)
astRFromS (Ast.AstFromPrimal v) =
withListSh (Proxy @sh) $ \(_ :: IShR p) ->
gcastWith (unsafeCoerce Refl :: Rank sh :~: p) $
astRFromS (Ast.AstFromPrimal v)
| Dict <- lemKnownNatRankS (knownShS @sh) =
Ast.AstFromPrimal $ astRFromS v
astRFromS (Ast.AstSFromR v) = v -- no information lost, so no checks
astRFromS v = Ast.AstRFromS v

astRFromX :: forall sh s r. (TensorKind1 r, KnownShX sh)
=> AstTensor AstMethodLet s (TKX2 sh r)
-> AstTensor AstMethodLet s (TKR2 (Rank sh) r)
astRFromX (AstConcrete ftk t)
| Dict <- lemKnownNatRankX (knownShX @sh) = case ftk of
FTKX _ x ->
let u = Nested.mtoRanked (unRepN t)
in AstConcrete (FTKR (Nested.rshape u) x) (RepN u)
astRFromX (Ast.AstFromPrimal v)
| Dict <- lemKnownNatRankX (knownShX @sh) =
Ast.AstFromPrimal $ astRFromX v
astRFromX (Ast.AstXFromR v) = v -- no information lost, so no checks
astRFromX v = Ast.AstRFromX v

astSFromR :: forall sh s r. (TensorKind1 r, KnownShS sh, KnownNat (Rank sh))
=> AstTensor AstMethodLet s (TKR2 (Rank sh) r) -> AstTensor AstMethodLet s (TKS2 sh r)
=> AstTensor AstMethodLet s (TKR2 (Rank sh) r)
-> AstTensor AstMethodLet s (TKS2 sh r)
astSFromR (AstConcrete ftk t) = case ftk of
FTKR _ x ->
AstConcrete (FTKS knownShS x) $ RepN
Expand All @@ -2246,7 +2266,7 @@ astSFromR (Ast.AstRFromS @sh1 v) =
astSFromR v = Ast.AstSFromR v

astSFromX :: forall sh sh' s r.
(KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', KnownShX (Nested.MapJust sh), TensorKind1 r)
(KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind1 r)
=> AstTensor AstMethodLet s (TKX2 sh' r)
-> AstTensor AstMethodLet s (TKS2 sh r)
astSFromX (AstConcrete ftk t) = case ftk of
Expand All @@ -2260,13 +2280,24 @@ astSFromX (Ast.AstXFromS @sh1 v) =
_ -> error "astSFromX: different shapes in SFromX(XFromS)"
astSFromX v = Ast.AstSFromX v

astXFromR :: forall sh s r.
(KnownShX sh, KnownNat (Rank sh), TensorKind1 r)
=> AstTensor AstMethodLet s (TKR2 (Rank sh) r)
-> AstTensor AstMethodLet s (TKX2 sh r)
astXFromR (AstConcrete ftk t) = case ftk of
FTKR _ x ->
let u = Nested.rcastToMixed (knownShX @sh) (unRepN t)
in AstConcrete (FTKX (Nested.mshape u) x) (RepN u)
astXFromR (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astXFromR v
astXFromR v = Ast.AstXFromR v

astXFromS :: forall sh sh' s r.
(KnownShS sh, KnownShX sh', sh' ~ Nested.MapJust sh, TensorKind1 r)
(KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind1 r)
=> AstTensor AstMethodLet s (TKS2 sh r)
-> AstTensor AstMethodLet s (TKX2 sh' r)
astXFromS (AstConcrete ftk t) = case ftk of
FTKS _ x ->
let u = Nested.stoMixed (unRepN t)
let u = Nested.scastToMixed (knownShX @sh') (unRepN t)
in AstConcrete (FTKX (Nested.mshape u) x) (RepN u)
astXFromS (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astXFromS v
-- impossible, shapes may differ: astXFromS (Ast.AstSFromX v) = v
Expand Down Expand Up @@ -2316,6 +2347,7 @@ astPrimalPart t = case t of
Ast.AstProjectR l p -> astProjectR (astPrimalPart l) p
Ast.AstLetHVectorIn vars l v -> astLetHVectorIn vars l (astPrimalPart v)
Ast.AstRFromS v -> astRFromS $ astPrimalPart v
Ast.AstRFromX v -> astRFromX $ astPrimalPart v

AstN1S opCode u -> AstN1S opCode (astPrimalPart u)
AstN2S opCode u v -> AstN2S opCode (astPrimalPart u) (astPrimalPart v)
Expand All @@ -2341,6 +2373,7 @@ astPrimalPart t = case t of
Ast.AstUnNestS v -> astUnNestS $ astPrimalPart v
Ast.AstSFromR v -> astSFromR $ astPrimalPart v
Ast.AstSFromX v -> astSFromX $ astPrimalPart v
Ast.AstXFromR v -> astXFromR $ astPrimalPart v
Ast.AstXFromS v -> astXFromS $ astPrimalPart v

Ast.AstMkHVector{} -> Ast.AstPrimalPart t -- TODO
Expand Down Expand Up @@ -2401,6 +2434,7 @@ astDualPart t = case t of
Ast.AstProjectR l p -> astProjectR (astDualPart l) p
Ast.AstLetHVectorIn vars l v -> astLetHVectorIn vars l (astDualPart v)
Ast.AstRFromS v -> astRFromS $ astDualPart v
Ast.AstRFromX v -> astRFromX $ astDualPart v

AstN1S{} -> Ast.AstDualPart t
AstN2S{} -> Ast.AstDualPart t
Expand All @@ -2424,6 +2458,7 @@ astDualPart t = case t of
Ast.AstUnNestS v -> astUnNestS $ astDualPart v
Ast.AstSFromR v -> astSFromR $ astDualPart v
Ast.AstSFromX v -> astSFromX $ astDualPart v
Ast.AstXFromR v -> astXFromR $ astDualPart v
Ast.AstXFromS v -> astXFromS $ astDualPart v

Ast.AstMkHVector{} -> Ast.AstDualPart t -- TODO
Expand Down Expand Up @@ -2686,6 +2721,7 @@ simplifyAst t = case t of
Ast.AstLetHVectorIn vars l v ->
astLetHVectorIn vars (simplifyAst l) (simplifyAst v)
Ast.AstRFromS v -> astRFromS $ simplifyAst v
Ast.AstRFromX v -> astRFromX $ simplifyAst v

Ast.AstMinIndexS a -> Ast.AstMinIndexS (simplifyAst a)
Ast.AstMaxIndexS a -> Ast.AstMaxIndexS (simplifyAst a)
Expand Down Expand Up @@ -2716,6 +2752,7 @@ simplifyAst t = case t of
Ast.AstUnNestS v -> astUnNestS $ simplifyAst v
Ast.AstSFromR v -> astSFromR $ simplifyAst v
Ast.AstSFromX v -> astSFromX $ simplifyAst v
Ast.AstXFromR v -> astXFromR $ simplifyAst v
Ast.AstXFromS v -> astXFromS $ simplifyAst v

Ast.AstMkHVector l -> Ast.AstMkHVector $ V.map simplifyAstDynamic l
Expand Down Expand Up @@ -2916,6 +2953,7 @@ expandAst t = case t of
Ast.AstLetHVectorIn vars l v ->
astLetHVectorIn vars (expandAst l) (expandAst v)
Ast.AstRFromS v -> astRFromS $ expandAst v
Ast.AstRFromX v -> astRFromX $ expandAst v

Ast.AstMinIndexS a -> Ast.AstMinIndexS (expandAst a)
Ast.AstMaxIndexS a -> Ast.AstMaxIndexS (expandAst a)
Expand Down Expand Up @@ -2949,6 +2987,7 @@ expandAst t = case t of
Ast.AstUnNestS v -> astUnNestS $ expandAst v
Ast.AstSFromR v -> astSFromR $ expandAst v
Ast.AstSFromX v -> astSFromX $ expandAst v
Ast.AstXFromR v -> astXFromR $ expandAst v
Ast.AstXFromS v -> astXFromS $ expandAst v

Ast.AstMkHVector l -> Ast.AstMkHVector $ V.map expandAstDynamic l
Expand Down Expand Up @@ -3446,6 +3485,7 @@ substitute1Ast i var v1 = case v1 of
(ml, mv) ->
Just $ astLetHVectorIn vars (fromMaybe l ml) (fromMaybe v mv)
Ast.AstRFromS v -> astRFromS <$> substitute1Ast i var v
Ast.AstRFromX v -> astRFromX <$> substitute1Ast i var v

Ast.AstMinIndexS a -> Ast.AstMinIndexS <$> substitute1Ast i var a
Ast.AstMaxIndexS a -> Ast.AstMaxIndexS <$> substitute1Ast i var a
Expand Down Expand Up @@ -3514,6 +3554,7 @@ substitute1Ast i var v1 = case v1 of
Ast.AstUnNestS v -> astUnNestS <$> substitute1Ast i var v
Ast.AstSFromR v -> astSFromR <$> substitute1Ast i var v
Ast.AstSFromX v -> astSFromX <$> substitute1Ast i var v
Ast.AstXFromR v -> astXFromR <$> substitute1Ast i var v
Ast.AstXFromS v -> astXFromS <$> substitute1Ast i var v

Ast.AstMkHVector args ->
Expand Down
23 changes: 16 additions & 7 deletions src/HordeAd/Core/AstTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ import Data.List (foldl')
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality (gcastWith, (:~:) (Refl))
import Data.Vector.Generic qualified as V
import GHC.Exts (IsList (..))
import GHC.TypeLits (sameNat, type (+))
import Unsafe.Coerce (unsafeCoerce)

import Data.Array.Mixed.Permutation qualified as Permutation
import Data.Array.Mixed.Shape (shxSize)
import Data.Array.Mixed.Shape (KnownShX (..), shxSize)
import Data.Array.Nested
(IShR, KnownShS (..), ShR (..), pattern (:$:), pattern ZSR)
import Data.Array.Nested.Internal.Shape (shCvtSX, shrSize, shsSize)
import Data.Array.Nested.Internal.Shape (shrSize, shsSize)
import Data.Array.Nested.Internal.Shape qualified as Nested.Internal.Shape


import HordeAd.Core.Ast
import HordeAd.Core.TensorKind
import HordeAd.Core.Types
Expand Down Expand Up @@ -118,7 +118,10 @@ ftkAst t = case t of
AstLetHVectorIn _ _ v -> ftkAst v
AstRFromS @sh v
| Dict <- lemKnownNatRankS (knownShS @sh) -> case ftkAst v of
FTKS _ x -> FTKR (listToShape $ shapeT @sh) x
FTKS _ x -> FTKR (fromList $ shapeT @sh) x
AstRFromX @sh v
| Dict <- lemKnownNatRankX (knownShX @sh) -> case ftkAst v of
FTKX shx x -> FTKR (fromList $ toList shx) x

AstMinIndexS{} -> FTKS knownShS FTKScalar
AstMaxIndexS{} -> FTKS knownShS FTKScalar
Expand Down Expand Up @@ -165,8 +168,11 @@ ftkAst t = case t of
FTKR _ x -> FTKS knownShS x
AstSFromX v -> case ftkAst v of
FTKX _ x -> FTKS knownShS x
AstXFromS @sh v -> case ftkAst v of
FTKS _ x -> FTKX (shCvtSX (knownShS @sh)) x
AstXFromR @sh v
| Dict <- lemKnownNatRankX (knownShX @sh) -> case ftkAst v of
FTKR shr x -> FTKX (fromList $ toList shr) x
AstXFromS v -> case ftkAst v of
FTKS sh x -> FTKX (fromList $ toList sh) x

AstMkHVector v ->
FTKUntyped
Expand Down Expand Up @@ -268,6 +274,7 @@ varInAst var = \case
AstProjectR l _p -> varInAst var l
AstLetHVectorIn _vars l v -> varInAst var l || varInAst var v
AstRFromS v -> varInAst var v
AstRFromX v -> varInAst var v

AstMinIndexS a -> varInAst var a
AstMaxIndexS a -> varInAst var a
Expand Down Expand Up @@ -296,7 +303,6 @@ varInAst var = \case
AstUnNestS v -> varInAst var v
AstSFromR v -> varInAst var v
AstSFromX v -> varInAst var v
AstXFromS v -> varInAst var v

AstMinIndexX a -> varInAst var a
AstMaxIndexX a -> varInAst var a
Expand All @@ -322,6 +328,7 @@ varInAst var = \case
AstFromIntegralX a -> varInAst var a
AstProjectX l _p -> varInAst var l
AstXFromR v -> varInAst var v
AstXFromS v -> varInAst var v

AstMkHVector l -> any (varInAstDynamic var) l
AstApply t ll -> varInAstHFun var t || varInAst var ll
Expand Down Expand Up @@ -390,6 +397,7 @@ astIsSmall relaxed = \case
relaxed && astIsSmall relaxed v -- often cheap and often fuses
AstProjectR t _ -> astIsSmall relaxed t
AstRFromS v -> astIsSmall relaxed v
AstRFromX v -> astIsSmall relaxed v

AstIotaS -> True
AstFromVectorS v | V.length v == 1 -> astIsSmall relaxed $ v V.! 0
Expand All @@ -400,6 +408,7 @@ astIsSmall relaxed = \case
AstProjectS t _ -> astIsSmall relaxed t
AstSFromR v -> astIsSmall relaxed v
AstSFromX v -> astIsSmall relaxed v
AstXFromR v -> astIsSmall relaxed v
AstXFromS v -> astIsSmall relaxed v

AstMkHVector v | V.length v == 1 -> case v V.! 0 of
Expand Down
Loading

0 comments on commit d082e2b

Please sign in to comment.