Skip to content

Commit

Permalink
When pretty-printing, insert explicit rscalar applications
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Nov 27, 2024
1 parent 35e1d70 commit c8b707e
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 229 deletions.
12 changes: 9 additions & 3 deletions src/HordeAd/Core/AstPrettyPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,15 @@ printAstAux cfg d = \case
. printAst cfg 11 v
AstToShare v -> printAstAux cfg d v -- ignored
AstConcrete FTKScalar a -> shows a
AstConcrete (FTKR ZSR FTKScalar) a -> shows $ Nested.runScalar $ unRepN a
AstConcrete (FTKS ZSS FTKScalar) a -> shows $ Nested.sunScalar $ unRepN a
AstConcrete (FTKX ZSX FTKScalar) a -> shows $ Nested.munScalar $ unRepN a
AstConcrete (FTKR ZSR FTKScalar) a -> showParen (d > 10)
$ showString "rscalar "
. shows (Nested.runScalar $ unRepN a)
AstConcrete (FTKS ZSS FTKScalar) a -> showParen (d > 10)
$ showString "sscalar "
. shows (Nested.sunScalar $ unRepN a)
AstConcrete (FTKX ZSX FTKScalar) a -> showParen (d > 10)
$ showString "xscalar "
. shows (Nested.munScalar $ unRepN a)
AstConcrete ftk a -> showParen (d > 10)
$ showString ("tconcrete (" ++ show ftk ++ ") ")
. (showParen True
Expand Down
10 changes: 5 additions & 5 deletions test/SimplifiedOnlyTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ tests :: TestTree
tests =
testGroup "Tests for simplified horde-ad"
[ testGroup "Short_tests"
(TestGatherSimplified.testTrees
++ TestHighRankSimplified.testTrees
(TestAdaptorSimplified.testTrees
++ TestConvSimplified.testTrees
++ TestAdaptorSimplified.testTrees
++ TestGatherSimplified.testTrees
++ TestHighRankSimplified.testTrees
++ TestRevFwdFold.testTrees)
, testGroup "Neural_network_tests"
(TestMnistFCNNR.testTrees
++ TestMnistCNNR.testTrees
(TestMnistCNNR.testTrees
++ TestMnistFCNNR.testTrees
++ TestMnistRNNR.testTrees
++ TestMnistRNNS.testTrees)
]
180 changes: 104 additions & 76 deletions test/simplified/TestAdaptorSimplified.hs

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions test/simplified/TestConvSimplified.hs

Large diffs are not rendered by default.

44 changes: 22 additions & 22 deletions test/simplified/TestGatherSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -331,14 +331,14 @@ testGatherSimpPP23 = do
(t * rreplicate0N [6, 2] (rfromIndex0 i))))
$ AstVar (FTKR [6, 2] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t1) @?= 217
length (show (simplifyInline @(TKR 3 Float) t1)) @?= 619
length (show (simplifyInline @(TKR 3 Float) t1)) @?= 695
resetVarCounter
let !t2 = (\t -> rbuild1 4 (\i ->
rreshape @(AstTensor AstMethodLet PrimalSpan) @Float @2 @2 [2, 6]
(t * rreplicate0N [6, 2] (rfromIndex0 i))))
$ AstVar (FTKR [6, 2] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t2) @?= 217
length (show (simplifyInline @(TKR 3 Float) t2)) @?= 619
length (show (simplifyInline @(TKR 3 Float) t2)) @?= 695

-- Depending on if and how transpose it desugared, this may or may not result
-- in dozens of nested gathers that should vanish after simplification.
Expand Down Expand Up @@ -450,31 +450,31 @@ testGatherSimpPP33 = do
resetVarCounter
let !t1 = gatherTranspose33 @(AstTensor AstMethodLet PrimalSpan)
$ AstVar (FTKR [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t1) @?= 591
length (show (simplifyInline @(TKR 2 Float) t1)) @?= 591
length (show t1) @?= 614
length (show (simplifyInline @(TKR 2 Float) t1)) @?= 614
resetVarCounter
let !t2 = (\t -> rmatmul2 (rreshape [6, 8] (rconcrete $ unRepN t48))
(rreshape @(AstTensor AstMethodLet PrimalSpan) @Float @10 [8, 16] t))
$ AstVar (FTKR [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t2) @?= 510
length (show (simplifyInline @(TKR 2 Float) t2)) @?= 510
length (show t2) @?= 533
length (show (simplifyInline @(TKR 2 Float) t2)) @?= 533

testGatherSimpPP34 :: Assertion
testGatherSimpPP34 = do
resetVarCounter
let !t1 = (\t -> rbuild1 4 (\i ->
gatherTranspose33 @(AstTensor AstMethodLet PrimalSpan) (t * rreplicate0N [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] (rfromIndex0 i))))
$ AstVar (FTKR [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t1) @?= 936
length (show (simplifyInline @(TKR 3 Float) t1)) @?= 936
length (show t1) @?= 959
length (show (simplifyInline @(TKR 3 Float) t1)) @?= 959
resetVarCounter
let !t2 = (\t -> rbuild1 4 (\i ->
(\t' -> rmatmul2 (rreshape [6, 8] (rconcrete $ unRepN t48))
(rreshape @(AstTensor AstMethodLet PrimalSpan) @Float @10 [8, 16] t'))
(t * rreplicate0N [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] (rfromIndex0 i))))
$ AstVar (FTKR [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t2) @?= 689
length (show (simplifyInline @(TKR 3 Float) t2)) @?= 689
length (show t2) @?= 712
length (show (simplifyInline @(TKR 3 Float) t2)) @?= 712

-- scatters instead of gathers

Expand Down Expand Up @@ -534,12 +534,12 @@ testScatterSimpPP1 :: Assertion
testScatterSimpPP1 = do
resetVarCounter
let !t1 = scatterNested1 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (FTKR [7, 2] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t1) @?= 371
length (show t1) @?= 390
resetVarCounter
let !t2 = scatter1 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (FTKR [7, 2] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t2) @?= 556
length (show (simplifyInline @(TKR 1 Float) t1)) @?= 371
length (show (simplifyInline @(TKR 1 Float) t2)) @?= 556
length (show t2) @?= 632
length (show (simplifyInline @(TKR 1 Float) t1)) @?= 390
length (show (simplifyInline @(TKR 1 Float) t2)) @?= 632

scatterNested2 :: forall target r. (ADReady target, GoodScalar r)
=> target (TKR 2 r) -> target (TKR 2 r)
Expand Down Expand Up @@ -600,12 +600,12 @@ testScatterSimpPP2 :: Assertion
testScatterSimpPP2 = do
resetVarCounter
let !t1 = scatterNested2 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (FTKR [7, 2] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t1) @?= 1470
length (show t1) @?= 1660
resetVarCounter
let !t2 = scatter2 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (FTKR [7, 2] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t2) @?= 839
length (show (simplifyInline @(TKR 2 Float) t1)) @?= 1470
length (show (simplifyInline @(TKR 2 Float) t2)) @?= 839
length (show t2) @?= 915
length (show (simplifyInline @(TKR 2 Float) t1)) @?= 1660
length (show (simplifyInline @(TKR 2 Float) t2)) @?= 915

scatterNested12 :: forall target r. (ADReady target, GoodScalar r)
=> target (TKR 2 r) -> target (TKR 2 r)
Expand Down Expand Up @@ -668,9 +668,9 @@ testScatterSimpPP12 :: Assertion
testScatterSimpPP12 = do
resetVarCounter
let !t1 = scatterNested12 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (FTKR [7, 2] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t1) @?= 1246
length (show t1) @?= 1398
resetVarCounter
let !t2 = scatter12 @(AstTensor AstMethodLet PrimalSpan) $ AstVar (FTKR [7, 2] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t2) @?= 839
length (show (simplifyInline @(TKR 2 Float) t1)) @?= 1246
length (show (simplifyInline @(TKR 2 Float) t2)) @?= 839
length (show t2) @?= 915
length (show (simplifyInline @(TKR 2 Float) t1)) @?= 1398
length (show (simplifyInline @(TKR 2 Float) t2)) @?= 915
29 changes: 0 additions & 29 deletions test/simplified/TestHighRankSimplified.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import Data.Array.Nested
import Data.Array.Nested qualified as Nested

import HordeAd
import HordeAd.Core.AstFreshId (resetVarCounter)
import HordeAd.Util.ShapedList qualified as ShapedList

import CrossTesting
Expand Down Expand Up @@ -66,8 +65,6 @@ testTrees =
, testCase "3barReluADValDt2" testBarReluADValDt2
, testCase "3barReluADVal" testBarReluADVal
, testCase "3barReluADVal3" testBarReluADVal3
, testCase "3reluSimpPP" testReluSimpPP
, testCase "3barReluADVal320" testBarReluADVal320
, testCase "3braidedBuilds" testBraidedBuilds
, testCase "3braidedBuilds1" testBraidedBuilds1
, testCase "3recycled" testRecycled
Expand Down Expand Up @@ -533,32 +530,6 @@ testBarReluADVal3 =
(rev' @Double @7 barRelu
(rmap0N (* (rscalar 0.001)) t48))

barRelu10xSlower
:: ( ADReady target, GoodScalar r, KnownNat n, Differentiable r )
=> target (TKR n r) -> target (TKR n r)
barRelu10xSlower x = let t = rmap0N (* rscalar 0.001) x
in relu $ bar (t, relu t)

testReluSimpPP :: Assertion
testReluSimpPP = do
resetVarCounter
let !t1 = barRelu10xSlower @(AstTensor AstMethodLet PrimalSpan)
$ AstVar (FTKR [1,2,2,1,2,2,2,2,2,1] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t1) @?= 15152
length (show (simplifyInline @(TKR 10 Float) t1)) @?= 15152
resetVarCounter
let !t2 = barRelu @(AstTensor AstMethodLet PrimalSpan)
$ AstVar (FTKR [1,2,2,1,2,2,2,2,2,1] FTKScalar) (mkAstVarName . intToAstVarId $ 100000000)
length (show t2) @?= 12144
length (show (simplifyInline @(TKR 10 Float) t2)) @?= 15152

testBarReluADVal320 :: Assertion
testBarReluADVal320 =
assertEqualUpToEpsilonShort 1e-10
(ringestData [1,2,2,1,2,2,2,2,2,1] [2.885038541771792e-4,2.885145151321922e-4,2.8854294397024206e-4,2.885034988157713e-4,2.885923176600045e-4,2.887454843457817e-4,2.886097295122454e-4,2.8846476339094805e-4,2.885038541771792e-4,2.885145151321922e-4,2.8854294397024206e-4,2.8851415976532735e-4,2.885923176600045e-4,2.887454843457817e-4,2.8849246223035154e-4,2.884182085399516e-4,2.884075468755327e-4,2.8842176240868867e-4,2.8840399312321096e-4,0.0,2.887454843457817e-4,2.886097295122454e-4,2.887454843457817e-4,2.88599069218435e-4,2.887454843457817e-4,2.886097295122454e-4,2.8846476339094805e-4,2.885038541771792e-4,2.885145151321922e-4,2.8854294397024206e-4,2.885145151321922e-4,2.885145151321922e-4,2.8854294397024206e-4,2.8858878438222746e-4,2.885923176600045e-4,0.0,2.884007943794131e-4,0.0,2.884469945274759e-4,2.8843242392031246e-4,2.884288700806792e-4,0.0,2.885034988157713e-4,2.884110805753153e-4,0.0,2.8849283778617973e-4,2.884075468755327e-4,2.884075468755327e-4,2.884075468755327e-4,2.884075468755327e-4,0.0,0.0,0.0,0.0,2.884892851579934e-4,2.884892851579934e-4,2.884892851579934e-4,2.884892851579934e-4,0.0,0.0,0.0,0.0,2.884892851579934e-4,2.884892851579934e-4,2.884892851579934e-4,2.884892851579934e-4,2.8854294397024206e-4,2.884288700806792e-4,2.884395315486472e-4,0.0,2.8849246223035154e-4,2.8850276789489724e-4,0.0,2.8849212704517413e-4,2.8854294397024206e-4,2.884288700806792e-4,2.884395315486472e-4,0.0,2.8849246223035154e-4,2.8850276789489724e-4,0.0,2.8849212704517413e-4,2.8842922547482884e-4,2.885038541771792e-4,2.885145151321922e-4,2.8854294397024206e-4,2.885145151321922e-4,2.8854294397024206e-4,2.894378297730782e-4,2.885923176600045e-4,2.887454843457817e-4,2.88599069218435e-4,2.887454843457817e-4,2.887056688523444e-4,2.887454843457817e-4,2.887056688523444e-4,2.8846476339094805e-4,2.885038541771792e-4,2.885145151321922e-4,2.8854294397024206e-4,2.885145151321922e-4,2.8854294397024206e-4,2.885145151321922e-4,2.8854294397024206e-4,2.884786229769816e-4,2.885923176600045e-4,2.887454843457817e-4,2.886950092188272e-4,2.887454843457817e-4,2.884818011261814e-4,2.887454843457817e-4,2.886097295122454e-4,2.8846476339094805e-4,2.885038541771792e-4,2.885145151321922e-4,2.8854294397024206e-4,2.885145151321922e-4,2.8854294397024206e-4,2.885145151321922e-4,2.8854294397024206e-4,2.887167039107226e-4,2.885923176600045e-4,2.887454843457817e-4,2.8860262265516213e-4,2.887454843457817e-4,2.885884088500461e-4,2.887454843457817e-4,2.88599069218435e-4])
(rev' @Double @10 barRelu10xSlower
(rmap0N (* (rscalar 0.001)) t128))

braidedBuilds :: forall target n r. (ADReady target, GoodScalar r, KnownNat n, Differentiable r)
=> target (TKR (1 + n) r) -> target (TKR 2 r)
braidedBuilds r =
Expand Down
Loading

0 comments on commit c8b707e

Please sign in to comment.