-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support conditionals #185
Conversation
7d1efd3
to
5f91fcd
Compare
This comment was marked as outdated.
This comment was marked as outdated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reactant.jl Benchmarks
Benchmark suite | Current: 9ab07db | Previous: babeb7c | Ratio |
---|---|---|---|
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1255240093 ns |
1418797944 ns |
0.88 |
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant |
1248946910 ns |
1230657063 ns |
1.01 |
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1214828214 ns |
1210055514 ns |
1.00 |
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) |
2293717092 ns |
2321453182 ns |
0.99 |
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux |
217679848 ns |
215031968 ns |
1.01 |
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) |
6955740554 ns |
5458708327 ns |
1.27 |
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant |
5427479279 ns |
5179301625 ns |
1.05 |
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) |
5156546580 ns |
5152065959 ns |
1.00 |
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) |
6875151231 ns |
6914653384 ns |
0.99 |
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux |
31688836759 ns |
29634509034 ns |
1.07 |
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1238682669 ns |
1303933391 ns |
0.95 |
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant |
1228539126 ns |
1288941570.5 ns |
0.95 |
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1185066846.5 ns |
1246884488 ns |
0.95 |
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) |
2478959292 ns |
2588209027 ns |
0.96 |
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux |
8594471 ns |
8825930 ns |
0.97 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) |
1626866559 ns |
1637260762 ns |
0.99 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant |
1656544289 ns |
1607338067 ns |
1.03 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) |
1694371308 ns |
1592753746 ns |
1.06 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) |
2819361701 ns |
2888392716 ns |
0.98 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux |
2489135988 ns |
2959415354 ns |
0.84 |
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1330343963 ns |
1320589513 ns |
1.01 |
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant |
1344203917 ns |
1232647002.5 ns |
1.09 |
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1330861556.5 ns |
1233197730.5 ns |
1.08 |
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) |
2415936657 ns |
2510219663 ns |
0.96 |
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux |
22524604 ns |
22686905 ns |
0.99 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) |
2168532504 ns |
2195365921 ns |
0.99 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant |
2169097763 ns |
2173148463 ns |
1.00 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) |
2154109260 ns |
2160517237 ns |
1.00 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) |
3420971865 ns |
3389252115 ns |
1.01 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux |
7540417154.5 ns |
5458754250.5 ns |
1.38 |
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1364552764 ns |
1336344147 ns |
1.02 |
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant |
1246917061.5 ns |
1284165465.5 ns |
0.97 |
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1220403778 ns |
1264413606 ns |
0.97 |
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) |
2559847580 ns |
2388755659 ns |
1.07 |
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux |
7050323 ns |
7116389 ns |
0.99 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) |
1436903713.5 ns |
1494584041 ns |
0.96 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant |
1422707192 ns |
1490742502 ns |
0.95 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) |
1424497991 ns |
1473980569 ns |
0.97 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) |
2668906211 ns |
2796807816 ns |
0.95 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux |
1243275774 ns |
1669183460 ns |
0.74 |
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1234161204.5 ns |
1220367359.5 ns |
1.01 |
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant |
1284924628.5 ns |
1264274640.5 ns |
1.02 |
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1253249271 ns |
1345724410.5 ns |
0.93 |
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) |
2420194943 ns |
2566724316 ns |
0.94 |
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux |
12120447 ns |
12278807 ns |
0.99 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) |
1727430271 ns |
1777269725 ns |
0.97 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant |
1717646566 ns |
1763977334 ns |
0.97 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) |
1711914668 ns |
1773537556 ns |
0.97 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) |
2921070609 ns |
3105794746 ns |
0.94 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux |
3123095202 ns |
3076042064.5 ns |
1.02 |
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1273894038 ns |
1271346742 ns |
1.00 |
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant |
1289592679.5 ns |
1246562750 ns |
1.03 |
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1246482766 ns |
1309043330 ns |
0.95 |
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) |
2533335173 ns |
2442642621 ns |
1.04 |
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux |
27297511 ns |
27314834 ns |
1.00 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) |
2175906741 ns |
2242544865 ns |
0.97 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant |
2170417538 ns |
2216501128 ns |
0.98 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) |
2179733438 ns |
2196805969 ns |
0.99 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) |
3392091960 ns |
3556647163 ns |
0.95 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux |
5811384324 ns |
5559034960 ns |
1.05 |
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1215072309 ns |
1242324757 ns |
0.98 |
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant |
1211762639.5 ns |
1298352031 ns |
0.93 |
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1237794082 ns |
1230035861 ns |
1.01 |
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) |
2487677602 ns |
2637986128 ns |
0.94 |
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux |
52668259.5 ns |
52652664 ns |
1.00 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) |
3014647673 ns |
3060315768 ns |
0.99 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant |
3034207527 ns |
3106069884 ns |
0.98 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) |
3029060826 ns |
3053865991 ns |
0.99 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) |
4315985968 ns |
4567618226 ns |
0.94 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux |
12286351744 ns |
9483960261 ns |
1.30 |
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1211908977 ns |
1231844420 ns |
0.98 |
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant |
1209609157 ns |
1232961844 ns |
0.98 |
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1314913978 ns |
1246778659.5 ns |
1.05 |
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) |
2497022835 ns |
2387803469 ns |
1.05 |
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux |
70873199 ns |
70768943 ns |
1.00 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) |
3143712998 ns |
3264057828 ns |
0.96 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant |
3158645011 ns |
3289278023 ns |
0.96 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) |
3185501103 ns |
3264052473 ns |
0.98 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) |
4493364988 ns |
4733831239 ns |
0.95 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux |
11851265767 ns |
10856363466 ns |
1.09 |
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1179042431 ns |
1204313355 ns |
0.98 |
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant |
1199915158 ns |
1195727515.5 ns |
1.00 |
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1192754860 ns |
1220121107.5 ns |
0.98 |
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) |
2336617348 ns |
2407152921 ns |
0.97 |
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux |
20590451 ns |
20638923 ns |
1.00 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) |
1882994217 ns |
1946630937 ns |
0.97 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant |
1873533211 ns |
1945684183 ns |
0.96 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) |
1818163076 ns |
1942898142 ns |
0.94 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) |
3080332047 ns |
3272416211 ns |
0.94 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux |
3464042438.5 ns |
3630039016.5 ns |
0.95 |
This comment was automatically generated by workflow using github-action-benchmark.
f23a978
to
1a9cff0
Compare
Benchmark Results
Benchmark PlotsA plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR. |
using Reactant
function conditional_fn(x)
v = sum(x)
ff = -1.0
@trace if v > 0
z = 1.0
v2 = sum(x)
# x = x .+ ff
# elseif v > -2.0
# z = 2.0
# elseif v > -4.0
# z = 3.0
else
z = 0.0
v2 = -sum(x)
end
return z, v2
end
conditional_fn(rand(2, 3))
x = rand(2, 3)
x_ra = Reactant.to_rarray(x)
conditional_fn_compiled = @compile conditional_fn(x_ra) Error
Can I get some help in the MLIR operation part here? I am sure I am doing something illegal here |
e3c7069
to
cf99d6e
Compare
src/ControlFlow.jl
Outdated
|
||
macro trace(expr) | ||
expr.head == :if && return esc(trace_if(__module__, expr)) | ||
return error("Only `if-elseif-else` blocks are currently supported by `@trace`") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should make sure to test [bare if end / if else end / if elseif end / if elseif else end] x [returning different values from the if / overwriting a local value (e.g. x = x + 2) / in place updating values (e.g. x[2] = 4) ]
src/ControlFlow.jl
Outdated
@assert expr.head == :if | ||
@assert length(expr.args) == 3 "`@trace` expects an `else` block for `if` blocks." | ||
# XXX: support `elseif` blocks | ||
@assert expr.args[3].head == :block "`elseif` blocks are not supported yet." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well above tests for what's supported at least xD
Probably the solution is to not pass existing regions, but then to move the
child blocks from the old regions into the if
…On Sun, Oct 20, 2024 at 4:28 PM Avik Pal ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In src/ControlFlow.jl
<#185 (comment)>:
> +
+# Generate this dummy function and later we remove it during tracing
+function traced_if(cond, true_fn::TFn, false_fn::FFn, args) where {TFn,FFn}
+ if cond
+ return true_fn(args...)
+ else
+ return false_fn(args...)
+ end
+end
+
+function traced_if(
+ cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args
+) where {TFn,FFn}
+ _, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results = Reactant.make_mlir_fn(
+ true_fn, args, (), string(gensym("true_branch")), false
+ )
The above example will be fine, since the returns are union-ed. But the
following errors:
function f(x, y)
@Traced if x < y
x = x + 1
p = 2
else
y = y + 1
endend
—
Reply to this email directly, view it on GitHub
<#185 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTUXDETHJF572TDPBJ2LLZ4Q4B3AVCNFSM6AAAAABQH3EMSGVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDGOBQG4YDANBWHA>
.
You are receiving this because your review was requested.Message ID:
***@***.***>
|
d330af1
to
c31f4cb
Compare
using Reactant
function conditional_fn(x)
v = sum(x)
@trace if v > 0
v2 = v
else
v2 = -v
end
return v2, x
end
conditional_fn(rand(2, 3))
x = rand(2, 3)
x_ra = Reactant.to_rarray(x)
conditional_fn_compiled = @compile conditional_fn(x_ra) julia> @code_hlo optimize=false conditional_fn(x_ra)
Module:
"builtin.module"() ({
"func.func"() <{function_type = (tensor<f64>) -> tensor<f64>, sym_name = "identity_broadcast_scalar", sym_visibility = "private"}> ({
^bb0(%arg5: tensor<f64>):
%17 = "stablehlo.transpose"(%arg5) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
%18 = "stablehlo.transpose"(%17) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
"func.return"(%18) : (tensor<f64>) -> ()
}) : () -> ()
"func.func"() <{function_type = (tensor<f64>) -> tensor<f64>, sym_name = "##true_branch#320", sym_visibility = "private"}> ({
}) : () -> ()
"func.func"() <{function_type = (tensor<f64>) -> (tensor<f64>, tensor<f64>), sym_name = "##false_branch#321", sym_visibility = "private"}> ({
}) : () -> ()
"func.func"() <{function_type = (tensor<3x2xf64>) -> (tensor<f64>, tensor<3x2xf64>), sym_name = "main"}> ({
^bb0(%arg0: tensor<3x2xf64>):
%0 = "stablehlo.transpose"(%arg0) <{permutation = array<i64: 1, 0>}> : (tensor<3x2xf64>) -> tensor<2x3xf64>
%1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f64>}> : () -> tensor<f64>
%2 = "stablehlo.broadcast_in_dim"(%0) <{broadcast_dimensions = array<i64: 0, 1>}> : (tensor<2x3xf64>) -> tensor<2x3xf64>
%3 = "enzyme.batch"(%2) <{batch_shape = array<i64: 2, 3>, fn = @identity_broadcast_scalar}> : (tensor<2x3xf64>) -> tensor<2x3xf64>
%4 = "stablehlo.reduce"(%3, %1) <{dimensions = array<i64: 0, 1>}> ({
^bb0(%arg3: tensor<f64>, %arg4: tensor<f64>):
%16 = "stablehlo.add"(%arg3, %arg4) : (tensor<f64>, tensor<f64>) -> tensor<f64>
"stablehlo.return"(%16) : (tensor<f64>) -> ()
}) : (tensor<2x3xf64>, tensor<f64>) -> tensor<f64>
%5 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f64>}> : () -> tensor<f64>
%6 = "stablehlo.compare"(%4, %5) <{comparison_direction = #stablehlo<comparison_direction GT>}> : (tensor<f64>, tensor<f64>) -> tensor<i1>
%7 = "stablehlo.if"(%6) ({
^bb0(%arg2: tensor<f64>):
%14 = "stablehlo.transpose"(%arg2) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
%15 = "stablehlo.transpose"(%14) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
"stablehlo.return"(%15) : (tensor<f64>) -> ()
}, {
^bb0(%arg1: tensor<f64>):
%10 = "stablehlo.transpose"(%arg1) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
%11 = "stablehlo.negate"(%10) : (tensor<f64>) -> tensor<f64>
%12 = "stablehlo.transpose"(%11) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
%13 = "stablehlo.transpose"(%10) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
"stablehlo.return"(%12, %13) : (tensor<f64>, tensor<f64>) -> ()
}) : (tensor<i1>) -> tensor<f64>
%8 = "stablehlo.transpose"(%7) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
%9 = "stablehlo.transpose"(%0) <{permutation = array<i64: 1, 0>}> : (tensor<2x3xf64>) -> tensor<3x2xf64>
"func.return"(%8, %9) : (tensor<f64>, tensor<3x2xf64>) -> ()
}) : () -> ()
}) : () -> () Doesn't yet work completely. The branches in |
so this shouldn't be bad to fix (assuming that the true/false regions just accept the args from the parent -- which is notably not how while works (which takes them as blockarguments like you have setup here). Essentially we need to do a replace all users of the blocks with the actual values in the parent (there should be a replace api call in mlir.jl cc @mofeing) and then delete the now unused blockargs |
Module:
"builtin.module"() ({
"func.func"() <{function_type = (tensor<f64>) -> tensor<f64>, sym_name = "identity_broadcast_scalar", sym_visibility = "private"}> ({
^bb0(%arg5: tensor<f64>):
%17 = "stablehlo.transpose"(%arg5) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
%18 = "stablehlo.transpose"(%17) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
"func.return"(%18) : (tensor<f64>) -> ()
}) : () -> ()
"func.func"() <{function_type = (tensor<f64>) -> tensor<f64>, sym_name = "##true_branch#282", sym_visibility = "private"}> ({
}) : () -> ()
"func.func"() <{function_type = (tensor<f64>) -> (tensor<f64>, tensor<f64>), sym_name = "##false_branch#283", sym_visibility = "private"}> ({
}) : () -> ()
"func.func"() <{function_type = (tensor<3x2xf64>) -> (tensor<f64>, tensor<3x2xf64>), sym_name = "main"}> ({
^bb0(%arg0: tensor<3x2xf64>):
%0 = "stablehlo.transpose"(%arg0) <{permutation = array<i64: 1, 0>}> : (tensor<3x2xf64>) -> tensor<2x3xf64>
%1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f64>}> : () -> tensor<f64>
%2 = "stablehlo.broadcast_in_dim"(%0) <{broadcast_dimensions = array<i64: 0, 1>}> : (tensor<2x3xf64>) -> tensor<2x3xf64>
%3 = "enzyme.batch"(%2) <{batch_shape = array<i64: 2, 3>, fn = @identity_broadcast_scalar}> : (tensor<2x3xf64>) -> tensor<2x3xf64>
%4 = "stablehlo.reduce"(%3, %1) <{dimensions = array<i64: 0, 1>}> ({
^bb0(%arg3: tensor<f64>, %arg4: tensor<f64>):
%16 = "stablehlo.add"(%arg3, %arg4) : (tensor<f64>, tensor<f64>) -> tensor<f64>
"stablehlo.return"(%16) : (tensor<f64>) -> ()
}) : (tensor<2x3xf64>, tensor<f64>) -> tensor<f64>
%5 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f64>}> : () -> tensor<f64>
%6 = "stablehlo.compare"(%4, %5) <{comparison_direction = #stablehlo<comparison_direction GT>}> : (tensor<f64>, tensor<f64>) -> tensor<i1>
%7 = "stablehlo.if"(%6) ({
^bb0(%arg2: tensor<f64>):
%14 = "stablehlo.transpose"(%arg2) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
%15 = "stablehlo.transpose"(%4) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
"stablehlo.return"(%15) : (tensor<f64>) -> ()
}, {
^bb0(%arg1: tensor<f64>):
%10 = "stablehlo.transpose"(%arg1) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
%11 = "stablehlo.negate"(%4) : (tensor<f64>) -> tensor<f64>
%12 = "stablehlo.transpose"(%11) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
%13 = "stablehlo.transpose"(%4) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
"stablehlo.return"(%12, %13) : (tensor<f64>, tensor<f64>) -> ()
}) : (tensor<i1>) -> tensor<f64>
%8 = "stablehlo.transpose"(%7) <{permutation = array<i64>}> : (tensor<f64>) -> tensor<f64>
%9 = "stablehlo.transpose"(%0) <{permutation = array<i64: 1, 0>}> : (tensor<2x3xf64>) -> tensor<3x2xf64>
"func.return"(%8, %9) : (tensor<f64>, tensor<3x2xf64>) -> ()
}) : () -> ()
}) : () -> () How do I delete the arguments? I don't see an equivalent to https://mlir.llvm.org/doxygen/classmlir_1_1Block.html#ab497ed8c6f52f6faa5e9a8ac0c6b1014 in reactant |
666469d
to
66c82c1
Compare
yeah it seems that |
Remember that Reactant uses LLVM 20 currently, because it has its own LLVM build with XLA. |
Remember that Reactant uses LLVM 20-DEV currently, because it has its own LLVM build with XLA. |
oh nice, then we need to regenerate the LibMLIR file to get |
03aa1e1
to
feea37e
Compare
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
e3fd076
to
9ab07db
Compare
the setindex issue still needs fixing |
If we merge this before fixing that, would you mind opening an issue so that we can keep track of that please? |
TODOs
ControlFlow.traced_if
ReactantCore
?