Skip to content

Commit

Permalink
Fully generalize snest and sunNest
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 1, 2024
1 parent 62ce22f commit 4007eda
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 32 deletions.
12 changes: 6 additions & 6 deletions src/HordeAd/Core/Ast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -484,14 +484,14 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType
AstProjectS :: (GoodScalar r, KnownShS sh)
=> AstTensor ms s TKUntyped -> Int -> AstTensor ms s (TKS sh r)
AstNestS :: forall r sh1 sh2 ms s.
(GoodScalar r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> AstTensor ms s (TKS (sh1 ++ sh2) r)
-> AstTensor ms s (TKS2 sh1 (TKS sh2 r))
(TensorKind2 r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> AstTensor ms s (TKS2 (sh1 ++ sh2) r)
-> AstTensor ms s (TKS2 sh1 (TKS2 sh2 r))
AstUnNestS :: forall r sh1 sh2 ms s.
( GoodScalar r, KnownShS sh1, KnownShS sh2
( TensorKind2 r, KnownShS sh1, KnownShS sh2
, KnownShS (sh1 ++ sh2) )
=> AstTensor ms s (TKS2 sh1 (TKS sh2 r))
-> AstTensor ms s (TKS (sh1 ++ sh2) r)
=> AstTensor ms s (TKS2 sh1 (TKS2 sh2 r))
-> AstTensor ms s (TKS2 (sh1 ++ sh2) r)
AstSFromR :: (KnownShS sh, KnownNat (Rank sh), GoodScalar r)
=> AstTensor ms s (TKR (Rank sh) r) -> AstTensor ms s (TKS sh r)
AstSFromX :: ( KnownShS sh, KnownShX sh', Rank sh ~ Rank sh'
Expand Down
28 changes: 16 additions & 12 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -609,11 +609,12 @@ astIndexKnobsS knobs (Ast.AstIndexS v ix) ZIS = astIndexKnobsS knobs v ix
astIndexKnobsS _ v0 ZIS = v0
astIndexKnobsS knobs v0 ix@((:.$) @in1 i1 (rest1 :: AstIxS AstMethodLet shm1)) | Dict <- sixKnown rest1 =
let astIndexRec, astIndex
:: forall shm' shn' s'.
( KnownShS shm', KnownShS shn', KnownShS (shm' ++ shn')
:: forall shm' shn' s' r'.
( TensorKind2 r', KnownShS shm', KnownShS shn', KnownShS (shm' ++ shn')
, AstSpan s' )
=> AstTensor AstMethodLet s' (TKS2 (shm' ++ shn') r) -> AstIxS AstMethodLet shm'
-> AstTensor AstMethodLet s' (TKS2 shn' r)
=> AstTensor AstMethodLet s' (TKS2 (shm' ++ shn') r')
-> AstIxS AstMethodLet shm'
-> AstTensor AstMethodLet s' (TKS2 shn' r')
astIndexRec v2 ZIS = v2
astIndexRec v2 ix2 = if knobStepOnly knobs
then Ast.AstIndexS v2 ix2
Expand Down Expand Up @@ -824,8 +825,11 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 i1 (rest1 :: AstIxS AstMethodLet shm1)) |
Ast.AstProjectS{} -> Ast.AstIndexS v0 ix
Ast.AstLetHVectorIn vars l v ->
astLetHVectorIn vars l (astIndexRec v ix)
Ast.AstNestS _ -> Ast.AstIndexS v0 ix
-- TODO: why no work? maybe AstNestS needs to be even more general? Ast.AstNestS v -> astNestS (astIndexRec v ix)
Ast.AstNestS @_ @_ @sh2 v ->
withKnownShS (Nested.Internal.Shape.shsAppend (knownShS @shn) (knownShS @sh2)) $
gcastWith (unsafeCoerce Refl
:: (shm ++ shn) ++ sh2 :~: shm ++ (shn ++ sh2)) $
astNestS (astIndexRec v ix)
-- TODO: hard: Ast.AstUnNestS v -> astUnNestS (astIndexRec v ix)
Ast.AstUnNestS _ -> Ast.AstIndexS v0 ix
Ast.AstSFromR t ->
Expand Down Expand Up @@ -2159,9 +2163,9 @@ astProjectS l p = case l of

astNestS
:: forall r sh1 sh2 ms s.
(GoodScalar r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2), AstSpan s)
=> AstTensor ms s (TKS (sh1 ++ sh2) r)
-> AstTensor ms s (TKS2 sh1 (TKS sh2 r))
(TensorKind2 r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2), AstSpan s)
=> AstTensor ms s (TKS2 (sh1 ++ sh2) r)
-> AstTensor ms s (TKS2 sh1 (TKS2 sh2 r))
astNestS t = case t of
Ast.AstLet var u2 d2 -> -- TODO: good idea?
astLet var u2 (astNestS d2)
Expand All @@ -2172,9 +2176,9 @@ astNestS t = case t of

astUnNestS
:: forall r sh1 sh2 ms s.
(GoodScalar r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2), AstSpan s)
=> AstTensor ms s (TKS2 sh1 (TKS sh2 r))
-> AstTensor ms s (TKS (sh1 ++ sh2) r)
(TensorKind2 r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2), AstSpan s)
=> AstTensor ms s (TKS2 sh1 (TKS2 sh2 r))
-> AstTensor ms s (TKS2 (sh1 ++ sh2) r)
astUnNestS t = case t of
Ast.AstLet var u2 d2 -> -- TODO: good idea?
astLet var u2 (astUnNestS d2)
Expand Down
6 changes: 4 additions & 2 deletions src/HordeAd/Core/AstTools.hs
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,10 @@ ftkAst t = case t of
AstCastS{} -> FTKS knownShS FTKScalar
AstFromIntegralS{} -> FTKS knownShS FTKScalar
AstProjectS{} -> FTKS knownShS FTKScalar
AstNestS{} -> FTKS knownShS (FTKS knownShS FTKScalar)
AstUnNestS{} -> FTKS knownShS FTKScalar
AstNestS v -> case ftkAst v of
FTKS _ x -> FTKS knownShS (FTKS knownShS x)
AstUnNestS v -> case ftkAst v of
FTKS _ (FTKS _ x) -> FTKS knownShS x
AstSFromR{} -> FTKS knownShS FTKScalar
AstSFromX{} -> FTKS knownShS FTKScalar
AstXFromS{} -> error "TODO"
Expand Down
18 changes: 10 additions & 8 deletions src/HordeAd/Core/Delta.hs
Original file line number Diff line number Diff line change
Expand Up @@ -574,12 +574,12 @@ data Delta :: Target -> TensorKindType -> Type where
-- TODO: this is a haddock for Gather1; fix.
CastS :: (GoodScalar r1, RealFrac r1, GoodScalar r2, RealFrac r2, KnownShS sh)
=> Delta target (TKS sh r1) -> Delta target (TKS sh r2)
NestS :: (GoodScalar r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> Delta target (TKS (sh1 ++ sh2) r)
-> Delta target (TKS2 sh1 (TKS sh2 r))
UnNestS :: (GoodScalar r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> Delta target (TKS2 sh1 (TKS sh2 r))
-> Delta target (TKS (sh1 ++ sh2) r)
NestS :: (TensorKind2 r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> Delta target (TKS2 (sh1 ++ sh2) r)
-> Delta target (TKS2 sh1 (TKS2 sh2 r))
UnNestS :: (TensorKind2 r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> Delta target (TKS2 sh1 (TKS2 sh2 r))
-> Delta target (TKS2 (sh1 ++ sh2) r)
SFromR :: forall sh r target. (KnownShS sh, KnownNat (Rank sh), GoodScalar r)
=> Delta target (TKR (Rank sh) r)
-> Delta target (TKS sh r)
Expand Down Expand Up @@ -715,8 +715,10 @@ shapeDeltaFull = \case
ReshapeS{} -> FTKS knownShS FTKScalar
GatherS{} -> FTKS knownShS FTKScalar
CastS{} -> FTKS knownShS FTKScalar
NestS{} -> FTKS knownShS (FTKS knownShS FTKScalar)
UnNestS{} -> FTKS knownShS FTKScalar
NestS d -> case shapeDeltaFull d of
FTKS _ x -> FTKS knownShS (FTKS knownShS x)
UnNestS d -> case shapeDeltaFull d of
FTKS _ (FTKS _ x) -> FTKS knownShS x
SFromR{} -> FTKS knownShS FTKScalar
SFromX{} -> FTKS knownShS FTKScalar
XFromS{} -> error "TODO"
Expand Down
9 changes: 5 additions & 4 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -775,11 +775,12 @@ class ( Num (IntOf target)
sconcrete :: (GoodScalar r, KnownShS sh) => Nested.Shaped sh r -> target (TKS sh r)
sconcrete a = tconcrete (FTKS (Nested.sshape a) FTKScalar) (RepN a)
snest :: forall sh1 sh2 r.
(GoodScalar r, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> ShS sh1 -> target (TKS (sh1 ++ sh2) r) -> target (TKS2 sh1 (TKS sh2 r))
(TensorKind2 r, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> ShS sh1 -> target (TKS2 (sh1 ++ sh2) r)
-> target (TKS2 sh1 (TKS2 sh2 r))
sunNest :: forall sh1 sh2 r.
(GoodScalar r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> target (TKS2 sh1 (TKS sh2 r)) -> target (TKS (sh1 ++ sh2) r)
(TensorKind2 r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> target (TKS2 sh1 (TKS2 sh2 r)) -> target (TKS2 (sh1 ++ sh2) r)
sfromR :: (GoodScalar r, KnownShS sh, KnownNat (Rank sh))
=> target (TKR (Rank sh) r) -> target (TKS sh r)
sfromX :: ( KnownShS sh, KnownShX sh', Rank sh ~ Rank sh'
Expand Down

0 comments on commit 4007eda

Please sign in to comment.