Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Massive clean up for conditional operations #3794

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 58 additions & 37 deletions src/AbstractOperations/conditional_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,27 @@ import Oceananigans.Architectures: on_architecture
import Oceananigans.Fields: condition_operand, conditional_length, set!, compute_at!, indices

# For conditional reductions such as mean(u * v, condition = u .> 0))
struct ConditionalOperation{LX, LY, LZ, O, F, G, C, M, T} <: AbstractOperation{LX, LY, LZ, G, T}
struct ConditionalOperation{LX, LY, LZ, F, C, O, G, M, T} <: AbstractOperation{LX, LY, LZ, G, T}
operand :: O
func :: F
grid :: G
condition :: C
mask :: M

function ConditionalOperation{LX, LY, LZ}(operand::O, func::F, grid::G,
condition::C, mask::M) where {LX, LY, LZ, O, F, G, C, M}
function ConditionalOperation{LX, LY, LZ}(operand::O, func, grid::G,
condition::C, mask::M) where {LX, LY, LZ, O, G, C, M}
if func === Base.identity
func = nothing
end
T = eltype(operand)
return new{LX, LY, LZ, O, F, G, C, M, T}(operand, func, grid, condition, mask)
F = typeof(func)
return new{LX, LY, LZ, F, C, O, G, M, T}(operand, func, grid, condition, mask)
end
end

"""
ConditionalOperation(operand::AbstractField;
func = identity,
func = nothing,
condition = nothing,
mask = 0)

Expand All @@ -31,19 +35,19 @@ described by `func(operand)`.
Positional arguments
====================

- `operand`: The `AbstractField` to be masked (it must have a `grid` property!)
- `operand`: The `AbstractField` to be masked.

Keyword arguments
=================

- `func`: A unary transformation applied element-wise to the field `operand` at locations where
`condition == true`. Default is `identity`.
`condition == true`. Default is `nothing` which applies no transformation.

- `condition`: either a function of `(i, j, k, grid, operand)` returning a Boolean,
or a 3-dimensional Boolean `AbstractArray`. At locations where `condition == false`,
operand will be masked by `mask`
operand will be masked by `mask`.

- `mask`: the scalar mask
- `mask`: the scalar mask. Default: 0.

`condition_operand` is a convenience function used to construct a `ConditionalOperation`

Expand Down Expand Up @@ -78,10 +82,9 @@ julia> d[2, 1, 1]
```
"""
function ConditionalOperation(operand::AbstractField;
func = identity,
func = nothing,
condition = nothing,
mask = zero(eltype(operand)))

LX, LY, LZ = location(operand)
return ConditionalOperation{LX, LY, LZ}(operand, func, operand.grid, condition, mask)
end
Expand All @@ -95,29 +98,44 @@ function ConditionalOperation(c::ConditionalOperation;
return ConditionalOperation{LX, LY, LZ}(c.operand, func, c.grid, condition, mask)
end

struct TrueCondition end
@inline function Base.getindex(co::ConditionalOperation, i, j, k)
conditioned = evaluate_condition(co.condition, i, j, k, c.grid, c)
value = getindex(co.operand, i, j, k)
func_value = co.func(value)
return ifelse(conditioned, value, c.mask)
end

# Some special cases
const NoFuncCO = ConditionalOperation{<:Any, <:Any, <:Any, Nothing}
const NoConditionCO = ConditionalOperation{<:Any, <:Any, <:Any, <:Any, Nothing}
const NoFuncNoConditionCO = ConditionalOperation{<:Any, <:Any, <:Any, Nothing, Nothing}

using Base: @propagate_inbounds
@propagate_inbounds function Base.getindex(co::NoConditionCO, i, j, k)
value = getindex(co.operand, i, j, k)
return co.func(value)
end

@propagate_inbounds Base.getindex(co::NoFuncNoConditionCO, i, j, k) = getindex(co.operand, i, j, k)

@inline function Base.getindex(c::ConditionalOperation, i, j, k)
return ifelse(evaluate_condition(c.condition, i, j, k, c.grid, c),
c.func(getindex(c.operand, i, j, k)),
c.mask)
@propagate_inbounds function Base.getindex(co::NoFuncCO, i, j, k)
conditioned = evaluate_condition(co.condition, i, j, k, co.grid, co)
value = getindex(co.operand, i, j, k)
return ifelse(conditioned, value, co.mask)
end

@inline evaluate_condition(condition, i, j, k, grid, args...) = condition(i, j, k, grid, args...)
@inline evaluate_condition(::TrueCondition, i, j, k, grid, args...) = true
@inline evaluate_condition(condition::AbstractArray, i, j, k, grid, args...) = @inbounds condition[i, j, k]
# Conditions: general, nothing, array
@inline evaluate_condition(condition, i, j, k, grid, args...) = condition(i, j, k, grid, args...)
@inline evaluate_condition(::Nothing, i, j, k, grid, args...) = true
@propagate_inbounds evaluate_condition(condition::AbstractArray, i, j, k, grid, args...) = condition[i, j, k]

@inline condition_operand(func::Function, op::AbstractField, condition, mask) = ConditionalOperation(op; func, condition, mask)
@inline condition_operand(func::Function, op::AbstractField, ::Nothing, mask) = ConditionalOperation(op; func, condition=TrueCondition(), mask)
@inline condition_operand(func, op, condition, mask) = ConditionalOperation(op; func, condition, mask)

@inline function condition_operand(func::Function, operand::AbstractField, condition::AbstractArray, mask)
@inline function condition_operand(func, operand, condition::AbstractArray, mask)
condition = on_architecture(architecture(operand.grid), condition)
return ConditionalOperation(operand; func, condition, mask)
end

@inline condition_operand(func::typeof(identity), c::ConditionalOperation, ::Nothing, mask) = ConditionalOperation(c; mask)
@inline condition_operand(func::Function, c::ConditionalOperation, ::Nothing, mask) = ConditionalOperation(c; func, mask)

@inline materialize_condition!(c::ConditionalOperation) = set!(c.operand, c)

function materialize_condition(c::ConditionalOperation)
Expand All @@ -126,14 +144,17 @@ function materialize_condition(c::ConditionalOperation)
return f
end

@inline condition_onefield(c::ConditionalOperation{LX, LY, LZ}, mask) where {LX, LY, LZ} =
ConditionalOperation{LX, LY, LZ}(OneField(Int), identity, c.grid, c.condition, mask)
@inline function conditional_one(c::ConditionalOperation, mask)
LX, LY, LZ = location(c)
one_field = OneField(Int)
return ConditionalOperation{LX, LY, LZ}(one_field, nothing, c.grid, c.condition, mask)
end

@inline conditional_length(c::ConditionalOperation) = sum(condition_onefield(c, 0))
@inline conditional_length(c::ConditionalOperation, dims) = sum(condition_onefield(c, 0); dims = dims)
@inline conditional_length(c::ConditionalOperation) = sum(conditional_one(c, 0))
@inline conditional_length(c::ConditionalOperation, dims) = sum(conditional_one(c, 0); dims = dims)

Adapt.adapt_structure(to, c::ConditionalOperation{LX, LY, LZ}) where {LX, LY, LZ} =
ConditionalOperation{LX, LY, LZ}(adapt(to, c.operand),
ConditionalOperation{LX, LY, LZ}(adapt(to, c.operand),
adapt(to, c.func),
adapt(to, c.grid),
adapt(to, c.condition),
Expand All @@ -152,10 +173,10 @@ compute_at!(c::ConditionalOperation, time) = compute_at!(c.operand, time)
indices(c::ConditionalOperation) = indices(c.operand)

Base.show(io::IO, operation::ConditionalOperation) =
print(io,
"ConditionalOperation at $(location(operation))", "\n",
"├── operand: ", summary(operation.operand), "\n",
"├── grid: ", summary(operation.grid), "\n",
"├── func: ", summary(operation.func), "\n",
"├── condition: ", summary(operation.condition), "\n",
"└── mask: ", operation.mask)
print(io, "ConditionalOperation at $(location(operation))", '\n',
"├── operand: ", summary(operation.operand), '\n',
"├── grid: ", summary(operation.grid), '\n',
"├── func: ", summary(operation.func), '\n',
"├── condition: ", summary(operation.condition), '\n',
"└── mask: ", operation.mask)

21 changes: 19 additions & 2 deletions src/Fields/abstract_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ end
return (ax, ay, az, at)
end



"""
total_size(field::AbstractField)

Expand Down Expand Up @@ -126,3 +124,22 @@ for f in (:+, :-)
@eval Base.$f(ϕ::AbstractField, ψ::AbstractArray) = $f(interior(ϕ), ψ)
end

const XReducedAF = AbstractField{Nothing}
const YReducedAF = AbstractField{<:Any, Nothing}
const ZReducedAF = AbstractField{<:Any, <:Any, Nothing}

const YZReducedAF = AbstractField{<:Any, Nothing, Nothing}
const XZReducedAF = AbstractField{Nothing, <:Any, Nothing}
const XYReducedAF = AbstractField{Nothing, Nothing, <:Any}

const XYZReducedAF = AbstractField{Nothing, Nothing, Nothing}

reduced_dimensions(field::AbstractField) = ()
reduced_dimensions(field::XReducedAF) = tuple(1)
reduced_dimensions(field::YReducedAF) = tuple(2)
reduced_dimensions(field::ZReducedAF) = tuple(3)
reduced_dimensions(field::YZReducedAF) = (2, 3)
reduced_dimensions(field::XZReducedAF) = (1, 3)
reduced_dimensions(field::XYReducedAF) = (1, 2)
reduced_dimensions(field::XYZReducedAF) = (1, 2, 3)

32 changes: 13 additions & 19 deletions src/Fields/field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -513,15 +513,6 @@ const ReducedField = Union{XReducedField,
XYReducedField,
XYZReducedField}

reduced_dimensions(field::Field) = ()
reduced_dimensions(field::XReducedField) = tuple(1)
reduced_dimensions(field::YReducedField) = tuple(2)
reduced_dimensions(field::ZReducedField) = tuple(3)
reduced_dimensions(field::YZReducedField) = (2, 3)
reduced_dimensions(field::XZReducedField) = (1, 3)
reduced_dimensions(field::XYReducedField) = (1, 2)
reduced_dimensions(field::XYZReducedField) = (1, 2, 3)

@propagate_inbounds Base.getindex(r::XReducedField, i, j, k) = getindex(r.data, 1, j, k)
@propagate_inbounds Base.getindex(r::YReducedField, i, j, k) = getindex(r.data, i, 1, k)
@propagate_inbounds Base.getindex(r::ZReducedField, i, j, k) = getindex(r.data, i, j, 1)
Expand Down Expand Up @@ -628,15 +619,14 @@ function reduced_dimension(loc)
return dims
end

## Allow support for ConditionalOperation

get_neutral_mask(::Union{AllReduction, AnyReduction}) = true
get_neutral_mask(::Union{SumReduction, MeanReduction}) = 0
get_neutral_mask(::MinimumReduction) = Inf
get_neutral_mask(::MaximumReduction) = - Inf
get_neutral_mask(::ProdReduction) = 1
get_neutral_mask(::Union{SumReduction, MeanReduction}) = 0
get_neutral_mask(::ProdReduction) = 1

# TODO make this Float32 friendly
get_neutral_mask(::MinimumReduction) = +Inf
get_neutral_mask(::MaximumReduction) = -Inf

# If func = identity and condition = nothing, nothing happens
"""
condition_operand(f::Function, op::AbstractField, condition, mask)

Expand All @@ -646,8 +636,11 @@ If `f isa identity` and `isnothing(condition)` then `op` is returned without wra

Otherwise return `ConditionedOperand`, even when `isnothing(condition)` but `!(f isa identity)`.
"""
@inline condition_operand(op::AbstractField, condition, mask) = condition_operand(identity, op, condition, mask)
@inline condition_operand(::typeof(identity), operand::AbstractField, ::Nothing, mask) = operand
@inline condition_operand(op::AbstractField, condition, mask) = condition_operand(nothing, op, condition, mask)

# Do NOT condition if condition=nothing.
# All non-trivial conditioning is found in AbstractOperations/conditional_operations.jl
@inline condition_operand(::Nothing, operand, ::Nothing, mask) = operand

@inline conditional_length(c::AbstractField) = length(c)
@inline conditional_length(c::AbstractField, dims) = mapreduce(i -> size(c, i), *, unique(dims); init=1)
Expand Down Expand Up @@ -692,10 +685,10 @@ for reduction in (:sum, :maximum, :minimum, :all, :any, :prod)
mask = get_neutral_mask(Base.$(reduction!)),
dims = :)

conditioned_c = condition_operand(f, c, condition, mask)
T = filltype(Base.$(reduction!), c)
loc = reduced_location(location(c); dims)
r = Field(loc, c.grid, T; indices=indices(c))
conditioned_c = condition_operand(f, c, condition, mask)
initialize_reduced_field!(Base.$(reduction!), identity, r, conditioned_c)
Base.$(reduction!)(identity, r, conditioned_c, init=false)

Expand Down Expand Up @@ -767,3 +760,4 @@ function fill_halo_regions!(field::Field, args...; kwargs...)

return nothing
end

70 changes: 48 additions & 22 deletions src/ImmersedBoundaries/immersed_reductions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,61 @@ import Oceananigans.Fields: condition_operand, conditional_length
@inline truefunc(args...) = true

struct NotImmersed{F} <: Function
func :: F
condition :: F
end

NotImmersed() = NotImmersed(nothing)
Base.summary(::NotImmersed{Nothing}) = "NotImmersed()"
Base.summary(::NotImmersed) = string("NotImmersed(", summary(condition), ")")

# ImmersedField
const IF = AbstractField{<:Any, <:Any, <:Any, <:ImmersedBoundaryGrid}

@inline condition_operand(func::Function, op::IF, cond, mask) = ConditionalOperation(op; func, condition=NotImmersed(cond), mask)
@inline condition_operand(func::Function, op::IF, ::Nothing, mask) = ConditionalOperation(op; func, condition=NotImmersed(truefunc), mask)
@inline condition_operand(func::typeof(identity), op::IF, ::Nothing, mask) = ConditionalOperation(op; func, condition=NotImmersed(truefunc), mask)
function ConditionalOperation(operand::IF;
func = nothing,
condition = nothing,
mask = zero(eltype(operand)))

@inline function condition_operand(func::Function, op::IF, cond::AbstractArray, mask)
arch = architecture(op.grid)
arch_condition = on_architecture(arch, cond)
ni_condition = NotImmersed(arch_condition)
return ConditionalOperation(op; func, condition=ni_condition, mask)
immersed_condition = NotImmersed(condition)
LX, LY, LZ = location(operand)
grid = operand.grid
return ConditionalOperation{LX, LY, LZ}(operand, func, grid, immersed_condition, mask)
end

@inline conditional_length(c::IF) = conditional_length(condition_operand(identity, c, nothing, 0))
@inline conditional_length(c::IF, dims) = conditional_length(condition_operand(identity, c, nothing, 0), dims)
@inline conditional_length(c::IF) = conditional_length(condition_operand(c, nothing, 0))
@inline conditional_length(c::IF, dims) = conditional_length(condition_operand(c, nothing, 0), dims)

@inline function evaluate_condition(::NotImmersed{Nothing},
i, j, k,
grid::ImmersedBoundaryGrid,
co::ConditionalOperation) #, args...)

@inline function evaluate_condition(condition::NotImmersed, i, j, k, ibg, co::ConditionalOperation, args...)
ℓx, ℓy, ℓz = map(instantiate, location(co))
immersed = immersed_peripheral_node(i, j, k, ibg, ℓx, ℓy, ℓz) | inactive_node(i, j, k, ibg, ℓx, ℓy, ℓz)
return !immersed & evaluate_condition(condition.func, i, j, k, ibg, args...)
immersed = immersed_peripheral_node(i, j, k, grid, ℓx, ℓy, ℓz) | inactive_node(i, j, k, grid, ℓx, ℓy, ℓz)
return !immersed
end

@inline function evaluate_condition(ni::NotImmersed,
i, j, k,
grid::ImmersedBoundaryGrid,
co::ConditionalOperation, args...)

ℓx, ℓy, ℓz = map(instantiate, location(co))
immersed = immersed_peripheral_node(i, j, k, grid, ℓx, ℓy, ℓz) | inactive_node(i, j, k, grid, ℓx, ℓy, ℓz)
return !immersed & evaluate_condition(ni.condition, i, j, k, grid, co, args...)
end

#####
##### Reduction operations on Reduced Fields test the immersed condition on the entirety of the immersed direction
#####

struct NotImmersedColumn{IC, F} <:Function
struct NotImmersedColumn{F, IC} <:Function
immersed_column :: IC
func :: F
condition :: F
end

NotImmersedColumn(immersed_column) = NotImmersedColumn(immersed_column, nothing)

using Oceananigans.Fields: reduced_dimensions, OneField
using Oceananigans.AbstractOperations: ConditionalOperation

Expand All @@ -62,25 +82,31 @@ const XYZIRF = AbstractField{Nothing, Nothing, Nothing, <:ImmersedBoundaryGrid}

const IRF = Union{XIRF, YIRF, ZIRF, YZIRF, XZIRF, XYIRF, XYZIRF}

@inline condition_operand(func::Function, op::IRF, cond, mask) = ConditionalOperation(op; func, condition=NotImmersedColumn(immersed_column(op), cond ), mask)
@inline condition_operand(func::Function, op::IRF, ::Nothing, mask) = ConditionalOperation(op; func, condition=NotImmersedColumn(immersed_column(op), truefunc), mask)
@inline condition_operand(func::typeof(identity), op::IRF, ::Nothing, mask) = ConditionalOperation(op; func, condition=NotImmersedColumn(immersed_column(op), truefunc), mask)
@inline function condition_operand(func, op::IRF, condition, mask)
immersed_condition = NotImmersedColumn(immersed_column(op), condition)
return ConditionalOperation(op; func, condition, mask)
end

@inline function immersed_column(field::IRF)
grid = field.grid
reduced_dims = reduced_dimensions(field)
LX, LY, LZ = map(center_to_nothing, location(field))
one_field = ConditionalOperation{LX, LY, LZ}(OneField(Int), identity, grid, NotImmersed(truefunc), zero(grid))
one_field = ConditionalOperation{LX, LY, LZ}(OneField(Int), identity, grid, NotImmersed(), zero(grid))
return sum(one_field, dims=reduced_dims)
end

@inline center_to_nothing(::Type{Face}) = Face
@inline center_to_nothing(::Type{Center}) = Center
@inline center_to_nothing(::Type{Nothing}) = Center

@inline function evaluate_condition(condition::NotImmersedColumn, i, j, k, ibg, co::ConditionalOperation, args...)
@inline function evaluate_condition(nic::NotImmersedColumn,
i, j, k,
grid::ImmersedBoundaryGrid,
co::ConditionalOperation, args...)
LX, LY, LZ = location(co)
return evaluate_condition(condition.func, i, j, k, ibg, args...) & !(is_immersed_column(i, j, k, condition.immersed_column))
immersed = is_immersed_column(i, j, k, nic.immersed_column)
return !immersed & evaluate_condition(nic.condition, i, j, k, grid, args...)
end

@inline is_immersed_column(i, j, k, column) = @inbounds column[i, j, k] == 0