Skip to content

Commit

Permalink
make Base.reduced_indices more type-stable (#52905)
Browse files Browse the repository at this point in the history
This fixes #35199 by rewriting `Base.reduced_indices` to be
type stable (and grounded). I was also able to remove a method since
that case is covered by the general case.

The changes are illustrated by the following quick benchmarks:

```julia
julia> VERSION
v"1.10.0"

julia> using BenchmarkTools

julia> M = [1 2; 3 4]
2×2 Matrix{Int64}:
 1  2
 3  4

julia> @Btime sum($M, dims=$(2))
  194.816 ns (5 allocations: 160 bytes)
2×1 Matrix{Int64}:
 3
 7

julia> @Btime sum($M, dims=$((2,)))
  209.385 ns (5 allocations: 224 bytes)
2×1 Matrix{Int64}:
 3
 7

julia> function my_reduced_indices(inds::Base.Indices{N}, region) where N
           rinds = inds
           for i in region
               isa(i, Integer) || throw(ArgumentError("reduced dimension(s) must be integers"))
               d = Int(i)
               if d < 1
                   throw(ArgumentError("region dimension(s) must be ≥ 1, got $d"))
               elseif d <= N
                   rinds = let rinds_=rinds
                       ntuple(j -> j == d ? Base.reduced_index(rinds_[d]) : rinds_[j], Val(N))
                   end
               end
           end
           rinds
       end
my_reduced_indices (generic function with 1 method)

julia> Base.reduced_indices(inds::Base.Indices{N}, region::Int) where N = my_reduced_indices(inds, region)

julia> Base.reduced_indices(inds::Base.Indices{N}, region) where N = my_reduced_indices(inds, region)

julia> @Btime sum($M, dims=$(2))
  43.582 ns (1 allocation: 80 bytes)
2×1 Matrix{Int64}:
 3
 7

julia> @Btime sum($M, dims=$((2,)))
  43.882 ns (1 allocation: 80 bytes)
2×1 Matrix{Int64}:
 3
 7
```

I also rewrote `Base.reduced_indices0` in the same fashion. I wasn't
sure how to add tests for this since the improvements are to
type-groundedness. Since these changes affect all reductions I hope this
solution is robust.
  • Loading branch information
lxvm authored Jan 30, 2024
1 parent 432d248 commit 4b1bbeb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 48 deletions.
58 changes: 10 additions & 48 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,59 +17,21 @@ reduced_indices(a::AbstractArrayOrBroadcasted, region) = reduced_indices(axes(a)
# for reductions that keep 0 dims as 0
reduced_indices0(a::AbstractArray, region) = reduced_indices0(axes(a), region)

function reduced_indices(inds::Indices{N}, d::Int) where N
d < 1 && throw(ArgumentError("dimension must be ≥ 1, got $d"))
if d == 1
return (reduced_index(inds[1]), tail(inds)...)::typeof(inds)
elseif 1 < d <= N
return tuple(inds[1:d-1]..., oftype(inds[d], reduced_index(inds[d])), inds[d+1:N]...)::typeof(inds)
else
return inds
end
function reduced_indices(axs::Indices{N}, region) where N
_check_valid_region(region)
ntuple(d -> d in region ? reduced_index(axs[d]) : axs[d], Val(N))
end

function reduced_indices0(inds::Indices{N}, d::Int) where N
d < 1 && throw(ArgumentError("dimension must be ≥ 1, got $d"))
if d <= N
ind = inds[d]
rd = isempty(ind) ? ind : reduced_index(inds[d])
if d == 1
return (rd, tail(inds)...)::typeof(inds)
else
return tuple(inds[1:d-1]..., oftype(inds[d], rd), inds[d+1:N]...)::typeof(inds)
end
else
return inds
end
function reduced_indices0(axs::Indices{N}, region) where N
_check_valid_region(region)
ntuple(d -> d in region && !isempty(axs[d]) ? reduced_index(axs[d]) : axs[d], Val(N))
end

function reduced_indices(inds::Indices{N}, region) where N
rinds = collect(inds)
for i in region
isa(i, Integer) || throw(ArgumentError("reduced dimension(s) must be integers"))
d = Int(i)
if d < 1
throw(ArgumentError("region dimension(s) must be ≥ 1, got $d"))
elseif d <= N
rinds[d] = reduced_index(rinds[d])
end
end
tuple(rinds...)::typeof(inds)
end

function reduced_indices0(inds::Indices{N}, region) where N
rinds = collect(inds)
for i in region
isa(i, Integer) || throw(ArgumentError("reduced dimension(s) must be integers"))
d = Int(i)
if d < 1
throw(ArgumentError("region dimension(s) must be ≥ 1, got $d"))
elseif d <= N
rind = rinds[d]
rinds[d] = isempty(rind) ? rind : reduced_index(rind)
end
function _check_valid_region(region)
for d in region
isa(d, Integer) || throw(ArgumentError("reduced dimension(s) must be integers"))
Int(d) < 1 && throw(ArgumentError("region dimension(s) must be ≥ 1, got $d"))
end
tuple(rinds...)::typeof(inds)
end

###### Generic reduction functions #####
Expand Down
12 changes: 12 additions & 0 deletions test/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ fill!(r, -6.3)
fill!(r, -1.1)
@test sum!(abs2, r, Breduc, init=false) safe_sumabs2(Breduc, 1) .- 1.1

# issue #35199
function issue35199_test(sizes, dims)
M = rand(Float64, sizes)
ax = axes(M)
n1 = @allocations Base.reduced_indices(ax, dims)
return @test n1 == 0
end
for dims in (1, 2, (1,), (2,), (1,2))
sizes = (64, 3)
issue35199_test(sizes, dims)
end

# Small arrays with init=false
let A = reshape(1:15, 3, 5)
R = fill(1, 3)
Expand Down

0 comments on commit 4b1bbeb

Please sign in to comment.