diff --git a/src/HordeAd/Core/Ast.hs b/src/HordeAd/Core/Ast.hs index 460be549..ed0fcdc2 100644 --- a/src/HordeAd/Core/Ast.hs +++ b/src/HordeAd/Core/Ast.hs @@ -390,6 +390,8 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType -> AstTensor AstMethodLet s2 z AstRFromS :: (KnownShS sh, TensorKind1 r) => AstTensor ms s (TKS2 sh r) -> AstTensor ms s (TKR2 (Rank sh) r) + AstRFromX :: (KnownShX sh, TensorKind1 r) + => AstTensor ms s (TKX2 sh r) -> AstTensor ms s (TKR2 (Rank sh) r) -- Here starts the shaped part. AstFromScalar :: GoodScalar r @@ -495,11 +497,8 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType AstSFromR :: (KnownShS sh, KnownNat (Rank sh), TensorKind1 r) => AstTensor ms s (TKR2 (Rank sh) r) -> AstTensor ms s (TKS2 sh r) AstSFromX :: ( KnownShS sh, KnownShX sh', Rank sh ~ Rank sh' - , KnownShX (Nested.MapJust sh), TensorKind1 r ) - => AstTensor ms s (TKX2 sh' r) -> AstTensor ms s (TKS2 sh r) - AstXFromS :: ( KnownShS sh, KnownShX sh', sh' ~ Nested.MapJust sh , TensorKind1 r ) - => AstTensor ms s (TKS2 sh r) -> AstTensor ms s (TKX2 sh' r) + => AstTensor ms s (TKX2 sh' r) -> AstTensor ms s (TKS2 sh r) -- Here starts the mixed part. AstN1X :: (GoodScalar r, KnownShX sh) @@ -590,6 +589,8 @@ data AstTensor :: AstMethodOfSharing -> AstSpanType -> TensorKindType => AstTensor ms s TKUntyped -> Int -> AstTensor ms s (TKX sh r) AstXFromR :: (KnownShX sh, KnownNat (Rank sh), TensorKind1 r) => AstTensor ms s (TKR2 (Rank sh) r) -> AstTensor ms s (TKX2 sh r) + AstXFromS :: (KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind1 r) + => AstTensor ms s (TKS2 sh r) -> AstTensor ms s (TKX2 sh' r) -- Here starts the misc part. AstMkHVector :: HVector (AstTensor ms s) -> AstTensor ms s TKUntyped diff --git a/src/HordeAd/Core/AstInline.hs b/src/HordeAd/Core/AstInline.hs index 35ee6484..65aeb4da 100644 --- a/src/HordeAd/Core/AstInline.hs +++ b/src/HordeAd/Core/AstInline.hs @@ -233,6 +233,7 @@ inlineAst memo v0 = case v0 of (memo2, v2) = inlineAst memo1 v in (memo2, Ast.AstLetHVectorIn vars l2 v2) Ast.AstRFromS v -> second Ast.AstRFromS $ inlineAst memo v + Ast.AstRFromX v -> second Ast.AstRFromX $ inlineAst memo v Ast.AstMinIndexS a -> second Ast.AstMinIndexS $ inlineAst memo a Ast.AstMaxIndexS a -> second Ast.AstMaxIndexS $ inlineAst memo a @@ -582,6 +583,7 @@ unshareAst memo = \case let (memo1, l2) = unshareAst memo l in (memo1, Ast.AstProjectR l2 p) Ast.AstRFromS v -> second Ast.AstRFromS $ unshareAst memo v + Ast.AstRFromX v -> second Ast.AstRFromX $ unshareAst memo v Ast.AstMinIndexS a -> second Ast.AstMinIndexS $ unshareAst memo a Ast.AstMaxIndexS a -> second Ast.AstMaxIndexS $ unshareAst memo a diff --git a/src/HordeAd/Core/AstInterpret.hs b/src/HordeAd/Core/AstInterpret.hs index 16efadcd..315e9c46 100644 --- a/src/HordeAd/Core/AstInterpret.hs +++ b/src/HordeAd/Core/AstInterpret.hs @@ -616,6 +616,7 @@ interpretAst !env = \case (\lw -> interpretAst (env2 (dunHVector lw)) v) _ -> error "TODO" AstRFromS v -> rfromS $ interpretAst env v + AstRFromX v -> rfromX $ interpretAst env v AstMinIndexS v -> sminIndex $ sfromPrimal $ interpretAstPrimalSRuntimeSpecialized env v @@ -829,7 +830,6 @@ interpretAst !env = \case AstUnNestS v -> sunNest $ interpretAst env v AstSFromR v -> sfromR $ interpretAst env v AstSFromX v -> sfromX $ interpretAst env v - AstXFromS v -> xfromS $ interpretAst env v AstMinIndexX _v -> error "TODO" AstMaxIndexX _v -> error "TODO" @@ -881,7 +881,8 @@ interpretAst !env = \case AstCastX _v -> error "TODO" AstFromIntegralX _v -> error "TODO" AstProjectX _l _p -> error "TODO" - AstXFromR _v -> error "TODO" + AstXFromR v -> xfromR $ interpretAst env v + AstXFromS v -> xfromS $ interpretAst env v AstMkHVector l -> dmkHVector $ interpretAstDynamic env <$> l AstApply t ll -> diff --git a/src/HordeAd/Core/AstSimplify.hs b/src/HordeAd/Core/AstSimplify.hs index 7a9fc0bf..e8956103 100644 --- a/src/HordeAd/Core/AstSimplify.hs +++ b/src/HordeAd/Core/AstSimplify.hs @@ -30,7 +30,7 @@ module HordeAd.Core.AstSimplify , astCast, astCastR, astCastS , astFromIntegral, astFromIntegralR, astFromIntegralS , astProject1, astProject2, astProjectR, astProjectS, astNestS, astUnNestS - , astRFromS, astSFromR, astSFromX, astXFromS + , astRFromS, astRFromX, astSFromR, astSFromX, astXFromR, astXFromS , astPrimalPart, astDualPart , astLetHVectorIn, astHApply, astLetFun -- * The simplifying bottom-up pass @@ -338,6 +338,7 @@ astNonIndexStep t = case t of Ast.AstProjectR l p -> astProjectR l p Ast.AstLetHVectorIn vars u v -> astLetHVectorIn vars u v Ast.AstRFromS v -> astRFromS v + Ast.AstRFromX v -> astRFromX v Ast.AstMinIndexS{} -> t Ast.AstMaxIndexS{} -> t @@ -372,6 +373,7 @@ astNonIndexStep t = case t of Ast.AstUnNestS v -> astUnNestS v Ast.AstSFromR v -> astSFromR v Ast.AstSFromX v -> astSFromX v + Ast.AstXFromR v -> astXFromR v Ast.AstXFromS v -> astXFromS v _ -> t -- TODO @@ -608,6 +610,7 @@ astIndexKnobsR knobs v0 ix@(i1 :.: (rest1 :: AstIxR AstMethodLet m1)) = gcastWith (unsafeCoerce Refl :: Rank p_drop :~: n) $ astRFromS $ astIndexKnobsS @p_take @p_drop knobs t (ShapedList.listToIndex $ indexToList ix) + Ast.AstRFromX{} -> error "TODO" Ast.AstApply{} -> Ast.AstIndex v0 ix @@ -845,7 +848,9 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 i1 (rest1 :: AstIxS AstMethodLet shm1)) | Ast.AstProjectS{} -> Ast.AstIndexS v0 ix Ast.AstLetHVectorIn vars l v -> astLetHVectorIn vars l (astIndexRec v ix) -{- TODO: Ast.AstNestS @_ @_ @sh2 v -> + Ast.AstNestS{} -> Ast.AstIndexS v0 ix +{- TODO: + Ast.AstNestS @_ @_ @sh2 v -> withKnownShS (Nested.Internal.Shape.shsAppend (knownShS @shn) (knownShS @sh2)) $ gcastWith (unsafeCoerce Refl :: (shm ++ shn) ++ sh2 :~: shm ++ (shn ++ sh2)) $ @@ -1213,6 +1218,7 @@ astGatherKnobsR knobs sh0 v0 (vars0, ix0) = astRFromS $ astGatherStepS @_ @p' @sh v ( ShapedList.listToSized $ sizedToList vars4 , ShapedList.listToSized $ indexToList ix4 ) -} + Ast.AstRFromX{} -> error "TODO" Ast.AstApply{} -> Ast.AstGather sh4 v4 (vars4, ix4) @@ -2218,22 +2224,36 @@ astUnNestS t = case t of _ -> Ast.AstUnNestS t astRFromS :: forall sh s r. (TensorKind1 r, KnownShS sh) - => AstTensor AstMethodLet s (TKS2 sh r) -> AstTensor AstMethodLet s (TKR2 (Rank sh) r) -astRFromS (AstConcrete ftk t) = case ftk of + => AstTensor AstMethodLet s (TKS2 sh r) + -> AstTensor AstMethodLet s (TKR2 (Rank sh) r) +astRFromS (AstConcrete ftk t) + | Dict <- lemKnownNatRankS (knownShS @sh) = case ftk of FTKS _ x -> - withListSh (Proxy @sh) $ \(_ :: IShR p) -> - gcastWith (unsafeCoerce Refl :: Rank sh :~: p) $ let u = Nested.stoRanked (unRepN t) in AstConcrete (FTKR (Nested.rshape u) x) (RepN u) -astRFromS (Ast.AstFromPrimal v) = - withListSh (Proxy @sh) $ \(_ :: IShR p) -> - gcastWith (unsafeCoerce Refl :: Rank sh :~: p) $ +astRFromS (Ast.AstFromPrimal v) + | Dict <- lemKnownNatRankS (knownShS @sh) = Ast.AstFromPrimal $ astRFromS v astRFromS (Ast.AstSFromR v) = v -- no information lost, so no checks astRFromS v = Ast.AstRFromS v +astRFromX :: forall sh s r. (TensorKind1 r, KnownShX sh) + => AstTensor AstMethodLet s (TKX2 sh r) + -> AstTensor AstMethodLet s (TKR2 (Rank sh) r) +astRFromX (AstConcrete ftk t) + | Dict <- lemKnownNatRankX (knownShX @sh) = case ftk of + FTKX _ x -> + let u = Nested.mtoRanked (unRepN t) + in AstConcrete (FTKR (Nested.rshape u) x) (RepN u) +astRFromX (Ast.AstFromPrimal v) + | Dict <- lemKnownNatRankX (knownShX @sh) = + Ast.AstFromPrimal $ astRFromX v +astRFromX (Ast.AstXFromR v) = v -- no information lost, so no checks +astRFromX v = Ast.AstRFromX v + astSFromR :: forall sh s r. (TensorKind1 r, KnownShS sh, KnownNat (Rank sh)) - => AstTensor AstMethodLet s (TKR2 (Rank sh) r) -> AstTensor AstMethodLet s (TKS2 sh r) + => AstTensor AstMethodLet s (TKR2 (Rank sh) r) + -> AstTensor AstMethodLet s (TKS2 sh r) astSFromR (AstConcrete ftk t) = case ftk of FTKR _ x -> AstConcrete (FTKS knownShS x) $ RepN @@ -2246,7 +2266,7 @@ astSFromR (Ast.AstRFromS @sh1 v) = astSFromR v = Ast.AstSFromR v astSFromX :: forall sh sh' s r. - (KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', KnownShX (Nested.MapJust sh), TensorKind1 r) + (KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind1 r) => AstTensor AstMethodLet s (TKX2 sh' r) -> AstTensor AstMethodLet s (TKS2 sh r) astSFromX (AstConcrete ftk t) = case ftk of @@ -2260,13 +2280,24 @@ astSFromX (Ast.AstXFromS @sh1 v) = _ -> error "astSFromX: different shapes in SFromX(XFromS)" astSFromX v = Ast.AstSFromX v +astXFromR :: forall sh s r. + (KnownShX sh, KnownNat (Rank sh), TensorKind1 r) + => AstTensor AstMethodLet s (TKR2 (Rank sh) r) + -> AstTensor AstMethodLet s (TKX2 sh r) +astXFromR (AstConcrete ftk t) = case ftk of + FTKR _ x -> + let u = Nested.rcastToMixed (knownShX @sh) (unRepN t) + in AstConcrete (FTKX (Nested.mshape u) x) (RepN u) +astXFromR (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astXFromR v +astXFromR v = Ast.AstXFromR v + astXFromS :: forall sh sh' s r. - (KnownShS sh, KnownShX sh', sh' ~ Nested.MapJust sh, TensorKind1 r) + (KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind1 r) => AstTensor AstMethodLet s (TKS2 sh r) -> AstTensor AstMethodLet s (TKX2 sh' r) astXFromS (AstConcrete ftk t) = case ftk of FTKS _ x -> - let u = Nested.stoMixed (unRepN t) + let u = Nested.scastToMixed (knownShX @sh') (unRepN t) in AstConcrete (FTKX (Nested.mshape u) x) (RepN u) astXFromS (Ast.AstFromPrimal v) = Ast.AstFromPrimal $ astXFromS v -- impossible, shapes may differ: astXFromS (Ast.AstSFromX v) = v @@ -2316,6 +2347,7 @@ astPrimalPart t = case t of Ast.AstProjectR l p -> astProjectR (astPrimalPart l) p Ast.AstLetHVectorIn vars l v -> astLetHVectorIn vars l (astPrimalPart v) Ast.AstRFromS v -> astRFromS $ astPrimalPart v + Ast.AstRFromX v -> astRFromX $ astPrimalPart v AstN1S opCode u -> AstN1S opCode (astPrimalPart u) AstN2S opCode u v -> AstN2S opCode (astPrimalPart u) (astPrimalPart v) @@ -2341,6 +2373,7 @@ astPrimalPart t = case t of Ast.AstUnNestS v -> astUnNestS $ astPrimalPart v Ast.AstSFromR v -> astSFromR $ astPrimalPart v Ast.AstSFromX v -> astSFromX $ astPrimalPart v + Ast.AstXFromR v -> astXFromR $ astPrimalPart v Ast.AstXFromS v -> astXFromS $ astPrimalPart v Ast.AstMkHVector{} -> Ast.AstPrimalPart t -- TODO @@ -2401,6 +2434,7 @@ astDualPart t = case t of Ast.AstProjectR l p -> astProjectR (astDualPart l) p Ast.AstLetHVectorIn vars l v -> astLetHVectorIn vars l (astDualPart v) Ast.AstRFromS v -> astRFromS $ astDualPart v + Ast.AstRFromX v -> astRFromX $ astDualPart v AstN1S{} -> Ast.AstDualPart t AstN2S{} -> Ast.AstDualPart t @@ -2424,6 +2458,7 @@ astDualPart t = case t of Ast.AstUnNestS v -> astUnNestS $ astDualPart v Ast.AstSFromR v -> astSFromR $ astDualPart v Ast.AstSFromX v -> astSFromX $ astDualPart v + Ast.AstXFromR v -> astXFromR $ astDualPart v Ast.AstXFromS v -> astXFromS $ astDualPart v Ast.AstMkHVector{} -> Ast.AstDualPart t -- TODO @@ -2686,6 +2721,7 @@ simplifyAst t = case t of Ast.AstLetHVectorIn vars l v -> astLetHVectorIn vars (simplifyAst l) (simplifyAst v) Ast.AstRFromS v -> astRFromS $ simplifyAst v + Ast.AstRFromX v -> astRFromX $ simplifyAst v Ast.AstMinIndexS a -> Ast.AstMinIndexS (simplifyAst a) Ast.AstMaxIndexS a -> Ast.AstMaxIndexS (simplifyAst a) @@ -2716,6 +2752,7 @@ simplifyAst t = case t of Ast.AstUnNestS v -> astUnNestS $ simplifyAst v Ast.AstSFromR v -> astSFromR $ simplifyAst v Ast.AstSFromX v -> astSFromX $ simplifyAst v + Ast.AstXFromR v -> astXFromR $ simplifyAst v Ast.AstXFromS v -> astXFromS $ simplifyAst v Ast.AstMkHVector l -> Ast.AstMkHVector $ V.map simplifyAstDynamic l @@ -2916,6 +2953,7 @@ expandAst t = case t of Ast.AstLetHVectorIn vars l v -> astLetHVectorIn vars (expandAst l) (expandAst v) Ast.AstRFromS v -> astRFromS $ expandAst v + Ast.AstRFromX v -> astRFromX $ expandAst v Ast.AstMinIndexS a -> Ast.AstMinIndexS (expandAst a) Ast.AstMaxIndexS a -> Ast.AstMaxIndexS (expandAst a) @@ -2949,6 +2987,7 @@ expandAst t = case t of Ast.AstUnNestS v -> astUnNestS $ expandAst v Ast.AstSFromR v -> astSFromR $ expandAst v Ast.AstSFromX v -> astSFromX $ expandAst v + Ast.AstXFromR v -> astXFromR $ expandAst v Ast.AstXFromS v -> astXFromS $ expandAst v Ast.AstMkHVector l -> Ast.AstMkHVector $ V.map expandAstDynamic l @@ -3446,6 +3485,7 @@ substitute1Ast i var v1 = case v1 of (ml, mv) -> Just $ astLetHVectorIn vars (fromMaybe l ml) (fromMaybe v mv) Ast.AstRFromS v -> astRFromS <$> substitute1Ast i var v + Ast.AstRFromX v -> astRFromX <$> substitute1Ast i var v Ast.AstMinIndexS a -> Ast.AstMinIndexS <$> substitute1Ast i var a Ast.AstMaxIndexS a -> Ast.AstMaxIndexS <$> substitute1Ast i var a @@ -3514,6 +3554,7 @@ substitute1Ast i var v1 = case v1 of Ast.AstUnNestS v -> astUnNestS <$> substitute1Ast i var v Ast.AstSFromR v -> astSFromR <$> substitute1Ast i var v Ast.AstSFromX v -> astSFromX <$> substitute1Ast i var v + Ast.AstXFromR v -> astXFromR <$> substitute1Ast i var v Ast.AstXFromS v -> astXFromS <$> substitute1Ast i var v Ast.AstMkHVector args -> diff --git a/src/HordeAd/Core/AstTools.hs b/src/HordeAd/Core/AstTools.hs index a792e319..f801d4fc 100644 --- a/src/HordeAd/Core/AstTools.hs +++ b/src/HordeAd/Core/AstTools.hs @@ -24,17 +24,17 @@ import Data.List (foldl') import Data.Proxy (Proxy (Proxy)) import Data.Type.Equality (gcastWith, (:~:) (Refl)) import Data.Vector.Generic qualified as V +import GHC.Exts (IsList (..)) import GHC.TypeLits (sameNat, type (+)) import Unsafe.Coerce (unsafeCoerce) import Data.Array.Mixed.Permutation qualified as Permutation -import Data.Array.Mixed.Shape (shxSize) +import Data.Array.Mixed.Shape (KnownShX (..), shxSize) import Data.Array.Nested (IShR, KnownShS (..), ShR (..), pattern (:$:), pattern ZSR) -import Data.Array.Nested.Internal.Shape (shCvtSX, shrSize, shsSize) +import Data.Array.Nested.Internal.Shape (shrSize, shsSize) import Data.Array.Nested.Internal.Shape qualified as Nested.Internal.Shape - import HordeAd.Core.Ast import HordeAd.Core.TensorKind import HordeAd.Core.Types @@ -118,7 +118,10 @@ ftkAst t = case t of AstLetHVectorIn _ _ v -> ftkAst v AstRFromS @sh v | Dict <- lemKnownNatRankS (knownShS @sh) -> case ftkAst v of - FTKS _ x -> FTKR (listToShape $ shapeT @sh) x + FTKS _ x -> FTKR (fromList $ shapeT @sh) x + AstRFromX @sh v + | Dict <- lemKnownNatRankX (knownShX @sh) -> case ftkAst v of + FTKX shx x -> FTKR (fromList $ toList shx) x AstMinIndexS{} -> FTKS knownShS FTKScalar AstMaxIndexS{} -> FTKS knownShS FTKScalar @@ -165,8 +168,11 @@ ftkAst t = case t of FTKR _ x -> FTKS knownShS x AstSFromX v -> case ftkAst v of FTKX _ x -> FTKS knownShS x - AstXFromS @sh v -> case ftkAst v of - FTKS _ x -> FTKX (shCvtSX (knownShS @sh)) x + AstXFromR @sh v + | Dict <- lemKnownNatRankX (knownShX @sh) -> case ftkAst v of + FTKR shr x -> FTKX (fromList $ toList shr) x + AstXFromS v -> case ftkAst v of + FTKS sh x -> FTKX (fromList $ toList sh) x AstMkHVector v -> FTKUntyped @@ -268,6 +274,7 @@ varInAst var = \case AstProjectR l _p -> varInAst var l AstLetHVectorIn _vars l v -> varInAst var l || varInAst var v AstRFromS v -> varInAst var v + AstRFromX v -> varInAst var v AstMinIndexS a -> varInAst var a AstMaxIndexS a -> varInAst var a @@ -296,7 +303,6 @@ varInAst var = \case AstUnNestS v -> varInAst var v AstSFromR v -> varInAst var v AstSFromX v -> varInAst var v - AstXFromS v -> varInAst var v AstMinIndexX a -> varInAst var a AstMaxIndexX a -> varInAst var a @@ -322,6 +328,7 @@ varInAst var = \case AstFromIntegralX a -> varInAst var a AstProjectX l _p -> varInAst var l AstXFromR v -> varInAst var v + AstXFromS v -> varInAst var v AstMkHVector l -> any (varInAstDynamic var) l AstApply t ll -> varInAstHFun var t || varInAst var ll @@ -390,6 +397,7 @@ astIsSmall relaxed = \case relaxed && astIsSmall relaxed v -- often cheap and often fuses AstProjectR t _ -> astIsSmall relaxed t AstRFromS v -> astIsSmall relaxed v + AstRFromX v -> astIsSmall relaxed v AstIotaS -> True AstFromVectorS v | V.length v == 1 -> astIsSmall relaxed $ v V.! 0 @@ -400,6 +408,7 @@ astIsSmall relaxed = \case AstProjectS t _ -> astIsSmall relaxed t AstSFromR v -> astIsSmall relaxed v AstSFromX v -> astIsSmall relaxed v + AstXFromR v -> astIsSmall relaxed v AstXFromS v -> astIsSmall relaxed v AstMkHVector v | V.length v == 1 -> case v V.! 0 of diff --git a/src/HordeAd/Core/Delta.hs b/src/HordeAd/Core/Delta.hs index a4e31c83..6ee6b54a 100644 --- a/src/HordeAd/Core/Delta.hs +++ b/src/HordeAd/Core/Delta.hs @@ -63,6 +63,7 @@ import Data.Strict.Vector qualified as Data.Vector import Data.Traversable (mapAccumL) import Data.Type.Equality (gcastWith, testEquality, (:~:) (Refl)) import Data.Vector.Generic qualified as V +import GHC.Exts (IsList (..)) import GHC.TypeLits (KnownNat, sameNat, type (+), type (<=)) import Text.Show (showListWith) import Text.Show.Functions () @@ -70,12 +71,12 @@ import Type.Reflection (typeRep) import Unsafe.Coerce (unsafeCoerce) import Data.Array.Mixed.Permutation qualified as Permutation -import Data.Array.Mixed.Shape - (KnownShX (..), pattern (:.%), pattern ZIX, shxEqual) +import Data.Array.Mixed.Shape (pattern (:.%), pattern ZIX) import Data.Array.Nested ( IShR , IxS (..) , KnownShS (..) + , KnownShX (..) , Rank , ShR (..) , ShS (..) @@ -84,7 +85,7 @@ import Data.Array.Nested , type (++) ) import Data.Array.Nested qualified as Nested -import Data.Array.Nested.Internal.Shape (shCvtSX, shrRank) +import Data.Array.Nested.Internal.Shape (shrRank) import Data.Array.Nested.Internal.Shape qualified as Nested.Internal.Shape import HordeAd.Core.HVectorOps @@ -480,6 +481,9 @@ data Delta :: Target -> TensorKindType -> Type where RFromS :: forall sh r target. (TensorKind1 r, KnownShS sh) => Delta target (TKS2 sh r) -> Delta target (TKR2 (Rank sh) r) + RFromX :: forall sh r target. (TensorKind1 r, KnownShX sh) + => Delta target (TKX2 sh r) + -> Delta target (TKR2 (Rank sh) r) RFromH :: (KnownNat n, GoodScalar r) => Delta target TKUntyped -> Int -> Delta target (TKR n r) @@ -577,18 +581,14 @@ data Delta :: Target -> TensorKindType -> Type where UnNestS :: (TensorKind1 r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2)) => Delta target (TKS2 sh1 (TKS2 sh2 r)) -> Delta target (TKS2 (sh1 ++ sh2) r) - SFromR :: forall sh r target. (KnownShS sh, KnownNat (Rank sh), TensorKind1 r) + SFromR :: forall sh r target. + (KnownShS sh, KnownNat (Rank sh), TensorKind1 r) => Delta target (TKR2 (Rank sh) r) -> Delta target (TKS2 sh r) SFromX :: forall sh sh' r target. - ( KnownShS sh, KnownShX sh', Rank sh ~ Rank sh' - , KnownShX (Nested.MapJust sh), TensorKind1 r ) + (KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind1 r) => Delta target (TKX2 sh' r) -> Delta target (TKS2 sh r) - XFromS :: forall sh sh' r target. - (KnownShS sh, KnownShX sh', sh' ~ Nested.MapJust sh, TensorKind1 r) - => Delta target (TKS2 sh r) - -> Delta target (TKX2 sh' r) SFromH :: (KnownShS sh, GoodScalar r) => Delta target TKUntyped -> Int -> Delta target (TKS sh r) @@ -599,6 +599,12 @@ data Delta :: Target -> TensorKindType -> Type where FromVectorX :: (GoodScalar r, KnownShX sh, KnownNat n) => Data.Vector.Vector (Delta target (TKX sh r)) -> Delta target (TKX (Just n ': sh) r) + XFromR :: (KnownShX sh, TensorKind1 r, KnownNat (Rank sh)) + => Delta target (TKR2 (Rank sh) r) + -> Delta target (TKX2 sh r) + XFromS :: (KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind1 r) + => Delta target (TKS2 sh r) + -> Delta target (TKX2 sh' r) HToH :: HVector (Delta target) -> Delta target TKUntyped MapAccumR @@ -696,8 +702,11 @@ shapeDeltaFull = \case CastR d -> FTKR (shapeDelta d) FTKScalar RFromS @sh d | Dict <- lemKnownNatRankS (knownShS @sh) -> case shapeDeltaFull d of - FTKS _ x -> FTKR (listToShape $ shapeT @sh) x - RFromH d i -> FTKR (listToShape $ shapeVoidDynamic (shapeDeltaH d V.! i)) FTKScalar + FTKS _ x -> FTKR (fromList $ shapeT @sh) x + RFromX @sh d + | Dict <- lemKnownNatRankX (knownShX @sh) -> case shapeDeltaFull d of + FTKX shx x -> FTKR (fromList $ toList shx) x + RFromH d i -> FTKR (fromList $ shapeVoidDynamic (shapeDeltaH d V.! i)) FTKScalar IndexS d _ix -> case shapeDeltaFull d of FTKS _ x -> FTKS knownShS x @@ -735,8 +744,11 @@ shapeDeltaFull = \case FTKR _ x -> FTKS knownShS x SFromX d -> case shapeDeltaFull d of FTKX _ x -> FTKS knownShS x - XFromS @sh d -> case shapeDeltaFull d of - FTKS _ x -> FTKX (shCvtSX (knownShS @sh)) x + XFromR @sh d + | Dict <- lemKnownNatRankX (knownShX @sh) -> case shapeDeltaFull d of + FTKR shr x -> FTKX (fromList $ toList shr) x + XFromS d -> case shapeDeltaFull d of + FTKS sh x -> FTKX (fromList $ toList sh) x SFromH{} -> FTKS knownShS FTKScalar IndexX{} -> error "TODO" @@ -1198,6 +1210,8 @@ evalSame !s !c = \case RFromS (SFromR d) -> evalSame s c d -- no information lost, so no checks RFromS @sh d | Dict <- lemKnownNatRankS (knownShS @sh) -> evalSame s (sfromR c) d + RFromX @sh d | Dict <- lemKnownNatRankX (knownShX @sh) -> + evalSame s (xfromR c) d RFromH d i -> let cs = V.map dynamicFromVoid $ shapeDeltaH d ci = DynamicRanked c @@ -1262,13 +1276,12 @@ evalSame !s !c = \case case sameShape @sh @sh2 of Just Refl -> evalSame s c d _ -> error "evalSame: different shapes in SFromX(XFromS)" - SFromX d -> case shapeDeltaFull d of - FTKX sh' _ -> case shxEqual (xshape (xfromS c)) sh' of - Just Refl -> evalSame s (xfromS c) d - Nothing -> error "evalSame: wrong shapes in SFromX" + SFromX d -> + evalSame s (xfromS c) d -- impossible, shapes may differ: XFromS (SFromX d) -> evalSame s c d + XFromR @sh d | Dict <- lemKnownNatRankX (knownShX @sh) -> + evalSame s (rfromX c) d XFromS @sh d -> - gcastWith (unsafeCoerce Refl :: Rank sh :~: Rank (Nested.MapJust sh)) $ evalSame s (sfromX c) d SFromH d i -> let cs = V.map dynamicFromVoid $ shapeDeltaH d @@ -1584,6 +1597,7 @@ fwdSame params s = \case RFromS (SFromR d) -> fwdSame params s d -- no information lost, so no checks RFromS d -> second rfromS $ fwdSame params s d + RFromX d -> second rfromX $ fwdSame params s d RFromH d i -> let (s2, v) = fwdSame params s d in (s2, rfromD $ dunHVector v V.! i) @@ -1630,6 +1644,8 @@ fwdSame params s = \case Just Refl -> fwdSame params s d _ -> error "fwdSame: different shapes in SFromR(RFromS)" SFromR d -> second sfromR $ fwdSame params s d + XFromR @sh d | Dict <- lemKnownNatRankX (knownShX @sh) -> + second xfromR $ fwdSame params s d XFromS d -> second xfromS $ fwdSame params s d SFromX @sh (XFromS @sh2 d) -> case sameShape @sh @sh2 of diff --git a/src/HordeAd/Core/OpsADVal.hs b/src/HordeAd/Core/OpsADVal.hs index 21f1ba5c..e0217c01 100644 --- a/src/HordeAd/Core/OpsADVal.hs +++ b/src/HordeAd/Core/OpsADVal.hs @@ -26,7 +26,7 @@ import GHC.TypeLits (KnownNat, sameNat, type (+), type (<=)) import Type.Reflection (typeRep) import Data.Array.Mixed.Permutation qualified as Permutation -import Data.Array.Nested (IShR, KnownShS (..), KnownShX, Rank) +import Data.Array.Nested (IShR, KnownShS (..), KnownShX (..), Rank) import Data.Array.Nested qualified as Nested import Data.Array.Nested.Internal.Shape qualified as Nested.Internal.Shape @@ -322,6 +322,12 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target) => Delta target (TKS2 sh2 r2) -> Delta target (TKR2 (Rank sh2) r2) dRFromS (SFromR d) = d -- no information lost, so no checks dRFromS d = RFromS d + rfromX (D u u') = dDnotShared (rfromX u) (dRFromX u') + where + dRFromX :: (TensorKind1 r2, KnownShX sh2) + => Delta target (TKX2 sh2 r2) -> Delta target (TKR2 (Rank sh2) r2) + dRFromX (XFromR d) = d -- no information lost, so no checks + dRFromX d = RFromX d rtoScalar (D t d) = dDnotShared (rtoScalar t) (ToScalarG $ SFromR d) rfromScalar (D t d) = dDnotShared (rfromScalar t) (RFromS $ FromScalarG d) @@ -342,6 +348,11 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target) xprimalPart (D u _) = u xdualPart (D _ u') = u' xD t d = dD t d + xfromR :: forall sh r. (KnownShX sh, TensorKind1 r) + => ADVal target (TKR2 (Rank sh) r) -> ADVal target (TKX2 sh r) + xfromR (D u u') | Dict <- lemKnownNatRankX (knownShX @sh) = + dDnotShared (xfromR u) (XFromR u') + xfromS (D u u') = dDnotShared (xfromS u) (XFromS u') sminIndex (D u _) = let v = sminIndex u @@ -425,8 +436,7 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target) _ -> error "sfromR: different shapes in SFromR(RFromS)" dSFromR d = SFromR d sfromX :: forall r sh sh'. - ( KnownShS sh, KnownShX sh', Rank sh ~ Rank sh' - , KnownShX (Nested.MapJust sh), TensorKind1 r ) + ( KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind1 r ) => ADVal target (TKX2 sh' r) -> ADVal target (TKS2 sh r) sfromX (D u u') = dDnotShared (sfromX u) (dSFromX u') where @@ -435,7 +445,6 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target) Just Refl -> d _ -> error "sfromR: different shapes in SFromR(RFromS)" dSFromX d = SFromX d - xfromS (D u u') = dDnotShared (xfromS u) (XFromS u') stoScalar (D t d) = dDnotShared (stoScalar t) (ToScalarG d) sfromScalar (D t d) = dDnotShared (sfromScalar t) (FromScalarG d) diff --git a/src/HordeAd/Core/OpsAst.hs b/src/HordeAd/Core/OpsAst.hs index 2a1d9f15..32f14df1 100644 --- a/src/HordeAd/Core/OpsAst.hs +++ b/src/HordeAd/Core/OpsAst.hs @@ -374,6 +374,7 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where rcast = astCastR rfromIntegral = fromPrimal . astFromIntegralR . astSpanPrimal rfromS = astRFromS + rfromX = astRFromX rtoScalar = AstToScalar . AstSFromR rfromScalar = AstRFromS . AstFromScalar @@ -384,7 +385,7 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where rScale s t = astDualPart $ AstFromPrimal s * AstD (rzero (rshape s)) t xshape t = case ftkAst t of - FTKX sh FTKScalar -> sh + FTKX sh _ -> sh xindex v ix = AstIndexX v ix xfromVector = AstFromVectorX xreplicate = AstReplicate SNat @@ -394,6 +395,8 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where xprimalPart = astSpanPrimal xdualPart = astSpanDual xD u u' = astSpanD u u' + xfromR = astXFromR + xfromS = astXFromS sminIndex = fromPrimal . AstMinIndexS . astSpanPrimal smaxIndex = fromPrimal . AstMaxIndexS . astSpanPrimal @@ -429,7 +432,6 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where sunNest = astUnNestS sfromR = astSFromR sfromX = astSFromX - xfromS = astXFromS stoScalar = AstToScalar sfromScalar = AstFromScalar @@ -639,6 +641,7 @@ instance AstSpan s => BaseTensor (AstRaw s) where rfromIntegral = AstRaw . fromPrimal . AstFromIntegralR . astSpanPrimalRaw . unAstRaw rfromS = AstRaw . AstRFromS . unAstRaw + rfromX = AstRaw . AstRFromX . unAstRaw rtoScalar = AstRaw . AstToScalar . AstSFromR . unAstRaw rfromScalar = AstRaw . AstRFromS . AstFromScalar . unAstRaw @@ -651,7 +654,7 @@ instance AstSpan s => BaseTensor (AstRaw s) where * AstD (unAstRaw $ rzero (rshape s)) t xshape t = case ftkAst $ unAstRaw t of - FTKX sh FTKScalar -> sh + FTKX sh _ -> sh xindex v ix = AstRaw $ AstIndexX (unAstRaw v) (unAstRaw <$> ix) xfromVector = AstRaw . AstFromVectorX . V.map unAstRaw @@ -662,6 +665,8 @@ instance AstSpan s => BaseTensor (AstRaw s) where xprimalPart = AstRaw . astSpanPrimalRaw . unAstRaw xdualPart = astSpanDualRaw . unAstRaw xD u u' = AstRaw $ astSpanD (unAstRaw u) u' + xfromR = AstRaw . AstXFromR . unAstRaw + xfromS = AstRaw . AstXFromS . unAstRaw sminIndex = AstRaw . fromPrimal . AstMinIndexS . astSpanPrimalRaw . unAstRaw smaxIndex = AstRaw . fromPrimal . AstMaxIndexS . astSpanPrimalRaw . unAstRaw @@ -696,7 +701,6 @@ instance AstSpan s => BaseTensor (AstRaw s) where sunNest = AstRaw . AstUnNestS . unAstRaw sfromR = AstRaw . AstSFromR . unAstRaw sfromX = AstRaw . AstSFromX . unAstRaw - xfromS = AstRaw . AstXFromS . unAstRaw stoScalar = AstRaw . AstToScalar . unAstRaw sfromScalar = AstRaw . AstFromScalar . unAstRaw @@ -879,6 +883,7 @@ instance AstSpan s => BaseTensor (AstNoVectorize s) where rcast = AstNoVectorize . rcast . unAstNoVectorize rfromIntegral = AstNoVectorize . rfromIntegral . unAstNoVectorize rfromS = AstNoVectorize . rfromS . unAstNoVectorize + rfromX = AstNoVectorize . rfromX . unAstNoVectorize rtoScalar = AstNoVectorize . rtoScalar . unAstNoVectorize rfromScalar = AstNoVectorize . rfromScalar . unAstNoVectorize @@ -889,7 +894,7 @@ instance AstSpan s => BaseTensor (AstNoVectorize s) where rScale s t = rScale @(AstTensor AstMethodLet PrimalSpan) (unAstNoVectorize s) t xshape t = case ftkAst $ unAstNoVectorize t of - FTKX sh FTKScalar -> sh + FTKX sh _ -> sh xindex v ix = AstNoVectorize $ xindex (unAstNoVectorize v) (unAstNoVectorize <$> ix) xfromVector = AstNoVectorize . xfromVector . V.map unAstNoVectorize @@ -901,6 +906,8 @@ instance AstSpan s => BaseTensor (AstNoVectorize s) where xdualPart = xdualPart . unAstNoVectorize xD u u' = AstNoVectorize $ xD (unAstNoVectorize u) u' + xfromR = AstNoVectorize . xfromR . unAstNoVectorize + xfromS = AstNoVectorize . xfromS . unAstNoVectorize sminIndex = AstNoVectorize . sminIndex . unAstNoVectorize smaxIndex = AstNoVectorize . smaxIndex . unAstNoVectorize @@ -936,7 +943,6 @@ instance AstSpan s => BaseTensor (AstNoVectorize s) where sunNest = AstNoVectorize . astUnNestS . unAstNoVectorize sfromR = AstNoVectorize . sfromR . unAstNoVectorize sfromX = AstNoVectorize . sfromX . unAstNoVectorize - xfromS = AstNoVectorize . xfromS . unAstNoVectorize stoScalar = AstNoVectorize . stoScalar . unAstNoVectorize sfromScalar = AstNoVectorize . sfromScalar . unAstNoVectorize @@ -1112,6 +1118,7 @@ instance AstSpan s => BaseTensor (AstNoSimplify s) where rfromIntegral = AstNoSimplify . fromPrimal . AstFromIntegralR . astSpanPrimal . unAstNoSimplify rfromS = AstNoSimplify . AstRFromS . unAstNoSimplify + rfromX = AstNoSimplify . AstRFromX . unAstNoSimplify rtoScalar = AstNoSimplify . AstToScalar . AstSFromR . unAstNoSimplify rfromScalar = AstNoSimplify . AstRFromS . AstFromScalar . unAstNoSimplify @@ -1124,7 +1131,7 @@ instance AstSpan s => BaseTensor (AstNoSimplify s) where * AstD (rzero (rshape s)) t xshape t = case ftkAst $ unAstNoSimplify t of - FTKX sh FTKScalar -> sh + FTKX sh _ -> sh xindex v ix = AstNoSimplify $ AstIndexX (unAstNoSimplify v) (unAstNoSimplify <$> ix) xfromVector = AstNoSimplify . AstFromVectorX . V.map unAstNoSimplify @@ -1135,6 +1142,8 @@ instance AstSpan s => BaseTensor (AstNoSimplify s) where xprimalPart = AstNoSimplify . astSpanPrimal . unAstNoSimplify xdualPart = astSpanDual . unAstNoSimplify xD u u' = AstNoSimplify $ astSpanD (unAstNoSimplify u) u' + xfromR = AstNoSimplify . AstXFromR . unAstNoSimplify + xfromS = AstNoSimplify . AstXFromS . unAstNoSimplify sminIndex = AstNoSimplify . fromPrimal . AstMinIndexS . astSpanPrimal . unAstNoSimplify @@ -1177,7 +1186,6 @@ instance AstSpan s => BaseTensor (AstNoSimplify s) where sunNest = AstNoSimplify . AstUnNestS . unAstNoSimplify sfromR = AstNoSimplify . AstSFromR . unAstNoSimplify sfromX = AstNoSimplify . AstSFromX . unAstNoSimplify - xfromS = AstNoSimplify . AstXFromS . unAstNoSimplify stoScalar = AstNoSimplify . AstToScalar . unAstNoSimplify sfromScalar = AstNoSimplify . AstFromScalar . unAstNoSimplify diff --git a/src/HordeAd/Core/OpsConcrete.hs b/src/HordeAd/Core/OpsConcrete.hs index fabd7ac2..a64e0b3a 100644 --- a/src/HordeAd/Core/OpsConcrete.hs +++ b/src/HordeAd/Core/OpsConcrete.hs @@ -19,7 +19,7 @@ import System.Random import Type.Reflection (typeRep) import Unsafe.Coerce (unsafeCoerce) -import Data.Array.Nested (KnownShS (..), Rank) +import Data.Array.Nested (KnownShS (..), KnownShX (..), Rank) import Data.Array.Nested qualified as Nested import HordeAd.Core.Adaptor @@ -157,6 +157,13 @@ instance BaseTensor RepN where xprimalPart = id xdualPart _ = DummyDualTarget xD u _ = u + xfromR :: forall sh r. (KnownShX sh, TensorKind1 r) + => RepN (TKR2 (Rank sh) r) -> RepN (TKX2 sh r) + xfromR = RepN . Nested.rcastToMixed (knownShX @sh) . unRepN + xfromS :: forall sh sh' r. + (KnownShX sh', Rank sh ~ Rank sh', TensorKind1 r) + => RepN (TKS2 sh r) -> RepN (TKX2 sh' r) + xfromS = RepN . Nested.scastToMixed (knownShX @sh') . unRepN sminIndex = RepN . tminIndexS . unRepN smaxIndex = RepN . tmaxIndexS . unRepN @@ -232,7 +239,6 @@ instance BaseTensor RepN where sunNest t = RepN $ Nested.sunNest $ unRepN t sfromR = RepN . flip Nested.rcastToShaped knownShS . unRepN sfromX = RepN . flip Nested.mcastToShaped knownShS . unRepN - xfromS = RepN . Nested.stoMixed. unRepN stoScalar = RepN . Nested.sunScalar . unRepN sfromScalar = RepN . Nested.sscalar . unRepN diff --git a/src/HordeAd/Core/TensorClass.hs b/src/HordeAd/Core/TensorClass.hs index 7fab57dc..a1c87c4d 100644 --- a/src/HordeAd/Core/TensorClass.hs +++ b/src/HordeAd/Core/TensorClass.hs @@ -426,6 +426,8 @@ class ( Num (IntOf target) -> target (TKProduct (TKR2 n y) (TKR2 n z)) rfromS :: (TensorKind1 r, KnownShS sh) => target (TKS2 sh r) -> target (TKR2 (Rank sh) r) + rfromX :: (TensorKind1 r, KnownShX sh) + => target (TKX2 sh r) -> target (TKR2 (Rank sh) r) rtoScalar :: GoodScalar r => target (TKR 0 r) -> target (TKScalar r) rfromScalar :: GoodScalar r => target (TKScalar r) -> target (TKR 0 r) -- Prevents wrong shape in @0@ with ranked (but not shaped) tensors @@ -500,7 +502,7 @@ class ( Num (IntOf target) => IShX sh -> target (TKX sh r) xzero sh = xrepl sh 0 xfromPrimal :: (GoodScalar r, KnownShX sh) - => PrimalOf target (TKX sh r) -> target (TKX sh r) + => PrimalOf target (TKX sh r) -> target (TKX sh r) xprimalPart :: (GoodScalar r, KnownShX sh) => target (TKX sh r) -> PrimalOf target (TKX sh r) xdualPart :: (GoodScalar r, KnownShX sh) @@ -508,6 +510,10 @@ class ( Num (IntOf target) xD :: (GoodScalar r, KnownShX sh) => PrimalOf target (TKX sh r)-> DualOf target (TKX sh r) -> target (TKX sh r) + xfromR :: (KnownShX sh, KnownNat (Rank sh), TensorKind1 r) + => target (TKR2 (Rank sh) r) -> target (TKX2 sh r) + xfromS :: (KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind1 r) + => target (TKS2 sh r) -> target (TKX2 sh' r) -- Integer codomain sshape :: forall sh r. (TensorKind2 r, KnownShS sh) @@ -872,11 +878,8 @@ class ( Num (IntOf target) -> target (TKProduct (TKS2 sh y) (TKS2 sh z)) sfromR :: (TensorKind1 r, KnownShS sh, KnownNat (Rank sh)) => target (TKR2 (Rank sh) r) -> target (TKS2 sh r) - sfromX :: ( KnownShS sh, KnownShX sh', Rank sh ~ Rank sh' - , KnownShX (Nested.MapJust sh), TensorKind1 r ) + sfromX :: ( KnownShS sh, KnownShX sh', Rank sh ~ Rank sh', TensorKind1 r ) => target (TKX2 sh' r) -> target (TKS2 sh r) - xfromS :: (KnownShS sh, KnownShX sh', sh' ~ Nested.MapJust sh, TensorKind1 r) - => target (TKS2 sh r) -> target (TKX2 sh' r) stoScalar :: GoodScalar r => target (TKS '[] r) -> target (TKScalar r) sfromScalar :: GoodScalar r => target (TKScalar r) -> target (TKS '[] r) diff --git a/src/HordeAd/Core/TensorKind.hs b/src/HordeAd/Core/TensorKind.hs index c9a0420f..4653b233 100644 --- a/src/HordeAd/Core/TensorKind.hs +++ b/src/HordeAd/Core/TensorKind.hs @@ -221,11 +221,11 @@ type role FullTensorKind nominal data FullTensorKind y where FTKScalar :: GoodScalar r => FullTensorKind (TKScalar r) FTKR :: forall n x. Nested.Elt (RepORArray x) - => IShR n -> FullTensorKind x -> FullTensorKind (TKR2 n x) + => IShR n -> FullTensorKind x -> FullTensorKind (TKR2 n x) FTKS :: forall sh x. Nested.Elt (RepORArray x) - => ShS sh -> FullTensorKind x -> FullTensorKind (TKS2 sh x) + => ShS sh -> FullTensorKind x -> FullTensorKind (TKS2 sh x) FTKX :: forall sh x. Nested.Elt (RepORArray x) - => IShX sh -> FullTensorKind x -> FullTensorKind (TKX2 sh x) + => IShX sh -> FullTensorKind x -> FullTensorKind (TKX2 sh x) FTKProduct :: (Nested.Elt (RepORArray y), Nested.Elt (RepORArray z)) => FullTensorKind y -> FullTensorKind z -> FullTensorKind (TKProduct y z) diff --git a/src/HordeAd/Core/Types.hs b/src/HordeAd/Core/Types.hs index 7481202d..b81b55f0 100644 --- a/src/HordeAd/Core/Types.hs +++ b/src/HordeAd/Core/Types.hs @@ -10,7 +10,7 @@ module HordeAd.Core.Types , withKnownShS, withKnownShX , sshapeKnown, slistKnown, sixKnown, knownShR , shapeT, shapeP, sizeT, sizeP - , withShapeP, sameShape, matchingRank, lemKnownNatRankS + , withShapeP, sameShape, matchingRank, lemKnownNatRankS, lemKnownNatRankX , Dict(..), PermC, trustMeThisIsAPermutation , Take, Drop, Last, Init -- * Kinds of the functors that determine the structure of a tensor type @@ -57,7 +57,16 @@ import Data.Array.Mixed.Permutation qualified as Permutation import Data.Array.Mixed.Shape (withKnownShX) import Data.Array.Mixed.Types (Dict (..)) import Data.Array.Nested - (IxR, IxS (..), IxX, KnownShS (..), ListS (..), Rank, ShR (..), ShS (..)) + ( IxR + , IxS (..) + , IxX + , KnownShS (..) + , ListS (..) + , Rank + , ShR (..) + , ShS (..) + , StaticShX (..) + ) import Data.Array.Nested qualified as Nested import Data.Array.Nested.Internal.Mixed qualified as Nested.Internal.Mixed import Data.Array.Nested.Internal.Shape (shsToList, withKnownShS) @@ -136,6 +145,10 @@ lemKnownNatRankS :: ShS sh -> Dict KnownNat (Rank sh) lemKnownNatRankS ZSS = Dict lemKnownNatRankS (_ :$$ sh) | Dict <- lemKnownNatRankS sh = Dict +lemKnownNatRankX :: StaticShX sh -> Dict KnownNat (Rank sh) +lemKnownNatRankX ZKX = Dict +lemKnownNatRankX (_ :!% sh) | Dict <- lemKnownNatRankX sh = Dict + class Permutation.IsPermutation is => PermC is instance Permutation.IsPermutation is => PermC is