Skip to content

Commit

Permalink
Merge pull request #255 from ReactiveBayes/240-error-when-running-inf…
Browse files Browse the repository at this point in the history
…er-due-to-splatting

Implement splatting variables
  • Loading branch information
wouterwln authored Oct 8, 2024
2 parents c538d83 + 4fcf2c8 commit c7dca52
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 38 deletions.
1 change: 1 addition & 0 deletions docs/src/developers_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ GraphPPL.NodeData
GraphPPL.NodeLabel
GraphPPL.EdgeLabel
GraphPPL.ProxyLabel
GraphPPL.Splat
GraphPPL.indexed_last
GraphPPL.lift_index
GraphPPL.datalabel
Expand Down
38 changes: 37 additions & 1 deletion src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,15 @@ Base.show(io::IO, label::EdgeLabel) = print(io, to_symbol(label))
Base.:(==)(label1::EdgeLabel, label2::EdgeLabel) = label1.name == label2.name && label1.index == label2.index
Base.hash(label::EdgeLabel, h::UInt) = hash(label.name, hash(label.index, h))

"""
Splat{T}
A type used to represent splatting in the model macro. Any call on the right hand side of ~ that uses splatting will be wrapped in this type.
"""
struct Splat{T}
collection::T
end

"""
ProxyLabel(name, index, proxied)
Expand All @@ -306,6 +315,9 @@ is_proxied(::Type{T}) where {T <: NodeLabel} = True()
is_proxied(::Type{T}) where {T <: ProxyLabel} = True()
is_proxied(::Type{T}) where {T <: AbstractArray} = is_proxied(eltype(T))

proxylabel(name::Symbol, proxied::Splat{T}, index, maycreate) where {T} =
[proxylabel(name, proxiedelement, index, maycreate) for proxiedelement in proxied.collection]

# By default, `proxylabel` set `maycreate` to `False`
proxylabel(name::Symbol, proxied, index) = proxylabel(name, proxied, index, False())
proxylabel(name::Symbol, proxied, index, maycreate) = proxylabel(is_proxied(proxied), name, proxied, index, maycreate)
Expand Down Expand Up @@ -1115,6 +1127,28 @@ __check_external_collection_compatibility(label::VariableRef, collection::Number
# For all other we simply don't know so we assume we are compatible
__check_external_collection_compatibility(label::VariableRef, collection, indices::Tuple) = true

function Base.iterate(ref::VariableRef, state)
if !isnothing(external_collection(ref))
return iterate(external_collection(ref), state)
elseif !isnothing(internal_collection(ref))
return iterate(internal_collection(ref), state)
elseif haskey(ref.context, ref.name)
return iterate(ref.context[ref.name], state)
end
error("Cannot iterate over $(ref.name). The underlying collection for `$(ref.name)` has undefined shape.")
end

function Base.iterate(ref::VariableRef)
if !isnothing(external_collection(ref))
return iterate(external_collection(ref))
elseif !isnothing(internal_collection(ref))
return iterate(internal_collection(ref))
elseif haskey(ref.context, ref.name)
return iterate(ref.context[ref.name])
end
error("Cannot iterate over $(ref.name). The underlying collection for `$(ref.name)` has undefined shape.")
end

function Base.broadcastable(ref::VariableRef)
if !isnothing(external_collection(ref))
# If we have an underlying collection (e.g. data), we should instantiate all variables at the point of broadcasting
Expand Down Expand Up @@ -2132,7 +2166,9 @@ Calls a plugin specific logic after the model has been created. By default does
"""
postprocess_plugin(plugin, model) = nothing

function preprocess_plugins(type::AbstractPluginTraitType, model::Model, context::Context, label::NodeLabel, nodedata::NodeData, options)::Tuple{NodeLabel, NodeData}
function preprocess_plugins(
type::AbstractPluginTraitType, model::Model, context::Context, label::NodeLabel, nodedata::NodeData, options
)::Tuple{NodeLabel, NodeData}
plugins = filter(type, getplugins(model))
return foldl(plugins; init = (label, nodedata)) do (label, nodedata), plugin
return preprocess_plugin(plugin, model, context, label, nodedata, options)::Tuple{NodeLabel, NodeData}
Expand Down
2 changes: 2 additions & 0 deletions src/model_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,8 @@ function proxy_args_rhs(rhs)
return :(GraphPPL.proxylabel($(QuoteNode(rlabel)), $rlabel, $(Expr(:tuple, index...)), GraphPPL.False()))
elseif @capture(rhs, new(rlabel_[index__]))
return :(GraphPPL.proxylabel($(QuoteNode(rlabel)), $rlabel, $(Expr(:tuple, index...)), GraphPPL.True()))
elseif @capture(rhs, rlabel_...)
return :(GraphPPL.proxylabel($(QuoteNode(rlabel)), GraphPPL.Splat($rlabel), nothing, GraphPPL.False())...)
end
return :(GraphPPL.proxylabel(:anonymous, $rhs, nothing, GraphPPL.False()))
end
Expand Down
Loading

0 comments on commit c7dca52

Please sign in to comment.