Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

should fold and similar operations be unrolled? #227

Open
avik-pal opened this issue Nov 4, 2024 · 3 comments
Open

should fold and similar operations be unrolled? #227

avik-pal opened this issue Nov 4, 2024 · 3 comments

Comments

@avik-pal
Copy link
Collaborator

avik-pal commented Nov 4, 2024

(needs LuxDL/Lux.jl#1026 for the dump)

using Lux, Reactant, Random

model = Recurrence(RNNCell(4 => 4))
ps, st = Lux.setup(Xoshiro(123), model) |> Reactant.to_rarray
x = rand(Float32, 4, 2, 12) |> Reactant.ConcreteRArray

@code_hlo model(x, ps, st)

This currently gets unrolled into the following monstrosity: https://pastebin.com/QE00SCqb

@avik-pal
Copy link
Collaborator Author

avik-pal commented Nov 4, 2024

I will create a MWE later but opening this so that I don't forget. Recurrence internally just calls foldl, which we should be able to @reactant_override?

@avik-pal
Copy link
Collaborator Author

avik-pal commented Nov 4, 2024

using Reactant

function custom_op(x::AbstractArray)
    function inner_op(xᵢ::AbstractVector, yᵢ::AbstractVector)
        return xᵢ .+ yᵢ
    end
    return foldl(inner_op, eachcol(x))
end

x_ra = Reactant.to_rarray(rand(2, 4))

@code_hlo custom_op(x_ra)
Module:
module attributes {transform.with_named_sequence} {
  func.func @main(%arg0: tensor<4x2xf64>) -> tensor<2xf64> {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x2xf64>) -> tensor<2x4xf64>
    %1 = stablehlo.slice %0 [0:2, 0:1] : (tensor<2x4xf64>) -> tensor<2x1xf64>
    %2 = stablehlo.reshape %1 : (tensor<2x1xf64>) -> tensor<2xf64>
    %3 = stablehlo.slice %0 [0:2, 1:2] : (tensor<2x4xf64>) -> tensor<2x1xf64>
    %4 = stablehlo.reshape %3 : (tensor<2x1xf64>) -> tensor<2xf64>
    %5 = stablehlo.add %2, %4 : tensor<2xf64>
    %6 = stablehlo.slice %0 [0:2, 2:3] : (tensor<2x4xf64>) -> tensor<2x1xf64>
    %7 = stablehlo.reshape %6 : (tensor<2x1xf64>) -> tensor<2xf64>
    %8 = stablehlo.add %5, %7 : tensor<2xf64>
    %9 = stablehlo.slice %0 [0:2, 3:4] : (tensor<2x4xf64>) -> tensor<2x1xf64>
    %10 = stablehlo.reshape %9 : (tensor<2x1xf64>) -> tensor<2xf64>
    %11 = stablehlo.add %8, %10 : tensor<2xf64>
    return %11 : tensor<2xf64>
  }
}

@wsmoses
Copy link
Member

wsmoses commented Nov 5, 2024

Yeah we should probably lower this directly as a reduction of sorts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants