Skip to content

Commit

Permalink
chore: improve broadcast_dims function to work directly with `Dimensi…
Browse files Browse the repository at this point in the history
…on`s and `DimArray`s

The `broadcast_dims` function now supports broadcasting over `Dimension`s and references to them, in addition to `AbstractDimArray`s. This allows for more flexibility when broadcasting over combinations of `DimArray`s and `Dimension`s.
  • Loading branch information
haakon-e committed Aug 20, 2024
1 parent d10d0fc commit c776e46
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 9 deletions.
13 changes: 13 additions & 0 deletions docs/src/broadcast_dims.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,16 @@ We can see the means of each month are scaled by the broadcast :
mean(eachslice(data; dims=(X, Y)))
mean(eachslice(scaled; dims=(X, Y)))
````

Broadcasting also works directly over `Dimension`s (or references to them).
For example, a new `DimArray` can be constructed by broadcasting a function over a set of dimensions:

````@ansi bd
broadcast_dims(*, x, y)
````

Existing dimensions can be referenced by name, in which case its lookup values are used:

````@ansi bd
broadcast_dims(*, data, X) # or `:X` or `X(:)`
````
102 changes: 94 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,25 +108,111 @@ function modify(f, index::AbstractArray)
end

"""
broadcast_dims(f, sources::AbstractDimArray...) => AbstractDimArray
broadcast_dims(f, sources::Union{AbstractDimArray, Dimension, Symbol}...) => AbstractDimArray
Broadcast function `f` over the `AbstractDimArray`s in `sources`, permuting and reshaping
dimensions to match where required. The result will contain all the dimensions in
all passed in arrays in the order in which they are found.
Broadcast function `f` over the `AbstractDimArray`s, and/or `Dimension`s in `sources`, permuting and reshaping
dimensions to match where required. The result will contain all the dimensions in all passed in arrays in the
order in which they are found.
## Arguments
Existing dimensions can be referenced by e.g. `X`, `:X`, `X(:)`, `X(1.0:0.5:10.0)`.
New dimensions can be passed, but must have an explicit lookup, e.g. `X(1.0:0.5:10.0)`.
- `sources`: `AbstractDimArrays` to broadcast over with `f`.
# Arguments
- `sources`: `AbstractDimArrays`, `Dimension`s, `Symbol`s, to broadcast over with `f`.
This is like broadcasting over every slice of `A` if it is sliced by the dimensions of `B`.
# Throws
- `ArgumentError` if a `Dimension` without explicit lookup values is passed and it is not found among the passed in `DimArray`s.
# Extended help
## Examples
In the simplest use case, `broadcast_dims` can be used to construct a `DimArray` from multiple `Dimension`s:
```julia
julia> x, y, z = X(1:2:6), Y(10.5:1.0:13.5), Z(-0.5:0.5:0.5)
↓ X 1:2:5,
→ Y 10.5:1.0:13.5,
↗ Z -0.5:0.5:0.5
julia> A = broadcast_dims(*, x, y)
╭─────────────────────────╮
│ 3×4 DimArray{Float64,2} │
├─────────────────────────┴────────────────────────────────── dims ┐
↓ X Sampled{Int64} 1:2:5 ForwardOrdered Regular Points,
→ Y Sampled{Float64} 10.5:1.0:13.5 ForwardOrdered Regular Points
└──────────────────────────────────────────────────────────────────┘
↓ → 10.5 11.5 12.5 13.5
1 10.5 11.5 12.5 13.5
3 31.5 34.5 37.5 40.5
5 52.5 57.5 62.5 67.5
```
This is like broadcasting over every slice of `A` if it is
sliced by the dimensions of `B`.
We can also implicitly refer to existing dimensions in `DimArray`s:
```julia
julia> B = ones(x, y);
julia> broadcast_dims(+, B, Y) # also `Y(:)`, or `:Y` works
╭─────────────────────────╮
│ 3×4 DimArray{Float64,2} │
├─────────────────────────┴────────────────────────────────── dims ┐
↓ X Sampled{Int64} 1:2:5 ForwardOrdered Regular Points,
→ Y Sampled{Float64} 10.5:1.0:13.5 ForwardOrdered Regular Points
└──────────────────────────────────────────────────────────────────┘
↓ → 10.5 11.5 12.5 13.5
1 11.5 12.5 13.5 14.5
3 11.5 12.5 13.5 14.5
5 11.5 12.5 13.5 14.5
```
Finally, we can mix and match `DimArray`s and `Dimension`s:
```julia
julia> broadcast_dims(+, A, B, z)
╭───────────────────────────╮
│ 3×4×3 DimArray{Float64,3} │
├───────────────────────────┴───────────────────────────────── dims ┐
↓ X Sampled{Int64} 1:2:5 ForwardOrdered Regular Points,
→ Y Sampled{Float64} 10.5:1.0:13.5 ForwardOrdered Regular Points,
↗ Z Sampled{Float64} -0.5:0.5:0.5 ForwardOrdered Regular Points
└───────────────────────────────────────────────────────────────────┘
[:, :, 1]
↓ → 10.5 11.5 12.5 13.5
1 11.0 12.0 13.0 14.0
3 32.0 35.0 38.0 41.0
5 53.0 58.0 63.0 68.0
```
"""
function broadcast_dims(f, As::AbstractBasicDimArray...)
dims = combinedims(As...)
T = Base.Broadcast.combine_eltypes(f, As)
broadcast_dims!(f, similar(first(As), T, dims), As...)
end

function broadcast_dims(f, As::Union{AbstractBasicDimArray, Dimensions.Dimension, Type{<:Dimension}, Symbol}...)
# We have to look up dims for any actual DimArrays first if support for `X`, `Ti`, `:X`, etc, as input should work,
# because we need the lookup array
existing_dims = combinedims(filter(Base.Fix2(isa, AbstractBasicDimArray), As)...)
Bs = map(As) do A
if A isa Dimension && !(parent(A) isa Colon)
# A dimension is explicitly passed, so use it
DimArray(parent(A), A)
elseif A isa Dimension || A isa Type{<:Dimension} || A isa Symbol
# If a reference to a dimension, e.g. `X(:)`, `X` or `:X` is passed, look up values from `existing_dims`
dim = dims(existing_dims, A)
# If `A` isn't among the existing dimensions, and since we don't have its lookup values, we can't proceed
isnothing(dim) && throw(ArgumentError("Dimension $A not found among the passed in `DimArray`s"))
# otherwise, construct a `DimArray` with the looked up values
DimArray(parent(dim), dim)
else
# finally, if it's actually a `DimArray`, just pass it through
A
end
end # map(As)
broadcast_dims(f, Bs...)
end

function broadcast_dims(f, As::Union{AbstractDimStack,AbstractBasicDimArray}...)
st = _firststack(As...)
nts = _as_extended_nts(NamedTuple(st), As...)
Expand Down
24 changes: 23 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ end
@test dc1 == [2, 4, 6]
dc2 = broadcast_dims(+, da2, db1)
@test dc2 == [2 4 6; 5 7 9]
dc4 = broadcast_dims(+, da2, db1)

A3 = cat([1 2 3; 4 5 6], [11 12 13; 14 15 16]; dims=3)
da3 = DimArray(A3, (X, Y, Z))
Expand All @@ -157,6 +156,29 @@ end
dc3 = broadcast_dims(+, da3, db1)
@test dc3 == cat([2 4 6; 5 7 9], [12 14 16; 15 17 19]; dims=3)

@testset "works directly with Dimensions" begin
x, y, z = X([1, 2, 3]), Y([1, 2]), Z([0.1])

# construct a DimArray from dimensions, using `broadcast_dims`
da_from_dims = broadcast_dims(+, x, y)
@test da_from_dims == [2 3; 3 4; 4 5]

# different ways to refer to existing dimensions
da_with_reference_da = broadcast_dims(+, da_from_dims, DimArray(parent(y), y)) # reference computation
da_and_existing_dims = broadcast_dims(+, da_from_dims, Y)
da_and_existing_dims2 = broadcast_dims(+, da_from_dims, Y(:))
da_and_existing_dims3 = broadcast_dims(+, da_from_dims, :Y)
da_and_existing_dims4 = broadcast_dims(+, da_from_dims, y)
@test da_and_existing_dims == [3 5; 4 6; 5 7]
@test da_and_existing_dims == da_with_reference_da
@test da_and_existing_dims == da_and_existing_dims2
@test da_and_existing_dims == da_and_existing_dims3
@test da_and_existing_dims == da_and_existing_dims4

# combine `DimArray` and `Dimension`
da_and_new_dims = broadcast_dims(+, da_from_dims, z)
@test da_and_new_dims == [2.1 3.1; 3.1 4.1; 4.1 5.1;;;]
end
@testset "works with permuted dims" begin
db2p = permutedims(da2)
dc3p = broadcast_dims(+, da3, db2p)
Expand Down

0 comments on commit c776e46

Please sign in to comment.