Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support conditionals #185

Merged
merged 28 commits into from
Nov 1, 2024
Merged

feat: support conditionals #185

merged 28 commits into from
Nov 1, 2024

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Oct 19, 2024

using Reactant

function conditional_fn(x)
    v = sum(x)
    ff = -1.0
    @trace if v > 0
        z = 1.0
        v2 = sum(x)
        x = x .+ ff
    # elseif v > -2.0
    #     z = 2.0
    # elseif v > -4.0
        # z = 3.0
    else
        z = 0.0
        v2 = -1.0
    end
    return z
end

conditional_fn(rand(2, 3))

TODOs

  • general variable tracking that can be used for loops later on
  • rewrite the julia expression into something that can be traced
  • tracing overload for ControlFlow.traced_if
  • can we rewrite it to an expression without perf penalty ?? Currently leads to Boxing
  • support elseif blocks
  • different variables defined in each block
    • nice to have but additionally helps support blocks without an else
  • tests
  • move into ReactantCore?
  • CI scripts to install ReactantCore

@avik-pal

This comment was marked as outdated.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 9ab07db Previous: babeb7c Ratio
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1255240093 ns 1418797944 ns 0.88
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1248946910 ns 1230657063 ns 1.01
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1214828214 ns 1210055514 ns 1.00
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 2293717092 ns 2321453182 ns 0.99
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux 217679848 ns 215031968 ns 1.01
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 6955740554 ns 5458708327 ns 1.27
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5427479279 ns 5179301625 ns 1.05
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 5156546580 ns 5152065959 ns 1.00
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 6875151231 ns 6914653384 ns 0.99
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 31688836759 ns 29634509034 ns 1.07
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1238682669 ns 1303933391 ns 0.95
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1228539126 ns 1288941570.5 ns 0.95
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1185066846.5 ns 1246884488 ns 0.95
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 2478959292 ns 2588209027 ns 0.96
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux 8594471 ns 8825930 ns 0.97
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1626866559 ns 1637260762 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1656544289 ns 1607338067 ns 1.03
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1694371308 ns 1592753746 ns 1.06
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 2819361701 ns 2888392716 ns 0.98
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 2489135988 ns 2959415354 ns 0.84
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1330343963 ns 1320589513 ns 1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1344203917 ns 1232647002.5 ns 1.09
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1330861556.5 ns 1233197730.5 ns 1.08
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 2415936657 ns 2510219663 ns 0.96
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux 22524604 ns 22686905 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2168532504 ns 2195365921 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2169097763 ns 2173148463 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2154109260 ns 2160517237 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 3420971865 ns 3389252115 ns 1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 7540417154.5 ns 5458754250.5 ns 1.38
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1364552764 ns 1336344147 ns 1.02
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1246917061.5 ns 1284165465.5 ns 0.97
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1220403778 ns 1264413606 ns 0.97
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 2559847580 ns 2388755659 ns 1.07
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux 7050323 ns 7116389 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1436903713.5 ns 1494584041 ns 0.96
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1422707192 ns 1490742502 ns 0.95
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1424497991 ns 1473980569 ns 0.97
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 2668906211 ns 2796807816 ns 0.95
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1243275774 ns 1669183460 ns 0.74
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1234161204.5 ns 1220367359.5 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1284924628.5 ns 1264274640.5 ns 1.02
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1253249271 ns 1345724410.5 ns 0.93
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 2420194943 ns 2566724316 ns 0.94
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux 12120447 ns 12278807 ns 0.99
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1727430271 ns 1777269725 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1717646566 ns 1763977334 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1711914668 ns 1773537556 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 2921070609 ns 3105794746 ns 0.94
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 3123095202 ns 3076042064.5 ns 1.02
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1273894038 ns 1271346742 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1289592679.5 ns 1246562750 ns 1.03
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1246482766 ns 1309043330 ns 0.95
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 2533335173 ns 2442642621 ns 1.04
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux 27297511 ns 27314834 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2175906741 ns 2242544865 ns 0.97
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2170417538 ns 2216501128 ns 0.98
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2179733438 ns 2196805969 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3392091960 ns 3556647163 ns 0.95
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 5811384324 ns 5559034960 ns 1.05
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1215072309 ns 1242324757 ns 0.98
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1211762639.5 ns 1298352031 ns 0.93
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1237794082 ns 1230035861 ns 1.01
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 2487677602 ns 2637986128 ns 0.94
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux 52668259.5 ns 52652664 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 3014647673 ns 3060315768 ns 0.99
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 3034207527 ns 3106069884 ns 0.98
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 3029060826 ns 3053865991 ns 0.99
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 4315985968 ns 4567618226 ns 0.94
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 12286351744 ns 9483960261 ns 1.30
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1211908977 ns 1231844420 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1209609157 ns 1232961844 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1314913978 ns 1246778659.5 ns 1.05
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 2497022835 ns 2387803469 ns 1.05
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux 70873199 ns 70768943 ns 1.00
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3143712998 ns 3264057828 ns 0.96
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3158645011 ns 3289278023 ns 0.96
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3185501103 ns 3264052473 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 4493364988 ns 4733831239 ns 0.95
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 11851265767 ns 10856363466 ns 1.09
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1179042431 ns 1204313355 ns 0.98
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1199915158 ns 1195727515.5 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1192754860 ns 1220121107.5 ns 0.98
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 2336617348 ns 2407152921 ns 0.97
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux 20590451 ns 20638923 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1882994217 ns 1946630937 ns 0.97
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1873533211 ns 1945684183 ns 0.96
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1818163076 ns 1942898142 ns 0.94
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3080332047 ns 3272416211 ns 0.94
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 3464042438.5 ns 3630039016.5 ns 0.95

This comment was automatically generated by workflow using github-action-benchmark.

Copy link
Contributor

github-actions bot commented Oct 19, 2024

Benchmark Results

main 03aa1e1... main/03aa1e18c78859...
comptime/NN/ViT base (optimize = :after_enzyme) 7.76 s 7.1 s 1.09
comptime/NN/ViT base (optimize = :all) 7.24 s 7.05 s 1.03
comptime/NN/ViT base (optimize = :before_enzyme) 7.16 s 6.83 s 1.05
comptime/NN/ViT base (optimize = :only_enzyme) 7.93 s 7.16 s 1.11
comptime/NN/ViT tiny (optimize = :after_enzyme) 6.4 s 6.21 s 1.03
comptime/NN/ViT tiny (optimize = :all) 6.81 s 6.4 s 1.06
comptime/NN/ViT tiny (optimize = :before_enzyme) 6.29 s 6.35 s 0.991
comptime/NN/ViT tiny (optimize = :only_enzyme) 6.45 s 6.34 s 1.02
comptime/NN/vgg11 bn=false (optimize = :after_enzyme) 0.455 ± 0.043 s 0.419 ± 0.036 s 1.09
comptime/NN/vgg11 bn=false (optimize = :all) 0.444 ± 0.026 s 0.425 ± 0.036 s 1.05
comptime/NN/vgg11 bn=false (optimize = :before_enzyme) 0.447 ± 0.033 s 0.461 ± 0.05 s 0.968
comptime/NN/vgg11 bn=false (optimize = :only_enzyme) 0.412 ± 0.014 s 0.407 ± 0.052 s 1.01
comptime/NN/vgg11 bn=true (optimize = :after_enzyme) 1.2 ± 0.034 s 1.11 ± 0.0083 s 1.08
comptime/NN/vgg11 bn=true (optimize = :all) 1.13 ± 0.026 s 1.13 ± 0.027 s 1
comptime/NN/vgg11 bn=true (optimize = :before_enzyme) 1.15 ± 0.029 s 1.17 ± 0.19 s 0.988
comptime/NN/vgg11 bn=true (optimize = :only_enzyme) 1.09 ± 0.039 s 1.15 ± 0.0069 s 0.946
comptime/NN/vgg13 bn=false (optimize = :after_enzyme) 0.55 ± 0.045 s 0.486 ± 0.053 s 1.13
comptime/NN/vgg13 bn=false (optimize = :all) 0.508 ± 0.029 s 0.485 ± 0.03 s 1.05
comptime/NN/vgg13 bn=false (optimize = :before_enzyme) 0.535 ± 0.036 s 0.512 ± 0.015 s 1.05
comptime/NN/vgg13 bn=false (optimize = :only_enzyme) 0.479 ± 0.0099 s 0.489 ± 0.029 s 0.978
comptime/NN/vgg13 bn=true (optimize = :after_enzyme) 1.35 ± 0.065 s 1.31 ± 0.0046 s 1.03
comptime/NN/vgg13 bn=true (optimize = :all) 1.32 ± 0.0075 s 1.34 ± 0.023 s 0.987
comptime/NN/vgg13 bn=true (optimize = :before_enzyme) 1.37 ± 0.0013 s 1.46 ± 0.029 s 0.937
comptime/NN/vgg13 bn=true (optimize = :only_enzyme) 1.33 ± 0.029 s 1.31 ± 0.0091 s 1.02
comptime/NN/vgg16 bn=false (optimize = :after_enzyme) 0.545 ± 0.016 s 0.558 ± 0.0087 s 0.977
comptime/NN/vgg16 bn=false (optimize = :all) 0.617 ± 0.058 s 0.571 ± 0.0081 s 1.08
comptime/NN/vgg16 bn=false (optimize = :before_enzyme) 0.555 ± 0.0081 s 0.578 ± 0.018 s 0.96
comptime/NN/vgg16 bn=false (optimize = :only_enzyme) 0.587 ± 0.026 s 0.569 ± 0.043 s 1.03
comptime/NN/vgg16 bn=true (optimize = :after_enzyme) 1.68 ± 0.055 s 1.67 ± 0.025 s 1
comptime/NN/vgg16 bn=true (optimize = :all) 1.67 ± 0.0026 s 1.71 ± 0.012 s 0.976
comptime/NN/vgg16 bn=true (optimize = :before_enzyme) 1.77 ± 0.00089 s 1.8 ± 0.024 s 0.985
comptime/NN/vgg16 bn=true (optimize = :only_enzyme) 1.92 ± 0.016 s 1.72 ± 0.063 s 1.11
comptime/NN/vgg19 bn=false (optimize = :after_enzyme) 0.666 ± 0.024 s 0.635 ± 0.053 s 1.05
comptime/NN/vgg19 bn=false (optimize = :all) 0.786 ± 0.031 s 0.655 ± 0.0053 s 1.2
comptime/NN/vgg19 bn=false (optimize = :before_enzyme) 0.652 ± 0.0082 s 0.646 ± 0.0057 s 1.01
comptime/NN/vgg19 bn=false (optimize = :only_enzyme) 0.642 ± 0.036 s 0.65 ± 0.02 s 0.988
comptime/NN/vgg19 bn=true (optimize = :after_enzyme) 2.04 ± 0.0022 s 2.01 ± 0.017 s 1.01
comptime/NN/vgg19 bn=true (optimize = :all) 2.21 ± 0.0079 s 2.06 ± 0.041 s 1.07
comptime/NN/vgg19 bn=true (optimize = :before_enzyme) 2.23 ± 0.0077 s 2.13 ± 0.025 s 1.05
comptime/NN/vgg19 bn=true (optimize = :only_enzyme) 2 s 2.02 ± 0.011 s 0.991
comptime/basics/2D sum (optimize = :after_enzyme) 28 ± 1 ms 27.4 ± 1.2 ms 1.02
comptime/basics/2D sum (optimize = :all) 0.032 ± 0.0041 s 0.0317 ± 0.0027 s 1.01
comptime/basics/2D sum (optimize = :before_enzyme) 30.4 ± 1.3 ms 29.3 ± 1.1 ms 1.04
comptime/basics/2D sum (optimize = :only_enzyme) 23.5 ± 0.76 ms 24.5 ± 2.1 ms 0.959
comptime/basics/cos.(x) (optimize = :after_enzyme) 0.033 ± 0.0009 s 0.0332 ± 0.002 s 0.994
comptime/basics/cos.(x) (optimize = :all) 0.0367 ± 0.00073 s 0.0364 ± 0.0014 s 1.01
comptime/basics/cos.(x) (optimize = :before_enzyme) 0.0369 ± 0.0032 s 0.0355 ± 0.0039 s 1.04
comptime/basics/cos.(x) (optimize = :only_enzyme) 31.2 ± 2.7 ms 0.0325 ± 0.0034 s 0.961
comptime/basics/∇cos (optimize = :all) 0.0519 ± 0.0016 s 0.0528 ± 0.0036 s 0.984
runtime/NN/ViT base (optimize = :after_enzyme) 6.33 s 6.42 s 0.986
runtime/NN/ViT base (optimize = :all) 6.31 s 6.3 s 1
runtime/NN/ViT base (optimize = :before_enzyme) 6.3 s 6.35 s 0.992
runtime/NN/ViT base (optimize = :only_enzyme) 7.54 s 7.65 s 0.986
runtime/NN/ViT tiny (optimize = :after_enzyme) 1.68 s 1.65 s 1.02
runtime/NN/ViT tiny (optimize = :all) 1.72 s 1.69 s 1.02
runtime/NN/ViT tiny (optimize = :before_enzyme) 1.65 s 1.61 s 1.03
runtime/NN/ViT tiny (optimize = :only_enzyme) 2.67 s 2.63 s 1.02
runtime/NN/vgg11 bn=false (optimize = :after_enzyme) 2.2 s 2.13 s 1.04
runtime/NN/vgg11 bn=false (optimize = :all) 2.17 s 2.19 s 0.99
runtime/NN/vgg11 bn=false (optimize = :before_enzyme) 2.15 s 2.18 s 0.986
runtime/NN/vgg11 bn=false (optimize = :only_enzyme) 1.99 s 2.01 s 0.992
runtime/NN/vgg11 bn=true (optimize = :after_enzyme) 2.36 s 2.34 s 1.01
runtime/NN/vgg11 bn=true (optimize = :all) 2.3 s 2.32 s 0.99
runtime/NN/vgg11 bn=true (optimize = :before_enzyme) 2.34 s 2.33 s 1.01
runtime/NN/vgg11 bn=true (optimize = :only_enzyme) 2.39 s 2.44 s 0.979
runtime/NN/vgg13 bn=false (optimize = :after_enzyme) 3.03 s 2.95 s 1.03
runtime/NN/vgg13 bn=false (optimize = :all) 3.17 s 3.02 s 1.05
runtime/NN/vgg13 bn=false (optimize = :before_enzyme) 3.1 s 2.99 s 1.04
runtime/NN/vgg13 bn=false (optimize = :only_enzyme) 2.9 s 2.87 s 1.01
runtime/NN/vgg13 bn=true (optimize = :after_enzyme) 3.26 s 3.31 s 0.985
runtime/NN/vgg13 bn=true (optimize = :all) 3.31 s 3.3 s 1
runtime/NN/vgg13 bn=true (optimize = :before_enzyme) 3.17 s 3.22 s 0.984
runtime/NN/vgg13 bn=true (optimize = :only_enzyme) 3.47 s 3.51 s 0.988
runtime/NN/vgg16 bn=false (optimize = :after_enzyme) 3.82 s 3.82 s 0.998
runtime/NN/vgg16 bn=false (optimize = :all) 3.91 s 3.85 s 1.01
runtime/NN/vgg16 bn=false (optimize = :before_enzyme) 3.74 s 3.79 s 0.986
runtime/NN/vgg16 bn=false (optimize = :only_enzyme) 3.75 s 3.77 s 0.995
runtime/NN/vgg16 bn=true (optimize = :after_enzyme) 4.15 s 4.22 s 0.983
runtime/NN/vgg16 bn=true (optimize = :all) 4.19 s 4.13 s 1.01
runtime/NN/vgg16 bn=true (optimize = :before_enzyme) 4.24 s 4.22 s 1.01
runtime/NN/vgg16 bn=true (optimize = :only_enzyme) 4.62 s 4.39 s 1.05
runtime/NN/vgg19 bn=false (optimize = :after_enzyme) 4.72 s 4.66 s 1.01
runtime/NN/vgg19 bn=false (optimize = :all) 4.71 s 4.6 s 1.02
runtime/NN/vgg19 bn=false (optimize = :before_enzyme) 4.61 s 4.51 s 1.02
runtime/NN/vgg19 bn=false (optimize = :only_enzyme) 4.49 s 4.48 s 1
runtime/NN/vgg19 bn=true (optimize = :after_enzyme) 5.09 s 5.11 s 0.996
runtime/NN/vgg19 bn=true (optimize = :all) 5.22 s 5.18 s 1.01
runtime/NN/vgg19 bn=true (optimize = :before_enzyme) 5.21 s 5.34 s 0.977
runtime/NN/vgg19 bn=true (optimize = :only_enzyme) 5.64 s 5.58 s 1.01
time_to_load 1.95 ± 0.043 s 2 ± 0.065 s 0.974

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@avik-pal
Copy link
Collaborator Author

using Reactant

function conditional_fn(x)
    v = sum(x)
    ff = -1.0
    @trace if v > 0
        z = 1.0
        v2 = sum(x)
        # x = x .+ ff
    # elseif v > -2.0
    #     z = 2.0
    # elseif v > -4.0
        # z = 3.0
    else
        z = 0.0
        v2 = -sum(x)
    end
    return z, v2
end

conditional_fn(rand(2, 3))

x = rand(2, 3)

x_ra = Reactant.to_rarray(x)

conditional_fn_compiled = @compile conditional_fn(x_ra)
Error
free(): invalid pointer

[970911] signal 6 (-6): Aborted
in expression starting at REPL[8]:1
unknown function (ip: 0x7b2c897013f4)
gsignal at /usr/lib/libc.so.6 (unknown line)
abort at /usr/lib/libc.so.6 (unknown line)
unknown function (ip: 0x7b2c89690353)
unknown function (ip: 0x7b2c8970b764)
unknown function (ip: 0x7b2c8970dc4b)
__libc_free at /usr/lib/libc.so.6 (unknown line)
_ZN4mlir14OperationStateD1Ev at /home/avikpal/.julia/artifacts/683e7a1671039feec0b9faea91812e8b4e08ded5/lib/libReactantExtra.so (unknown line)
mlirOperationCreate at /home/avikpal/.julia/artifacts/683e7a1671039feec0b9faea91812e8b4e08ded5/lib/libReactantExtra.so (unknown line)
mlirOperationCreate at /mnt/software/lux/Reactant.jl/src/mlir/libMLIR_h.jl:985 [inlined]
#create_operation#56 at /mnt/software/lux/Reactant.jl/src/mlir/IR/Operation.jl:315
create_operation at /mnt/software/lux/Reactant.jl/src/mlir/IR/Operation.jl:273 [inlined]
#if_#53 at /mnt/software/lux/Reactant.jl/src/mlir/Dialects/StableHLO.jl:2287
if_ at /mnt/software/lux/Reactant.jl/src/mlir/Dialects/StableHLO.jl:2274 [inlined]
traced_if at /mnt/software/lux/Reactant.jl/src/ControlFlow.jl:127
unknown function (ip: 0x7b2c57f34986)
macro expansion at /mnt/software/lux/Reactant.jl/src/ControlFlow.jl:70 [inlined]
conditional_fn at ./REPL[2]:4 [inlined]
opaque closure at ./<missing>:0
unknown function (ip: 0x7b2c57f33832)
#26 at /mnt/software/lux/Reactant.jl/src/utils.jl:113
block! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
unknown function (ip: 0x7b2c6dbd1d96)
#make_mlir_fn#20 at /mnt/software/lux/Reactant.jl/src/utils.jl:81
make_mlir_fn at /mnt/software/lux/Reactant.jl/src/utils.jl:30 [inlined]
#6 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:260 [inlined]
block! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
#5 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:259 [inlined]
mmodule! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Module.jl:93
unknown function (ip: 0x7b2c6db5dcf6)
#compile_mlir!#4 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:256
compile_mlir! at /mnt/software/lux/Reactant.jl/src/Compiler.jl:255 [inlined]
#30 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:636
context! at /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:71
unknown function (ip: 0x7b2c6db56756)
#compile_xla#29 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:633
compile_xla at /mnt/software/lux/Reactant.jl/src/Compiler.jl:627 [inlined]
#compile#34 at /mnt/software/lux/Reactant.jl/src/Compiler.jl:660
compile at /mnt/software/lux/Reactant.jl/src/Compiler.jl:659
unknown function (ip: 0x7b2c6db4f5ed)
jl_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/julia.h:2157 [inlined]
do_call at /cache/build/builder-demeter6-6/julialang/julia-master/src/interpreter.c:126
eval_value at /cache/build/builder-demeter6-6/julialang/julia-master/src/interpreter.c:223
eval_stmt_value at /cache/build/builder-demeter6-6/julialang/julia-master/src/interpreter.c:174 [inlined]
eval_body at /cache/build/builder-demeter6-6/julialang/julia-master/src/interpreter.c:663
jl_interpret_toplevel_thunk at /cache/build/builder-demeter6-6/julialang/julia-master/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/builder-demeter6-6/julialang/julia-master/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/builder-demeter6-6/julialang/julia-master/src/toplevel.c:886
eval_body at /cache/build/builder-demeter6-6/julialang/julia-master/src/interpreter.c:625
eval_body at /cache/build/builder-demeter6-6/julialang/julia-master/src/interpreter.c:539
jl_interpret_toplevel_thunk at /cache/build/builder-demeter6-6/julialang/julia-master/src/interpreter.c:821
jl_toplevel_eval_flex at /cache/build/builder-demeter6-6/julialang/julia-master/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/builder-demeter6-6/julialang/julia-master/src/toplevel.c:886
jl_toplevel_eval_flex at /cache/build/builder-demeter6-6/julialang/julia-master/src/toplevel.c:886
ijl_toplevel_eval_in at /cache/build/builder-demeter6-6/julialang/julia-master/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
eval_user_input at /cache/build/builder-demeter6-6/julialang/julia-master/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:245
repl_backend_loop at /cache/build/builder-demeter6-6/julialang/julia-master/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:342
#start_repl_backend#59 at /cache/build/builder-demeter6-6/julialang/julia-master/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:327
start_repl_backend at /cache/build/builder-demeter6-6/julialang/julia-master/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:324
#run_repl#72 at /cache/build/builder-demeter6-6/julialang/julia-master/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:483
run_repl at /cache/build/builder-demeter6-6/julialang/julia-master/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:469
jfptr_run_repl_10088 at /home/avikpal/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/compiled/v1.11/REPL/u0gqU_GYsA8.so (unknown line)
#1139 at ./client.jl:446
jfptr_YY.1139_14649 at /home/avikpal/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/share/julia/compiled/v1.11/REPL/u0gqU_GYsA8.so (unknown line)
jl_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/julia.h:2157 [inlined]
jl_f__call_latest at /cache/build/builder-demeter6-6/julialang/julia-master/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1055 [inlined]
invokelatest at ./essentials.jl:1052 [inlined]
run_main_repl at ./client.jl:430
repl_main at ./client.jl:567 [inlined]
_start at ./client.jl:541
jfptr__start_72144.1 at /home/avikpal/.julia/juliaup/julia-1.11.1+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
jl_apply at /cache/build/builder-demeter6-6/julialang/julia-master/src/julia.h:2157 [inlined]
true_main at /cache/build/builder-demeter6-6/julialang/julia-master/src/jlapi.c:900
jl_repl_entrypoint at /cache/build/builder-demeter6-6/julialang/julia-master/src/jlapi.c:1059
main at /cache/build/builder-demeter6-6/julialang/julia-master/cli/loader_exe.c:58
unknown function (ip: 0x7b2c89690e07)
__libc_start_main at /usr/lib/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 71349677 (Pool: 71347786; Big: 1891); GC: 47
[1]    970911 IOT instruction (core dumped)  julia --project=. --threads=12

Can I get some help in the MLIR operation part here? I am sure I am doing something illegal here

src/ControlFlow.jl Outdated Show resolved Hide resolved

macro trace(expr)
expr.head == :if && return esc(trace_if(__module__, expr))
return error("Only `if-elseif-else` blocks are currently supported by `@trace`")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should make sure to test [bare if end / if else end / if elseif end / if elseif else end] x [returning different values from the if / overwriting a local value (e.g. x = x + 2) / in place updating values (e.g. x[2] = 4) ]

@assert expr.head == :if
@assert length(expr.args) == 3 "`@trace` expects an `else` block for `if` blocks."
# XXX: support `elseif` blocks
@assert expr.args[3].head == :block "`elseif` blocks are not supported yet."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well above tests for what's supported at least xD

@wsmoses
Copy link
Member

wsmoses commented Oct 20, 2024 via email

src/ControlFlow.jl Outdated Show resolved Hide resolved
@EnzymeAD EnzymeAD deleted a comment from Traced Oct 21, 2024
@avik-pal
Copy link
Collaborator Author

using Reactant

function conditional_fn(x)
    v = sum(x)
    @trace if v > 0
        v2 = v
    else
        v2 = -v
    end
    return v2, x
end

conditional_fn(rand(2, 3))

x = rand(2, 3)

x_ra = Reactant.to_rarray(x)

conditional_fn_compiled = @compile conditional_fn(x_ra)
julia> @code_hlo optimize=false conditional_fn(x_ra)
Module:
"builtin.module"() ({
  "func.func"() <{function_type = (tensor<f64>) -> tensor<f64>, sym_name = "identity_broadcast_scalar", sym_visibility = "private"}> ({
  ^bb0(%arg5: tensor<f64>):
    %17 = "stablehlo.transpose"(%arg5) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
    %18 = "stablehlo.transpose"(%17) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
    "func.return"(%18) : (tensor<f64>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (tensor<f64>) -> tensor<f64>, sym_name = "##true_branch#320", sym_visibility = "private"}> ({
  }) : () -> ()
  "func.func"() <{function_type = (tensor<f64>) -> (tensor<f64>, tensor<f64>), sym_name = "##false_branch#321", sym_visibility = "private"}> ({
  }) : () -> ()
  "func.func"() <{function_type = (tensor<3x2xf64>) -> (tensor<f64>, tensor<3x2xf64>), sym_name = "main"}> ({
  ^bb0(%arg0: tensor<3x2xf64>):
    %0 = "stablehlo.transpose"(%arg0) <{permutation = array<i64: 1, 0>}> : (tensor<3x2xf64>) -> tensor<2x3xf64>
    %1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f64>}> : () -> tensor<f64>
    %2 = "stablehlo.broadcast_in_dim"(%0) <{broadcast_dimensions = array<i64: 0, 1>}> : (tensor<2x3xf64>) -> tensor<2x3xf64>
    %3 = "enzyme.batch"(%2) <{batch_shape = array<i64: 2, 3>, fn = @identity_broadcast_scalar}> : (tensor<2x3xf64>) -> tensor<2x3xf64>
    %4 = "stablehlo.reduce"(%3, %1) <{dimensions = array<i64: 0, 1>}> ({
    ^bb0(%arg3: tensor<f64>, %arg4: tensor<f64>):
      %16 = "stablehlo.add"(%arg3, %arg4) : (tensor<f64>, tensor<f64>) -> tensor<f64>
      "stablehlo.return"(%16) : (tensor<f64>) -> ()
    }) : (tensor<2x3xf64>, tensor<f64>) -> tensor<f64>
    %5 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f64>}> : () -> tensor<f64>
    %6 = "stablehlo.compare"(%4, %5) <{comparison_direction = #stablehlo<comparison_direction GT>}> : (tensor<f64>, tensor<f64>) -> tensor<i1>
    %7 = "stablehlo.if"(%6) ({
    ^bb0(%arg2: tensor<f64>):
      %14 = "stablehlo.transpose"(%arg2) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
      %15 = "stablehlo.transpose"(%14) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
      "stablehlo.return"(%15) : (tensor<f64>) -> ()
    }, {
    ^bb0(%arg1: tensor<f64>):
      %10 = "stablehlo.transpose"(%arg1) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
      %11 = "stablehlo.negate"(%10) : (tensor<f64>) -> tensor<f64>
      %12 = "stablehlo.transpose"(%11) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
      %13 = "stablehlo.transpose"(%10) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
      "stablehlo.return"(%12, %13) : (tensor<f64>, tensor<f64>) -> ()
    }) : (tensor<i1>) -> tensor<f64>
    %8 = "stablehlo.transpose"(%7) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
    %9 = "stablehlo.transpose"(%0) <{permutation = array<i64: 1, 0>}> : (tensor<2x3xf64>) -> tensor<3x2xf64>
    "func.return"(%8, %9) : (tensor<f64>, tensor<3x2xf64>) -> ()
  }) : () -> ()
}) : () -> ()

Doesn't yet work completely. The branches in stablehlo.if needs to take in 0 args so the generated code is currently incorrect

@wsmoses
Copy link
Member

wsmoses commented Oct 23, 2024

so this shouldn't be bad to fix (assuming that the true/false regions just accept the args from the parent -- which is notably not how while works (which takes them as blockarguments like you have setup here). Essentially we need to do a replace all users of the blocks with the actual values in the parent (there should be a replace api call in mlir.jl cc @mofeing) and then delete the now unused blockargs

@avik-pal
Copy link
Collaborator Author

Module:
"builtin.module"() ({
  "func.func"() <{function_type = (tensor<f64>) -> tensor<f64>, sym_name = "identity_broadcast_scalar", sym_visibility = "private"}> ({
  ^bb0(%arg5: tensor<f64>):
    %17 = "stablehlo.transpose"(%arg5) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
    %18 = "stablehlo.transpose"(%17) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
    "func.return"(%18) : (tensor<f64>) -> ()
  }) : () -> ()
  "func.func"() <{function_type = (tensor<f64>) -> tensor<f64>, sym_name = "##true_branch#282", sym_visibility = "private"}> ({
  }) : () -> ()
  "func.func"() <{function_type = (tensor<f64>) -> (tensor<f64>, tensor<f64>), sym_name = "##false_branch#283", sym_visibility = "private"}> ({
  }) : () -> ()
  "func.func"() <{function_type = (tensor<3x2xf64>) -> (tensor<f64>, tensor<3x2xf64>), sym_name = "main"}> ({
  ^bb0(%arg0: tensor<3x2xf64>):
    %0 = "stablehlo.transpose"(%arg0) <{permutation = array<i64: 1, 0>}> : (tensor<3x2xf64>) -> tensor<2x3xf64>
    %1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f64>}> : () -> tensor<f64>
    %2 = "stablehlo.broadcast_in_dim"(%0) <{broadcast_dimensions = array<i64: 0, 1>}> : (tensor<2x3xf64>) -> tensor<2x3xf64>
    %3 = "enzyme.batch"(%2) <{batch_shape = array<i64: 2, 3>, fn = @identity_broadcast_scalar}> : (tensor<2x3xf64>) -> tensor<2x3xf64>
    %4 = "stablehlo.reduce"(%3, %1) <{dimensions = array<i64: 0, 1>}> ({
    ^bb0(%arg3: tensor<f64>, %arg4: tensor<f64>):
      %16 = "stablehlo.add"(%arg3, %arg4) : (tensor<f64>, tensor<f64>) -> tensor<f64>
      "stablehlo.return"(%16) : (tensor<f64>) -> ()
    }) : (tensor<2x3xf64>, tensor<f64>) -> tensor<f64>
    %5 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f64>}> : () -> tensor<f64>
    %6 = "stablehlo.compare"(%4, %5) <{comparison_direction = #stablehlo<comparison_direction GT>}> : (tensor<f64>, tensor<f64>) -> tensor<i1>
    %7 = "stablehlo.if"(%6) ({
    ^bb0(%arg2: tensor<f64>):
      %14 = "stablehlo.transpose"(%arg2) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
      %15 = "stablehlo.transpose"(%4) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
      "stablehlo.return"(%15) : (tensor<f64>) -> ()
    }, {
    ^bb0(%arg1: tensor<f64>):
      %10 = "stablehlo.transpose"(%arg1) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
      %11 = "stablehlo.negate"(%4) : (tensor<f64>) -> tensor<f64>
      %12 = "stablehlo.transpose"(%11) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
      %13 = "stablehlo.transpose"(%4) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
      "stablehlo.return"(%12, %13) : (tensor<f64>, tensor<f64>) -> ()
    }) : (tensor<i1>) -> tensor<f64>
    %8 = "stablehlo.transpose"(%7) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
    %9 = "stablehlo.transpose"(%0) <{permutation = array<i64: 1, 0>}> : (tensor<2x3xf64>) -> tensor<3x2xf64>
    "func.return"(%8, %9) : (tensor<f64>, tensor<3x2xf64>) -> ()
  }) : () -> ()
}) : () -> ()

How do I delete the arguments? I don't see an equivalent to https://mlir.llvm.org/doxygen/classmlir_1_1Block.html#ab497ed8c6f52f6faa5e9a8ac0c6b1014 in reactant

@Pangoraw
Copy link
Collaborator

How do I delete the arguments?

yeah it seems that mlirBlockEraseArgument is only available in LLVM >= v19. We can try to clone operations in a new block without argument ?

@mofeing
Copy link
Collaborator

mofeing commented Oct 24, 2024

How do I delete the arguments?

yeah it seems that mlirBlockEraseArgument is only available in LLVM >= v19. We can try to clone operations in a new block without argument ?

Remember that Reactant uses LLVM 20 currently, because it has its own LLVM build with XLA.

@mofeing
Copy link
Collaborator

mofeing commented Oct 24, 2024

How do I delete the arguments?

yeah it seems that mlirBlockEraseArgument is only available in LLVM >= v19. We can try to clone operations in a new block without argument ?

Remember that Reactant uses LLVM 20-DEV currently, because it has its own LLVM build with XLA.

@Pangoraw
Copy link
Collaborator

oh nice, then we need to regenerate the LibMLIR file to get mlirBlockEraseArgument

@avik-pal avik-pal force-pushed the ap/compile_conditionals branch 6 times, most recently from 03aa1e1 to feea37e Compare October 29, 2024 15:12
@avik-pal
Copy link
Collaborator Author

avik-pal commented Nov 1, 2024

the setindex issue still needs fixing

@mofeing
Copy link
Collaborator

mofeing commented Nov 1, 2024

the setindex issue still needs fixing

If we merge this before fixing that, would you mind opening an issue so that we can keep track of that please?

@avik-pal
Copy link
Collaborator Author

avik-pal commented Nov 1, 2024

open #210 to track this, and #211 tracks compiling closures

@avik-pal avik-pal merged commit b6ee968 into main Nov 1, 2024
17 of 25 checks passed
@avik-pal avik-pal deleted the ap/compile_conditionals branch November 1, 2024 18:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants