-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
should fold
and similar operations be unrolled?
#227
Comments
I will create a MWE later but opening this so that I don't forget. Recurrence internally just calls |
using Reactant
function custom_op(x::AbstractArray)
function inner_op(xᵢ::AbstractVector, yᵢ::AbstractVector)
return xᵢ .+ yᵢ
end
return foldl(inner_op, eachcol(x))
end
x_ra = Reactant.to_rarray(rand(2, 4))
@code_hlo custom_op(x_ra) Module:
module attributes {transform.with_named_sequence} {
func.func @main(%arg0: tensor<4x2xf64>) -> tensor<2xf64> {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x2xf64>) -> tensor<2x4xf64>
%1 = stablehlo.slice %0 [0:2, 0:1] : (tensor<2x4xf64>) -> tensor<2x1xf64>
%2 = stablehlo.reshape %1 : (tensor<2x1xf64>) -> tensor<2xf64>
%3 = stablehlo.slice %0 [0:2, 1:2] : (tensor<2x4xf64>) -> tensor<2x1xf64>
%4 = stablehlo.reshape %3 : (tensor<2x1xf64>) -> tensor<2xf64>
%5 = stablehlo.add %2, %4 : tensor<2xf64>
%6 = stablehlo.slice %0 [0:2, 2:3] : (tensor<2x4xf64>) -> tensor<2x1xf64>
%7 = stablehlo.reshape %6 : (tensor<2x1xf64>) -> tensor<2xf64>
%8 = stablehlo.add %5, %7 : tensor<2xf64>
%9 = stablehlo.slice %0 [0:2, 3:4] : (tensor<2x4xf64>) -> tensor<2x1xf64>
%10 = stablehlo.reshape %9 : (tensor<2x1xf64>) -> tensor<2xf64>
%11 = stablehlo.add %8, %10 : tensor<2xf64>
return %11 : tensor<2xf64>
}
} |
Yeah we should probably lower this directly as a reduction of sorts |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
(needs LuxDL/Lux.jl#1026 for the dump)
This currently gets unrolled into the following monstrosity: https://pastebin.com/QE00SCqb
The text was updated successfully, but these errors were encountered: