Skip to content

Commit

Permalink
Generalize rbuild and sbuild to nested arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 2, 2024
1 parent 5b2b30f commit 17808ef
Show file tree
Hide file tree
Showing 14 changed files with 235 additions and 189 deletions.
2 changes: 1 addition & 1 deletion src/HordeAd/Core/AstInterpret.hs
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ interpretAst !env = \case
$ gcastWith (unsafeCoerce Refl :: sh2 :~: sh2 ++ Drop p sh)
-- transitivity of type equality doesn't work, by design,
-- so this direct cast is needed instead of more basic laws
$ sbuild @target @r @(Rank sh2)
$ sbuild @target @(TKScalar r) @(Rank sh2)
(interpretLambdaIndexS
interpretAst env
(vars, fromPrimal @s $ AstFromIntegralS $ AstFromScalar i))
Expand Down
9 changes: 4 additions & 5 deletions src/HordeAd/Core/CarriersADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ instance OrdF f => OrdF (ADVal f) where
D u _ >. D v _ = u >. v
D u _ >=. D v _ = u >=. v

indexPrimal :: ( ADReadyNoLet target
, KnownNat m, KnownNat n, TensorKind2 r )
indexPrimal :: ( ADReadyNoLet target, TensorKind2 r
, KnownNat m, KnownNat n )
=> ADVal target (TKR2 (m + n) r) -> IxROf target m
-> ADVal target (TKR2 n r)
indexPrimal (D u u') ix = dD (rindex u ix) (IndexR u' ix)
Expand Down Expand Up @@ -338,9 +338,8 @@ instance ( ADReadyNoLet target
(fromList [ifF b 0 1])
_ -> error "TODO"

indexPrimalS :: ( ADReadyNoLet target
, TensorKind2 r, KnownShS sh1, KnownShS sh2
, KnownShS (sh1 ++ sh2) )
indexPrimalS :: ( ADReadyNoLet target, TensorKind2 r
, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2) )
=> ADVal target (TKS2 (sh1 ++ sh2) r) -> IxSOf target sh1
-> ADVal target (TKS2 sh2 r)
indexPrimalS (D u u') ix = dD (sindex u ix) (IndexS u' ix)
Expand Down
4 changes: 3 additions & 1 deletion src/HordeAd/Core/CarriersConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ import HordeAd.Core.Types
-- (also in sum1Inner and extremum and maybe tdot0R):
-- LA.sumElements $ OI.toUnorderedVectorT sh t

type TensorKind2 y = (TensorKind y, Default (RepORArray y), Nested.KnownElt (RepORArray y), Show (RepORArray y))
type TensorKind2 y =
( TensorKind y, Default (RepORArray y), Nested.KnownElt (RepORArray y)
, Show (RepORArray y), Num (RepORArray (ADTensorKind y)) )

type family RepORArray (y :: TensorKindType) where
RepORArray (TKScalar r) = r
Expand Down
2 changes: 1 addition & 1 deletion src/HordeAd/Core/Delta.hs
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ data Delta :: Target -> TensorKindType -> Type where
RFromH :: (KnownNat n, GoodScalar r)
=> Delta target TKUntyped -> Int -> Delta target (TKR n r)

IndexS :: (KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2), TensorKind2 r)
IndexS :: (TensorKind2 r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> Delta target (TKS2 (sh1 ++ sh2) r)
-> IxSOf target sh1
-> Delta target (TKS2 sh2 r)
Expand Down
24 changes: 14 additions & 10 deletions src/HordeAd/Core/OpsADVal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,14 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target)
rreshape sh t@(D u u') = case sameNat (Proxy @m) (Proxy @n) of
Just Refl | sh == rshape u -> t
_ -> dD (rreshape sh u) (ReshapeR sh u')
rbuild1 :: forall r n. (GoodScalar r, KnownNat n)
=> Int -> (IntOf (ADVal target) -> ADVal target (TKR n r))
-> ADVal target (TKR (1 + n) r)
rbuild1 :: forall r n. (TensorKind2 r, KnownNat n)
=> Int -> (IntOf (ADVal target) -> ADVal target (TKR2 n r))
-> ADVal target (TKR2 (1 + n) r)
rbuild1 0 _ = case sameNat (Proxy @n) (Proxy @0) of
Just Refl -> rconcrete Nested.remptyArray
-- the only case where we can guess sh
Nothing -> error "rbuild1: shape ambiguity"
Just Refl -> case stensorKind @r of
STKScalar{} -> rconcrete Nested.remptyArray
_ -> error "rbuild1: empty nested array"
Nothing -> error "rbuild1: shape ambiguity"
rbuild1 k f = rfromList $ NonEmpty.map (f . fromIntegral)
$ (0 :: Int) :| [1 .. k - 1]
-- element-wise (POPL) version
Expand Down Expand Up @@ -394,11 +395,14 @@ instance (ADReadyNoLet target, ShareTensor target, ShareTensor (PrimalOf target)
sreshape t@(D u u') = case sameShape @sh2 @sh of
Just Refl -> t
_ -> dD (sreshape u) (ReshapeS u')
sbuild1 :: forall r n sh. (GoodScalar r, KnownNat n, KnownShS sh)
=> (IntOf (ADVal target) -> ADVal target (TKS sh r))
-> ADVal target (TKS (n ': sh) r)
sbuild1 :: forall r n sh. (TensorKind2 r, KnownNat n, KnownShS sh)
=> (IntOf (ADVal target) -> ADVal target (TKS2 sh r))
-> ADVal target (TKS2 (n ': sh) r)
sbuild1 f = case sameNat (Proxy @n) (Proxy @0) of
Just Refl -> sconcrete $ Nested.semptyArray (knownShS @sh)
Just Refl -> case stensorKind @r of
STKScalar{} ->
sconcrete $ Nested.semptyArray (knownShS @sh)
_ -> error "sbuild1: empty nested array"
Nothing -> sfromList $ NonEmpty.map (f . fromIntegral)
$ (0 :: Int) :| [1 .. valueOf @n - 1]
-- element-wise (POPL) version
Expand Down
25 changes: 13 additions & 12 deletions src/HordeAd/Core/OpsAst.hs
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,10 @@ instance AstSpan s => BaseTensor (AstTensor AstMethodLet s) where
sreverse = astReverseS
stranspose perm = astTransposeS perm
sreshape = astReshapeS
sbuild1 :: forall r n sh. (GoodScalar r, KnownNat n, KnownShS sh)
=> (IntOf (AstTensor AstMethodLet s) -> AstTensor AstMethodLet s (TKS sh r))
-> AstTensor AstMethodLet s (TKS (n ': sh) r)
sbuild1 :: forall r n sh. (TensorKind2 r, KnownNat n, KnownShS sh)
=> (IntOf (AstTensor AstMethodLet s)
-> AstTensor AstMethodLet s (TKS2 sh r))
-> AstTensor AstMethodLet s (TKS2 (n ': sh) r)
sbuild1 f =
astBuild1Vectorize (SNat @n) f
sgather t f = astGatherStepS t
Expand Down Expand Up @@ -679,9 +680,9 @@ instance AstSpan s => BaseTensor (AstRaw s) where
sreverse = AstRaw . AstReverseS . unAstRaw
stranspose perm = AstRaw . AstTransposeS perm . unAstRaw
sreshape = AstRaw . AstReshapeS . unAstRaw
sbuild1 :: forall r n sh. (GoodScalar r, KnownNat n, KnownShS sh)
=> (IntOf (AstRaw s) -> AstRaw s (TKS sh r))
-> AstRaw s (TKS (n ': sh) r)
sbuild1 :: forall r n sh. (TensorKind2 r, KnownNat n, KnownShS sh)
=> (IntOf (AstRaw s) -> AstRaw s (TKS2 sh r))
-> AstRaw s (TKS2 (n ': sh) r)
sbuild1 f = AstRaw $ AstBuild1 (SNat @n)
$ funToAstI -- this introduces new variable names
$ unAstRaw . f . AstRaw
Expand Down Expand Up @@ -921,9 +922,9 @@ instance AstSpan s => BaseTensor (AstNoVectorize s) where
stranspose perm =
AstNoVectorize . stranspose perm . unAstNoVectorize
sreshape = AstNoVectorize . sreshape . unAstNoVectorize
sbuild1 :: forall r n sh. (GoodScalar r, KnownNat n, KnownShS sh)
=> (IntOf (AstNoVectorize s) -> AstNoVectorize s (TKS sh r))
-> AstNoVectorize s (TKS (n ': sh) r)
sbuild1 :: forall r n sh. (TensorKind2 r, KnownNat n, KnownShS sh)
=> (IntOf (AstNoVectorize s) -> AstNoVectorize s (TKS2 sh r))
-> AstNoVectorize s (TKS2 (n ': sh) r)
sbuild1 f = AstNoVectorize $ AstBuild1 (SNat @n)
$ funToAstI -- this introduces new variable names
$ unAstNoVectorize . f . AstNoVectorize
Expand Down Expand Up @@ -1160,9 +1161,9 @@ instance AstSpan s => BaseTensor (AstNoSimplify s) where
stranspose perm =
AstNoSimplify . AstTransposeS perm . unAstNoSimplify
sreshape = AstNoSimplify . AstReshapeS . unAstNoSimplify
sbuild1 :: forall r n sh. (GoodScalar r, KnownNat n, KnownShS sh)
=> (IntOf (AstNoSimplify s) -> AstNoSimplify s (TKS sh r))
-> AstNoSimplify s (TKS (n ': sh) r)
sbuild1 :: forall r n sh. (TensorKind2 r, KnownNat n, KnownShS sh)
=> (IntOf (AstNoSimplify s) -> AstNoSimplify s (TKS2 sh r))
-> AstNoSimplify s (TKS2 (n ': sh) r)
sbuild1 f =
AstNoSimplify
$ astBuild1Vectorize (SNat @n) (unAstNoSimplify . f . AstNoSimplify)
Expand Down
53 changes: 44 additions & 9 deletions src/HordeAd/Core/OpsConcrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ import Data.Function ((&))
import Data.List (foldl', mapAccumL, mapAccumR, scanl')
import Data.List.NonEmpty qualified as NonEmpty
import Data.Proxy (Proxy (Proxy))
import Data.Type.Equality ((:~:) (Refl))
import Data.Type.Equality (gcastWith, (:~:) (Refl))
import Data.Vector.Generic qualified as V
import GHC.TypeLits (KnownNat)
import System.Random
import Unsafe.Coerce (unsafeCoerce)

import Data.Array.Nested (KnownShS (..), Rank)
import Data.Array.Nested qualified as Nested
Expand Down Expand Up @@ -105,10 +106,24 @@ instance BaseTensor RepN where
rtranspose perm = RepN . ttransposeR perm . unRepN
rreshape sh = RepN . treshapeR sh . unRepN
rbuild1 k f = RepN $ tbuild1R k (unRepN . f . RepN)
rmap0N f t = RepN $ tmap0NR (unRepN . f . RepN) (unRepN t)
rzipWith0N f t u =
RepN $ tzipWith0NR (\v w -> unRepN $ f (RepN v) (RepN w))
(unRepN t) (unRepN u)
rmap0N :: forall r r1 n target.
(target ~ RepN, TensorKind2 r, TensorKind2 r1, KnownNat n)
=> (target (TKR2 0 r1) -> target (TKR2 0 r)) -> target (TKR2 n r1)
-> target (TKR2 n r)
rmap0N f t = case (stensorKind @r1, stensorKind @r) of
(STKScalar{}, STKScalar{}) -> RepN $ tmap0NR (unRepN . f . RepN) (unRepN t)
_ -> -- TODO: how to call the default implementation?
rbuild (rshape t) (f . rindex0 t)
rzipWith0N :: forall r1 r2 r n target.
(target ~ RepN, TensorKind2 r1, TensorKind2 r2, TensorKind2 r, KnownNat n)
=> (target (TKR2 0 r1) -> target (TKR2 0 r2) -> target (TKR2 0 r))
-> target (TKR2 n r1) -> target (TKR2 n r2) -> target (TKR2 n r)
rzipWith0N f t u = case (stensorKind @r1, stensorKind @r2, stensorKind @r) of
(STKScalar{}, STKScalar{}, STKScalar{}) ->
RepN $ tzipWith0NR (\v w -> unRepN $ f (RepN v) (RepN w))
(unRepN t) (unRepN u)
_ -> -- TODO: how to call the default implementation?
rbuild (rshape u) (\ix -> f (rindex0 t ix) (rindex0 u ix))
rgather sh t f = RepN $ tgatherZR sh (unRepN t)
(fmap unRepN . f . fmap RepN)
rgather1 k t f = RepN $ tgatherZ1R k (unRepN t)
Expand Down Expand Up @@ -182,10 +197,30 @@ instance BaseTensor RepN where
stranspose perm = RepN . ttransposeS perm . unRepN
sreshape = RepN . treshapeS . unRepN
sbuild1 f = RepN $ tbuild1S (unRepN . f . RepN)
smap0N f t = RepN $ tmap0NS (unRepN . f . RepN) (unRepN t)
szipWith0N f t u =
RepN $ tzipWith0NS (\v w -> unRepN $ f (RepN v) (RepN w))
(unRepN t) (unRepN u)
smap0N :: forall r1 r sh target.
(target ~ RepN, TensorKind2 r1, TensorKind2 r, KnownShS sh)
=> (target (TKS2 '[] r1) -> target (TKS2 '[] r)) -> target (TKS2 sh r1)
-> target (TKS2 sh r)
smap0N f v = case (stensorKind @r1, stensorKind @r) of
(STKScalar{}, STKScalar{}) ->
RepN $ tmap0NS (unRepN . f . RepN) (unRepN v)
_ -> -- TODO: how to call the default implementation?
gcastWith (unsafeCoerce Refl :: Drop (Rank sh) sh :~: '[])
$ gcastWith (unsafeCoerce Refl :: Take (Rank sh) sh :~: sh)
$ sbuild @target @r @(Rank sh) (f . sindex0 v)
szipWith0N :: forall r1 r2 r sh target.
( target ~ RepN, TensorKind2 r1, TensorKind2 r2, TensorKind2 r
, KnownShS sh )
=> (target (TKS2 '[] r1) -> target (TKS2 '[] r2) -> target (TKS2 '[] r))
-> target (TKS2 sh r1) -> target (TKS2 sh r2) -> target (TKS2 sh r)
szipWith0N f t u = case (stensorKind @r1, stensorKind @r2, stensorKind @r) of
(STKScalar{}, STKScalar{}, STKScalar{}) ->
RepN $ tzipWith0NS (\v w -> unRepN $ f (RepN v) (RepN w))
(unRepN t) (unRepN u)
_ -> -- TODO: how to call the default implementation?
gcastWith (unsafeCoerce Refl :: Drop (Rank sh) sh :~: '[])
$ gcastWith (unsafeCoerce Refl :: Take (Rank sh) sh :~: sh)
$ sbuild @target @_ @(Rank sh) (\ix -> f (sindex0 t ix) (sindex0 u ix))
sgather t f = RepN $ tgatherZS (unRepN t)
(fmap unRepN . f . fmap RepN)
sgather1 t f = RepN $ tgatherZ1S (unRepN t)
Expand Down
Loading

0 comments on commit 17808ef

Please sign in to comment.