Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specialize sum over arrays of static numbers #51

Open
oschulz opened this issue Apr 6, 2022 · 16 comments
Open

Specialize sum over arrays of static numbers #51

oschulz opened this issue Apr 6, 2022 · 16 comments

Comments

@oschulz
Copy link

oschulz commented Apr 6, 2022

We currently have

mapreduce(identity, +, fill(static(0), 5)) isa StaticInt{0}

but

sum(fill(static(0), 5)) isa Int
@oschulz
Copy link
Author

oschulz commented Apr 6, 2022

Also, mapreduce returns a static number but is very, very slow, this could affect generic code that doesn't care about number type:

julia> using Static, BenchmarkTools

julia> @btime mapreduce(identity, +, fill(1, 10^5))
  89.239 μs (2 allocations: 781.30 KiB)
100000

julia> @btime mapreduce(identity, +, fill(static(1), 10^5))
  17.578 ms (1 allocation: 48 bytes)
static(100000)

@Tokazama
Copy link
Collaborator

Tokazama commented Apr 6, 2022

Yeah, that's not great. Is this something you've been running into or is it just a result of thorough testing?

@oschulz
Copy link
Author

oschulz commented Apr 6, 2022

I plan to use static numbers for statistical weights in BAT.jl (I often have scenarios where samples may or may not be weighted, depending on sampling algorithm), so I played around a bit to make sure that operations like sum(fill(static(...), ...)) and sum(FillArrays.Fill(static(...), ...)) are efficient.

Somewhat related: JuliaArrays/FillArrays.jl#176

@chriselrod
Copy link
Collaborator

chriselrod commented Apr 6, 2022

This is what reduce_tup(+, x) is for.
It will also be type stable.

@generated function reduce_tup(f::F, inds::Tuple{Vararg{Any,N}}) where {F,N}

Edit: oops, arrays, not tuples. Is your array type stable?
Or a small union?
Seems like this is an odd thing to want to use. What's the use case?

IMO, it should return Int, unless the length is known at compile time, or eltype is StaticInt{0}.

@Tokazama
Copy link
Collaborator

Tokazama commented Apr 6, 2022

If you have 10^5 unique StaticInt values then I'm not sure it's the right type. If what you're actually dealing with is a handful of unique statically known values, but they are assigned to 10^5 unique variables then you probably want a different data structure than Vector.

That's not to say there's nothing we can do here to make the situation better, but it seems like one of those corner cases where an unintuitive or involved solution may be worse than just explicitly telling people that they should avoid that pattern. Just my initial thoughts.

@oschulz
Copy link
Author

oschulz commented Apr 6, 2022

If you have 10^5 unique StaticInt values

You can easily end up with that with genric code, e.g. starting from FillArrays.Fill(static(1), 10^5) (I would use that to represent a static statistical weight for a large number of samples, for example). There are many operations that will turn a FillArrays.Fill into a standard Array.

@Tokazama
Copy link
Collaborator

Tokazama commented Apr 6, 2022

Oh, so you're starting off with an ideal structure where the default is for everything to be 1 but then it gets converted when reducing? I can see how that would be easy to fall into.

I still think it could be problematic to implement a fix on our end. We'd have to somehow be aware of a dynamic number of iterations through a loop or awkwardly define Base.sum(a::AbstractArray{StaticInt}). But if someone has a clean solution I'm game.

@oschulz
Copy link
Author

oschulz commented Apr 6, 2022

Oh, so you're starting off with an ideal structure where the default is for everything to be 1 but then it gets converted when reducing? I can see how that would be easy to fall into.

Yes, exactly. I have a lot of generic code that handles statistical samples (and will have more). The samples may all have weight 1 - in which case important optimizations can be made or they may have arbitrary integer of floating point weights. So it would seem natural to use something like Ones{StaticInt}(n) (n can be very large) or Fill(static(1), n. But even with certain possible specializations (see JuliaArrays/FillArrays.jl#177) one can still end up with an AbstractArray{StaticInt} after a while. So it would be important that standard operations like sum, mapreduce etc. function efficiently on arrays of static numbers, and ideally propagate the "staticness" of the element type.

@oschulz
Copy link
Author

oschulz commented Apr 6, 2022

I still think it could be problematic to implement a fix on our end.

Hm

A = fill(static(1), 10)
sum(A)

comes down to

mapreduce(identity, Base.add_sum, A)

I think the culprit may be

julia> x = static(1)
static(1)

julia> x + x
static(2)

julia> Base.add_sum(x, x)
2

This seems to be due to the return-type assertion in the definition of Base.add_sum

add_sum(x::Real, y::Real)::Real = x + y

Since

julia> foo(a::Real, b::Real) = a + b
foo (generic function with 1 method)

julia> foo(x, x)
static(2)

julia> bar(a::Real, b::Real)::Real = a + b
bar (generic function with 1 method)

julia> bar(x, x)
2

which I don't get, since static(2) isa Real.

@devmotion
Copy link
Member

The samples may all have weight 1 - in which case important optimizations can be made

A bit off-topic in this thread: Isn't this what StatsBase.UnitWeights is designed for?

@oschulz
Copy link
Author

oschulz commented Apr 6, 2022

A bit off-topic in this thread: Isn't this what StatsBase.UnitWeights is designed for?

@devmotion In principle, yes. Though if we can get this to work, we could get rid of StatsBase.UnitWeights as yet-another custom-array-of-ones (I think we have several more scattered through the ecosystem). And uweights(5)[1] isa Integer, not a StaticInt, so when doing mathematical operations on UnitWeights the information that the weights are known and static get's lost immediately.

But I actually had meant to pull to talk to the StatsBase maintainers in this context too, I just thought I should explore options regarding Static and FillArrays and how they could play together better first before making a proposal regarding UnitWeights. But I'd be very happy to get this all together.

@devmotion
Copy link
Member

I didn't want to initiate a longer discussion here, I don't think it's the right place for it. Just briefly: I don't think it's possible, and I don't even think it's desirable, to get rid of UnitWeights - e.g., in many cases an AbstractWeights supertype is needed/desired and the weight type hierarchy and traits are needed for other weight types anyway. That doesn't rule out that for other applications a fix for the issue here would be desirable but for the weights example using AbstractWeights seems a bit more natural.

@oschulz
Copy link
Author

oschulz commented Apr 6, 2022

don't even think it's desirable, to get rid of UnitWeights

Sorry, I shouldn't have said "get rid of", I agree that the additional semantics (beyond "array of ones") of UnitWeights are very important for many applications. I didn't mean removing it, but potentially powering it with Static and FillArrays. Building a fully-featured fill-of-static type isn't trivial (I just opened JuliaStats/StatsBase.jl/issues/782 regarding UnitWeights and vcat), so it might help to have common implementations underneath.

@Tokazama
Copy link
Collaborator

Tokazama commented Apr 6, 2022

Like Chris pointed out, this also depends on knowing the size array in some cases to be type stable. It's hard to know what the best approach is for addressing this until we get some closure on JuliaLang/julia#44538.

This seems like the sort of thing that should be part of the methods internal interface with something like

function _add_sum(A::AbstractArray{StaticInt{N}}) where {N}
    if known_length(A) === nothing
        return length(A) * N
    else
        return StaticInt{known_length(A) * N}()
    end
end

because it's usually a bad idea to dispatch on the element type of an abstract collection.

@oschulz
Copy link
Author

oschulz commented Apr 6, 2022

Currently, mapreduce on static numbers tries to keep the staticness even if the array size isn't known, at huge performance cost:

julia> @btime mapreduce(sqrt, Base.add_sum, A)
  17.152 ms (0 allocations: 0 bytes)
static(204939.01531919395)

Could we add a specialization like

function Base.mapreduce(f, ::typeof(Base.add_sum), A::AbstractVector{T}) where {T<:StaticFloat64}
    length(eachindex(A)) * T()
end

it would get us

julia> @btime mapreduce(sqrt, Base.add_sum, A)
  73.947 ns (1 allocation: 16 bytes)
420000.0

@oschulz
Copy link
Author

oschulz commented Apr 6, 2022

Could we add a specialization like ...

Darn, we'd get an ambiguity with StaticArrays with that (possibly other packages too):

julia> A = SVector(fill(static(4.2), 5)...)^C
julia> mapreduce(sqrt, Base.add_sum, A)
ERROR: MethodError: mapreduce(::typeof(sqrt), ::typeof(Base.add_sum), ::SVector{5, StaticFloat64{4.2}}) is ambiguous. Candidates:
  mapreduce(f, ::typeof(Base.add_sum), A::AbstractVector{T}) where T<:StaticFloat64
  mapreduce(f, op, a::StaticArray, b::StaticArray...; dims, init)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants