Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle new variables being created inside a traced loop #303

Closed
wants to merge 4 commits into from

Conversation

jumerckx
Copy link
Contributor

RE #301

@jumerckx jumerckx changed the title Handle new variables being created inside a traced loo Handle new variables being created inside a traced loop Nov 25, 2024
@jumerckx
Copy link
Contributor Author

CI succeeds (?) but running the tests locally, I get a failure for the sinkhorn control flow test.

error: expect operands to be compatible with body block arguments but got 'tensor<i64>', 'tensor<5xf32>', 'tensor<5xf32>', 'tensor<10xf32>', 'tensor<10x5xf32>', 'tensor<10xf32>' vs 'tensor<5xf32>', 'tensor<i64>', 'tensor<5xf32>', 'tensor<10xf32>', 'tensor<10x5xf32>', 'tensor<10xf32>'
ERROR: "failed to run pass manager on module"
Stacktrace:
  [1] run!
    @ ~/Reactant.jl/src/mlir/IR/Pass.jl:79 [inlined]
  [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String; enable_verifier::Bool)
    @ Reactant.Compiler ~/Reactant.jl/src/Compiler.jl:256
  [3] run_pass_pipeline!
    @ ~/Reactant.jl/src/Compiler.jl:250 [inlined]
  [4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}; optimize::Bool)
    @ Reactant.Compiler ~/Reactant.jl/src/Compiler.jl:300

The block arguments of body_fn seem to be in the wrong order. Is the sinkhorn test run by CI? If yes, my local setup might just not be fully up to date or something.

cond_fn = ((var"##i#444179", var"##v#444181", ν, μ, K, u) -> begin
    local num_iters = div(10 - 1, 1, RoundDown)
    local num_iters = Reactant.promote_to(Reactant.TracedRNumber{Int64}, num_iters)
    var"##i#444179" < num_iters + 1
end)
body_fn = ((var"##i#444179", var"##v#444181", ν, μ, K, u) -> begin
    local step_ = 1
    local start_ = 1
    local _ = start_ + var"##i#444179" * step_
    begin
        v = ν ./ (K' * u)
        u = μ ./ (K * v)
    end
    !(var"##v#444181" isa ReactantCore.MissingTracedValue) && (var"##v#444181" = v)
    (var"##i#444179" + 1, var"##v#444181", ν, μ, K, u)
end)
%12:6 = "stablehlo.while"(%11, %10, %1, %0, %7#0, %9) ({
^bb0(%arg9: tensor<i64>, %arg10: tensor<5xf32>, %arg11: tensor<5xf32>, %arg12: tensor<10xf32>, %arg13: tensor<10x5xf32>, %arg14: tensor<10xf32>):
  // ...
  "stablehlo.return"(%64) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<5xf32>, %arg4: tensor<i64>, %arg5: tensor<5xf32>, %arg6: tensor<10xf32>, %arg7: tensor<10x5xf32>, %arg8: tensor<10xf32>):
  // ...
  "stablehlo.return"(%60, %44#0, %arg5, %arg6, %arg7, %58#0) : (tensor<i64>, tensor<5xf32>, tensor<5xf32>, tensor<10xf32>, tensor<10x5xf32>, tensor<10xf32>) -> ()
}) : (tensor<i64>, tensor<5xf32>, tensor<5xf32>, tensor<10xf32>, tensor<10x5xf32>, tensor<10xf32>) -> (tensor<i64>, tensor<5xf32>, tensor<5xf32>, tensor<10xf32>, tensor<10x5xf32>, tensor<10xf32>)

Copy link
Collaborator

@Pangoraw Pangoraw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have another possible solution in Pangoraw@32853eb. I think there might be edge cases though.

Comment on lines +183 to +184
let
args = $(args_init)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this changes the scope of args here, is it on purpose?

filter!(∉(SPECIAL_SYMBOLS), body_symbols.assignments)
filter!(∉(SPECIAL_SYMBOLS), body_symbols.references)

potentially_undefined = setdiff(body_symbols.assignments, body_symbols.references)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all body_symbols.assignments are potentially undefined here. ExpressionExplorer already takes care of not adding references after an assignment.

julia> ss = ExpressionExplorer.compute_symbols_state(quote
           x = x
           y = x
           y, x
       end)
SymbolsState(Set([:x]), Set([:y, :x]), Set{FunctionName}(), Dict{FunctionNameSignaturePair, SymbolsState}(), Set{FunctionName}())

julia> ss.assignments
Set{Symbol} with 2 elements:
  :y
  :x

julia> ss.references
Set{Symbol} with 1 element:
  :x

@@ -172,6 +196,7 @@ function trace_for(mod, expr)
local start_ = $start
local $induction = start_ + $counter * step_
$body
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it needs something before the body (opposite of updates) ?

!($def isa $(MissingTracedValue)) && ($arg = $def)

$body

!($def isa $(MissingTracedValue)) && ($def = $arg)

@jumerckx
Copy link
Contributor Author

jumerckx commented Nov 27, 2024

I have another possible solution in Pangoraw@32853eb. I think there might be edge cases though.

Thanks for the review @Pangoraw! I like the simplicity of your approach way better, though. I can close this pr and we can go with your branch.

One consideration I have is for cond_val(s) = :(isdefined($(mod), $(QuoteNode(s))) ? $s : nothing). This works for global variables but doesn't pick up function-local variables, I think?
There exists a special expression type for checking isdefined, it seems:

julia> @macroexpand @isdefined var
:($(Expr(:isdefined, :var)))

@Pangoraw
Copy link
Collaborator

There exists a special expression type for checking isdefined, it seems:

oh right, I missed that one. thank you! I opened #310. I think it would maybe better to use the MissingTracedValue there instead like in your pr to represent undefined values.

@jumerckx
Copy link
Contributor Author

jumerckx commented Nov 27, 2024

I think it would maybe better to use the MissingTracedValue

Yeah agreed, it doesn't change much but seems a bit nicer.

@jumerckx
Copy link
Contributor Author

closed in favor of #310

@jumerckx jumerckx closed this Nov 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants