Skip to content

Commit

Permalink
Generalize shareIx
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Nov 30, 2024
1 parent 8e24e80 commit a2acdfb
Showing 1 changed file with 19 additions and 22 deletions.
41 changes: 19 additions & 22 deletions src/HordeAd/Core/AstSimplify.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import Data.Strict.Vector qualified as Data.Vector
import Data.Type.Equality (gcastWith, testEquality, (:~:) (Refl))
import Data.Type.Ord (Compare)
import Data.Vector.Generic qualified as V
import GHC.Exts (IsList (..))
import GHC.TypeLits
( KnownNat
, Nat
Expand Down Expand Up @@ -644,9 +645,9 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 i1 (rest1 :: AstIxS AstMethodLet shm1)) |
Ast.AstDualPart{} -> Ast.AstIndexS v0 ix
Ast.AstFromPrimal v -> Ast.AstFromPrimal $ astIndex v ix
Ast.AstD u u' ->
shareIxS ix $ \ !ix2 -> Ast.AstD (astIndexRec u ix2) (astIndexRec u' ix2)
shareIx ix $ \ !ix2 -> Ast.AstD (astIndexRec u ix2) (astIndexRec u' ix2)
Ast.AstCond b v w ->
shareIxS ix $ \ !ix2 -> astCond b (astIndexRec v ix2) (astIndexRec w ix2)
shareIx ix $ \ !ix2 -> astCond b (astIndexRec v ix2) (astIndexRec w ix2)
Ast.AstReplicate @y2 _snat v -> case stensorKind @y2 of
STKS sh _ -> withKnownShS sh $ astIndex v rest1
Ast.AstBuild1 @y2 _snat (var2, v) -> case stensorKind @y2 of
Expand Down Expand Up @@ -694,20 +695,20 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 i1 (rest1 :: AstIxS AstMethodLet shm1)) |
-- sfromIntegral . sfromPrimal . sfromR . rfromScalar $ interpretAstPrimal env i
Ast.AstIotaS -> Ast.AstIndexS v0 ix
AstN1S opCode u ->
shareIxS ix $ \ !ix2 -> AstN1S opCode (astIndexRec u ix2)
shareIx ix $ \ !ix2 -> AstN1S opCode (astIndexRec u ix2)
AstN2S opCode u v ->
shareIxS ix $ \ !ix2 -> AstN2S opCode (astIndexRec u ix2) (astIndexRec v ix2)
shareIx ix $ \ !ix2 -> AstN2S opCode (astIndexRec u ix2) (astIndexRec v ix2)
Ast.AstR1S opCode u ->
shareIxS ix
shareIx ix
$ \ !ix2 -> Ast.AstR1S opCode (astIndexRec u ix2)
Ast.AstR2S opCode u v ->
shareIxS ix
shareIx ix
$ \ !ix2 -> Ast.AstR2S opCode (astIndexRec u ix2) (astIndexRec v ix2)
Ast.AstI2S opCode u v ->
shareIxS ix
shareIx ix
$ \ !ix2 -> Ast.AstI2S opCode (astIndexRec u ix2) (astIndexRec v ix2)
AstSumOfListS args ->
shareIxS ix $ \ !ix2 -> astSumOfListS (map (`astIndexRec` ix2) args)
shareIx ix $ \ !ix2 -> astSumOfListS (map (`astIndexRec` ix2) args)
Ast.AstIndexS v (ix2 :: AstIxS AstMethodLet sh4) ->
gcastWith (unsafeCoerce Refl
:: (sh4 ++ shm) ++ shn :~: sh4 ++ (shm ++ shn)) $
Expand Down Expand Up @@ -759,7 +760,7 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 i1 (rest1 :: AstIxS AstMethodLet shm1)) |
Ast.AstFromVectorS{} | ZIS <- rest1 -> -- normal form
Ast.AstIndexS v0 ix
Ast.AstFromVectorS l ->
shareIxS rest1 $ \ !ix2 ->
shareIx rest1 $ \ !ix2 ->
Ast.AstIndexS @'[in1] @shn (astFromVectorS $ V.map (`astIndexRec` ix2) l)
(ShapedList.singletonIndex i1)
Ast.AstAppendS @_ @m u v ->
Expand Down Expand Up @@ -838,27 +839,23 @@ astIndexKnobsS knobs v0 ix@((:.$) @in1 i1 (rest1 :: AstIxS AstMethodLet shm1)) |

-- TODO: compared to tletIx, it adds many lets, not one, but does not
-- create other (and non-simplified!) big terms and also uses astIsSmall,
-- so it's probably more efficient. Use this instead of tletIx/sletIx
-- so it's probably more efficient. Use this instead of tletIx
-- or design something even better.
shareIx :: (KnownNat n, GoodScalar r, KnownNat m)
=> AstIxR AstMethodLet n -> (AstIxR AstMethodLet n -> AstTensor AstMethodLet s (TKR m r))
-> AstTensor AstMethodLet s (TKR m r)
shareIx :: (TensorKind y, IsList indexType, Item indexType ~ AstInt AstMethodLet)
=> indexType
-> (indexType -> AstTensor AstMethodLet s y)
-> AstTensor AstMethodLet s y
{-# NOINLINE shareIx #-}
shareIx ix f = unsafePerformIO $ do
let shareI :: AstInt AstMethodLet -> IO (Maybe (IntVarName, AstInt AstMethodLet), AstInt AstMethodLet)
let shareI :: AstInt AstMethodLet
-> IO (Maybe (IntVarName, AstInt AstMethodLet), AstInt AstMethodLet)
shareI i | astIsSmall True i = return (Nothing, i)
shareI i = funToAstIntVarIO $ \ (!varFresh, !astVarFresh) ->
(Just (varFresh, i), astVarFresh)
(bindings, ix2) <- mapAndUnzipM shareI (indexToList ix)
return $! foldr (uncurry Ast.AstLet) (f $ listToIndex ix2)
(bindings, ix2) <- mapAndUnzipM shareI (toList ix)
return $! foldr (uncurry Ast.AstLet) (f $ fromList ix2)
(catMaybes bindings)

shareIxS :: -- (KnownShS shn, KnownShS shm)
AstIxS AstMethodLet shn -> (AstIxS AstMethodLet shn -> AstTensor AstMethodLet s (TKS shm r))
-> AstTensor AstMethodLet s (TKS shm r)
{-# NOINLINE shareIxS #-}
shareIxS ix f = f ix -- TODO

astGatherR
:: forall m n p s r.
(KnownNat m, KnownNat p, KnownNat n, GoodScalar r, AstSpan s)
Expand Down

0 comments on commit a2acdfb

Please sign in to comment.