diff --git a/src/AbstractOperations/conditional_operations.jl b/src/AbstractOperations/conditional_operations.jl index 7eba3699f7..7f9c7251cb 100644 --- a/src/AbstractOperations/conditional_operations.jl +++ b/src/AbstractOperations/conditional_operations.jl @@ -5,23 +5,27 @@ import Oceananigans.Architectures: on_architecture import Oceananigans.Fields: condition_operand, conditional_length, set!, compute_at!, indices # For conditional reductions such as mean(u * v, condition = u .> 0)) -struct ConditionalOperation{LX, LY, LZ, O, F, G, C, M, T} <: AbstractOperation{LX, LY, LZ, G, T} +struct ConditionalOperation{LX, LY, LZ, F, C, O, G, M, T} <: AbstractOperation{LX, LY, LZ, G, T} operand :: O func :: F grid :: G condition :: C mask :: M - function ConditionalOperation{LX, LY, LZ}(operand::O, func::F, grid::G, - condition::C, mask::M) where {LX, LY, LZ, O, F, G, C, M} + function ConditionalOperation{LX, LY, LZ}(operand::O, func, grid::G, + condition::C, mask::M) where {LX, LY, LZ, O, G, C, M} + if func === Base.identity + func = nothing + end T = eltype(operand) - return new{LX, LY, LZ, O, F, G, C, M, T}(operand, func, grid, condition, mask) + F = typeof(func) + return new{LX, LY, LZ, F, C, O, G, M, T}(operand, func, grid, condition, mask) end end """ ConditionalOperation(operand::AbstractField; - func = identity, + func = nothing, condition = nothing, mask = 0) @@ -31,19 +35,19 @@ described by `func(operand)`. Positional arguments ==================== -- `operand`: The `AbstractField` to be masked (it must have a `grid` property!) +- `operand`: The `AbstractField` to be masked. Keyword arguments ================= - `func`: A unary transformation applied element-wise to the field `operand` at locations where - `condition == true`. Default is `identity`. + `condition == true`. Default is `nothing` which applies no transformation. - `condition`: either a function of `(i, j, k, grid, operand)` returning a Boolean, or a 3-dimensional Boolean `AbstractArray`. At locations where `condition == false`, - operand will be masked by `mask` + operand will be masked by `mask`. -- `mask`: the scalar mask +- `mask`: the scalar mask. Default: 0. `condition_operand` is a convenience function used to construct a `ConditionalOperation` @@ -78,10 +82,9 @@ julia> d[2, 1, 1] ``` """ function ConditionalOperation(operand::AbstractField; - func = identity, + func = nothing, condition = nothing, mask = zero(eltype(operand))) - LX, LY, LZ = location(operand) return ConditionalOperation{LX, LY, LZ}(operand, func, operand.grid, condition, mask) end @@ -95,29 +98,44 @@ function ConditionalOperation(c::ConditionalOperation; return ConditionalOperation{LX, LY, LZ}(c.operand, func, c.grid, condition, mask) end -struct TrueCondition end +@inline function Base.getindex(co::ConditionalOperation, i, j, k) + conditioned = evaluate_condition(co.condition, i, j, k, c.grid, c) + value = getindex(co.operand, i, j, k) + func_value = co.func(value) + return ifelse(conditioned, value, c.mask) +end + +# Some special cases +const NoFuncCO = ConditionalOperation{<:Any, <:Any, <:Any, Nothing} +const NoConditionCO = ConditionalOperation{<:Any, <:Any, <:Any, <:Any, Nothing} +const NoFuncNoConditionCO = ConditionalOperation{<:Any, <:Any, <:Any, Nothing, Nothing} + +using Base: @propagate_inbounds +@propagate_inbounds function Base.getindex(co::NoConditionCO, i, j, k) + value = getindex(co.operand, i, j, k) + return co.func(value) +end + +@propagate_inbounds Base.getindex(co::NoFuncNoConditionCO, i, j, k) = getindex(co.operand, i, j, k) -@inline function Base.getindex(c::ConditionalOperation, i, j, k) - return ifelse(evaluate_condition(c.condition, i, j, k, c.grid, c), - c.func(getindex(c.operand, i, j, k)), - c.mask) +@propagate_inbounds function Base.getindex(co::NoFuncCO, i, j, k) + conditioned = evaluate_condition(co.condition, i, j, k, co.grid, co) + value = getindex(co.operand, i, j, k) + return ifelse(conditioned, value, co.mask) end -@inline evaluate_condition(condition, i, j, k, grid, args...) = condition(i, j, k, grid, args...) -@inline evaluate_condition(::TrueCondition, i, j, k, grid, args...) = true -@inline evaluate_condition(condition::AbstractArray, i, j, k, grid, args...) = @inbounds condition[i, j, k] +# Conditions: general, nothing, array +@inline evaluate_condition(condition, i, j, k, grid, args...) = condition(i, j, k, grid, args...) +@inline evaluate_condition(::Nothing, i, j, k, grid, args...) = true +@propagate_inbounds evaluate_condition(condition::AbstractArray, i, j, k, grid, args...) = condition[i, j, k] -@inline condition_operand(func::Function, op::AbstractField, condition, mask) = ConditionalOperation(op; func, condition, mask) -@inline condition_operand(func::Function, op::AbstractField, ::Nothing, mask) = ConditionalOperation(op; func, condition=TrueCondition(), mask) +@inline condition_operand(func, op, condition, mask) = ConditionalOperation(op; func, condition, mask) -@inline function condition_operand(func::Function, operand::AbstractField, condition::AbstractArray, mask) +@inline function condition_operand(func, operand, condition::AbstractArray, mask) condition = on_architecture(architecture(operand.grid), condition) return ConditionalOperation(operand; func, condition, mask) end -@inline condition_operand(func::typeof(identity), c::ConditionalOperation, ::Nothing, mask) = ConditionalOperation(c; mask) -@inline condition_operand(func::Function, c::ConditionalOperation, ::Nothing, mask) = ConditionalOperation(c; func, mask) - @inline materialize_condition!(c::ConditionalOperation) = set!(c.operand, c) function materialize_condition(c::ConditionalOperation) @@ -126,14 +144,17 @@ function materialize_condition(c::ConditionalOperation) return f end -@inline condition_onefield(c::ConditionalOperation{LX, LY, LZ}, mask) where {LX, LY, LZ} = - ConditionalOperation{LX, LY, LZ}(OneField(Int), identity, c.grid, c.condition, mask) +@inline function conditional_one(c::ConditionalOperation, mask) + LX, LY, LZ = location(c) + one_field = OneField(Int) + return ConditionalOperation{LX, LY, LZ}(one_field, nothing, c.grid, c.condition, mask) +end -@inline conditional_length(c::ConditionalOperation) = sum(condition_onefield(c, 0)) -@inline conditional_length(c::ConditionalOperation, dims) = sum(condition_onefield(c, 0); dims = dims) +@inline conditional_length(c::ConditionalOperation) = sum(conditional_one(c, 0)) +@inline conditional_length(c::ConditionalOperation, dims) = sum(conditional_one(c, 0); dims = dims) Adapt.adapt_structure(to, c::ConditionalOperation{LX, LY, LZ}) where {LX, LY, LZ} = - ConditionalOperation{LX, LY, LZ}(adapt(to, c.operand), + ConditionalOperation{LX, LY, LZ}(adapt(to, c.operand), adapt(to, c.func), adapt(to, c.grid), adapt(to, c.condition), @@ -152,10 +173,10 @@ compute_at!(c::ConditionalOperation, time) = compute_at!(c.operand, time) indices(c::ConditionalOperation) = indices(c.operand) Base.show(io::IO, operation::ConditionalOperation) = - print(io, - "ConditionalOperation at $(location(operation))", "\n", - "├── operand: ", summary(operation.operand), "\n", - "├── grid: ", summary(operation.grid), "\n", - "├── func: ", summary(operation.func), "\n", - "├── condition: ", summary(operation.condition), "\n", - "└── mask: ", operation.mask) + print(io, "ConditionalOperation at $(location(operation))", '\n', + "├── operand: ", summary(operation.operand), '\n', + "├── grid: ", summary(operation.grid), '\n', + "├── func: ", summary(operation.func), '\n', + "├── condition: ", summary(operation.condition), '\n', + "└── mask: ", operation.mask) + diff --git a/src/Fields/abstract_field.jl b/src/Fields/abstract_field.jl index c8c9b7b5f1..6f552221e5 100644 --- a/src/Fields/abstract_field.jl +++ b/src/Fields/abstract_field.jl @@ -90,8 +90,6 @@ end return (ax, ay, az, at) end - - """ total_size(field::AbstractField) @@ -126,3 +124,22 @@ for f in (:+, :-) @eval Base.$f(ϕ::AbstractField, ψ::AbstractArray) = $f(interior(ϕ), ψ) end +const XReducedAF = AbstractField{Nothing} +const YReducedAF = AbstractField{<:Any, Nothing} +const ZReducedAF = AbstractField{<:Any, <:Any, Nothing} + +const YZReducedAF = AbstractField{<:Any, Nothing, Nothing} +const XZReducedAF = AbstractField{Nothing, <:Any, Nothing} +const XYReducedAF = AbstractField{Nothing, Nothing, <:Any} + +const XYZReducedAF = AbstractField{Nothing, Nothing, Nothing} + +reduced_dimensions(field::AbstractField) = () +reduced_dimensions(field::XReducedAF) = tuple(1) +reduced_dimensions(field::YReducedAF) = tuple(2) +reduced_dimensions(field::ZReducedAF) = tuple(3) +reduced_dimensions(field::YZReducedAF) = (2, 3) +reduced_dimensions(field::XZReducedAF) = (1, 3) +reduced_dimensions(field::XYReducedAF) = (1, 2) +reduced_dimensions(field::XYZReducedAF) = (1, 2, 3) + diff --git a/src/Fields/field.jl b/src/Fields/field.jl index ee806db849..98760de274 100644 --- a/src/Fields/field.jl +++ b/src/Fields/field.jl @@ -513,15 +513,6 @@ const ReducedField = Union{XReducedField, XYReducedField, XYZReducedField} -reduced_dimensions(field::Field) = () -reduced_dimensions(field::XReducedField) = tuple(1) -reduced_dimensions(field::YReducedField) = tuple(2) -reduced_dimensions(field::ZReducedField) = tuple(3) -reduced_dimensions(field::YZReducedField) = (2, 3) -reduced_dimensions(field::XZReducedField) = (1, 3) -reduced_dimensions(field::XYReducedField) = (1, 2) -reduced_dimensions(field::XYZReducedField) = (1, 2, 3) - @propagate_inbounds Base.getindex(r::XReducedField, i, j, k) = getindex(r.data, 1, j, k) @propagate_inbounds Base.getindex(r::YReducedField, i, j, k) = getindex(r.data, i, 1, k) @propagate_inbounds Base.getindex(r::ZReducedField, i, j, k) = getindex(r.data, i, j, 1) @@ -628,15 +619,14 @@ function reduced_dimension(loc) return dims end -## Allow support for ConditionalOperation - get_neutral_mask(::Union{AllReduction, AnyReduction}) = true -get_neutral_mask(::Union{SumReduction, MeanReduction}) = 0 -get_neutral_mask(::MinimumReduction) = Inf -get_neutral_mask(::MaximumReduction) = - Inf -get_neutral_mask(::ProdReduction) = 1 +get_neutral_mask(::Union{SumReduction, MeanReduction}) = 0 +get_neutral_mask(::ProdReduction) = 1 + +# TODO make this Float32 friendly +get_neutral_mask(::MinimumReduction) = +Inf +get_neutral_mask(::MaximumReduction) = -Inf -# If func = identity and condition = nothing, nothing happens """ condition_operand(f::Function, op::AbstractField, condition, mask) @@ -646,8 +636,11 @@ If `f isa identity` and `isnothing(condition)` then `op` is returned without wra Otherwise return `ConditionedOperand`, even when `isnothing(condition)` but `!(f isa identity)`. """ -@inline condition_operand(op::AbstractField, condition, mask) = condition_operand(identity, op, condition, mask) -@inline condition_operand(::typeof(identity), operand::AbstractField, ::Nothing, mask) = operand +@inline condition_operand(op::AbstractField, condition, mask) = condition_operand(nothing, op, condition, mask) + +# Do NOT condition if condition=nothing. +# All non-trivial conditioning is found in AbstractOperations/conditional_operations.jl +@inline condition_operand(::Nothing, operand, ::Nothing, mask) = operand @inline conditional_length(c::AbstractField) = length(c) @inline conditional_length(c::AbstractField, dims) = mapreduce(i -> size(c, i), *, unique(dims); init=1) @@ -692,10 +685,10 @@ for reduction in (:sum, :maximum, :minimum, :all, :any, :prod) mask = get_neutral_mask(Base.$(reduction!)), dims = :) + conditioned_c = condition_operand(f, c, condition, mask) T = filltype(Base.$(reduction!), c) loc = reduced_location(location(c); dims) r = Field(loc, c.grid, T; indices=indices(c)) - conditioned_c = condition_operand(f, c, condition, mask) initialize_reduced_field!(Base.$(reduction!), identity, r, conditioned_c) Base.$(reduction!)(identity, r, conditioned_c, init=false) @@ -767,3 +760,4 @@ function fill_halo_regions!(field::Field, args...; kwargs...) return nothing end + diff --git a/src/ImmersedBoundaries/immersed_reductions.jl b/src/ImmersedBoundaries/immersed_reductions.jl index a9969363c6..a5a5224cb1 100644 --- a/src/ImmersedBoundaries/immersed_reductions.jl +++ b/src/ImmersedBoundaries/immersed_reductions.jl @@ -11,41 +11,61 @@ import Oceananigans.Fields: condition_operand, conditional_length @inline truefunc(args...) = true struct NotImmersed{F} <: Function - func :: F + condition :: F end +NotImmersed() = NotImmersed(nothing) +Base.summary(::NotImmersed{Nothing}) = "NotImmersed()" +Base.summary(::NotImmersed) = string("NotImmersed(", summary(condition), ")") + # ImmersedField const IF = AbstractField{<:Any, <:Any, <:Any, <:ImmersedBoundaryGrid} -@inline condition_operand(func::Function, op::IF, cond, mask) = ConditionalOperation(op; func, condition=NotImmersed(cond), mask) -@inline condition_operand(func::Function, op::IF, ::Nothing, mask) = ConditionalOperation(op; func, condition=NotImmersed(truefunc), mask) -@inline condition_operand(func::typeof(identity), op::IF, ::Nothing, mask) = ConditionalOperation(op; func, condition=NotImmersed(truefunc), mask) +function ConditionalOperation(operand::IF; + func = nothing, + condition = nothing, + mask = zero(eltype(operand))) -@inline function condition_operand(func::Function, op::IF, cond::AbstractArray, mask) - arch = architecture(op.grid) - arch_condition = on_architecture(arch, cond) - ni_condition = NotImmersed(arch_condition) - return ConditionalOperation(op; func, condition=ni_condition, mask) + immersed_condition = NotImmersed(condition) + LX, LY, LZ = location(operand) + grid = operand.grid + return ConditionalOperation{LX, LY, LZ}(operand, func, grid, immersed_condition, mask) end -@inline conditional_length(c::IF) = conditional_length(condition_operand(identity, c, nothing, 0)) -@inline conditional_length(c::IF, dims) = conditional_length(condition_operand(identity, c, nothing, 0), dims) +@inline conditional_length(c::IF) = conditional_length(condition_operand(c, nothing, 0)) +@inline conditional_length(c::IF, dims) = conditional_length(condition_operand(c, nothing, 0), dims) + +@inline function evaluate_condition(::NotImmersed{Nothing}, + i, j, k, + grid::ImmersedBoundaryGrid, + co::ConditionalOperation) #, args...) -@inline function evaluate_condition(condition::NotImmersed, i, j, k, ibg, co::ConditionalOperation, args...) ℓx, ℓy, ℓz = map(instantiate, location(co)) - immersed = immersed_peripheral_node(i, j, k, ibg, ℓx, ℓy, ℓz) | inactive_node(i, j, k, ibg, ℓx, ℓy, ℓz) - return !immersed & evaluate_condition(condition.func, i, j, k, ibg, args...) + immersed = immersed_peripheral_node(i, j, k, grid, ℓx, ℓy, ℓz) | inactive_node(i, j, k, grid, ℓx, ℓy, ℓz) + return !immersed +end + +@inline function evaluate_condition(ni::NotImmersed, + i, j, k, + grid::ImmersedBoundaryGrid, + co::ConditionalOperation, args...) + + ℓx, ℓy, ℓz = map(instantiate, location(co)) + immersed = immersed_peripheral_node(i, j, k, grid, ℓx, ℓy, ℓz) | inactive_node(i, j, k, grid, ℓx, ℓy, ℓz) + return !immersed & evaluate_condition(ni.condition, i, j, k, grid, co, args...) end ##### ##### Reduction operations on Reduced Fields test the immersed condition on the entirety of the immersed direction ##### -struct NotImmersedColumn{IC, F} <:Function +struct NotImmersedColumn{F, IC} <:Function immersed_column :: IC - func :: F + condition :: F end +NotImmersedColumn(immersed_column) = NotImmersedColumn(immersed_column, nothing) + using Oceananigans.Fields: reduced_dimensions, OneField using Oceananigans.AbstractOperations: ConditionalOperation @@ -62,15 +82,16 @@ const XYZIRF = AbstractField{Nothing, Nothing, Nothing, <:ImmersedBoundaryGrid} const IRF = Union{XIRF, YIRF, ZIRF, YZIRF, XZIRF, XYIRF, XYZIRF} -@inline condition_operand(func::Function, op::IRF, cond, mask) = ConditionalOperation(op; func, condition=NotImmersedColumn(immersed_column(op), cond ), mask) -@inline condition_operand(func::Function, op::IRF, ::Nothing, mask) = ConditionalOperation(op; func, condition=NotImmersedColumn(immersed_column(op), truefunc), mask) -@inline condition_operand(func::typeof(identity), op::IRF, ::Nothing, mask) = ConditionalOperation(op; func, condition=NotImmersedColumn(immersed_column(op), truefunc), mask) +@inline function condition_operand(func, op::IRF, condition, mask) + immersed_condition = NotImmersedColumn(immersed_column(op), condition) + return ConditionalOperation(op; func, condition, mask) +end @inline function immersed_column(field::IRF) grid = field.grid reduced_dims = reduced_dimensions(field) LX, LY, LZ = map(center_to_nothing, location(field)) - one_field = ConditionalOperation{LX, LY, LZ}(OneField(Int), identity, grid, NotImmersed(truefunc), zero(grid)) + one_field = ConditionalOperation{LX, LY, LZ}(OneField(Int), identity, grid, NotImmersed(), zero(grid)) return sum(one_field, dims=reduced_dims) end @@ -78,9 +99,14 @@ end @inline center_to_nothing(::Type{Center}) = Center @inline center_to_nothing(::Type{Nothing}) = Center -@inline function evaluate_condition(condition::NotImmersedColumn, i, j, k, ibg, co::ConditionalOperation, args...) +@inline function evaluate_condition(nic::NotImmersedColumn, + i, j, k, + grid::ImmersedBoundaryGrid, + co::ConditionalOperation, args...) LX, LY, LZ = location(co) - return evaluate_condition(condition.func, i, j, k, ibg, args...) & !(is_immersed_column(i, j, k, condition.immersed_column)) + immersed = is_immersed_column(i, j, k, nic.immersed_column) + return !immersed & evaluate_condition(nic.condition, i, j, k, grid, args...) end @inline is_immersed_column(i, j, k, column) = @inbounds column[i, j, k] == 0 +