Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #35199: type-stable Base.reduced_indices #52905

Merged
merged 8 commits into from
Jan 30, 2024
Merged

Conversation

lxvm
Copy link
Contributor

@lxvm lxvm commented Jan 15, 2024

Hi,
This is my first pr to Julia and I would appreciate feedback from the reviewers. It fixes #35199 by rewriting Base.reduced_indices to be type stable (and grounded). I was also able to remove a method since that case is covered by the general case.

The changes are illustrated by the following quick benchmarks

julia> VERSION
v"1.10.0"

julia> using BenchmarkTools

julia> M = [1 2; 3 4]
2×2 Matrix{Int64}:
 1  2
 3  4

julia> @btime sum($M, dims=$(2))
  194.816 ns (5 allocations: 160 bytes)
2×1 Matrix{Int64}:
 3
 7

julia> @btime sum($M, dims=$((2,)))
  209.385 ns (5 allocations: 224 bytes)
2×1 Matrix{Int64}:
 3
 7

julia> function my_reduced_indices(inds::Base.Indices{N}, region) where N
           rinds = inds
           for i in region
               isa(i, Integer) || throw(ArgumentError("reduced dimension(s) must be integers"))
               d = Int(i)
               if d < 1
                   throw(ArgumentError("region dimension(s) must be ≥ 1, got $d"))
               elseif d <= N
                   rinds = let rinds_=rinds
                       ntuple(j -> j == d ? Base.reduced_index(rinds_[d]) : rinds_[j], Val(N))
                   end
               end
           end
           rinds
       end
my_reduced_indices (generic function with 1 method)

julia> Base.reduced_indices(inds::Base.Indices{N}, region::Int) where N = my_reduced_indices(inds, region)

julia> Base.reduced_indices(inds::Base.Indices{N}, region) where N = my_reduced_indices(inds, region)

julia> @btime sum($M, dims=$(2))
  43.582 ns (1 allocation: 80 bytes)
2×1 Matrix{Int64}:
 3
 7

julia> @btime sum($M, dims=$((2,)))
  43.882 ns (1 allocation: 80 bytes)
2×1 Matrix{Int64}:
 3
 7

I also rewrote Base.reduced_indices0 in the same fashion. I wasn't sure how to add tests for this since the improvements are to type-groundedness. Since these changes affect all reductions I hope this solution is robust.

@KristofferC
Copy link
Sponsor Member

Hi and welcome. For a PR that fixes some existing issue it is almost always a good idea to add a test that confirms the fix is working and that it stays fixed in the future.

@KristofferC KristofferC added the needs tests Unit tests are required for this change label Jan 15, 2024
@lxvm
Copy link
Contributor Author

lxvm commented Jan 15, 2024

Thank you, @KristofferC ! I added some tests to make sure the only allocations made are for the returned array.

Update: the new tests appear to be too strict so I'll drop the sum test since this pr only addresses Base.reduced_indices

@lxvm
Copy link
Contributor Author

lxvm commented Jan 16, 2024

Now that I added tests, I simplified the code to a one-liner, with some checks to keep the previous behavior, and think it is ready for review. For completeness, I've put some benchmarks comparing this pr to master below showing a typical 50x speedup for reduced_indices on matrices.

Benchmarks for master
julia> using BenchmarkTools

julia> M = [1 2; 3 4]
2×2 Matrix{Int64}:
 1  2
 3  4

julia> ax = axes(M)
(Base.OneTo(2), Base.OneTo(2))

julia> for dims in (1, Int32(1), true, 2, Int32(2), (1,), (Int32(1),), (true,), (2,), (Int32(2),), (1,2), (Int32(1),Int32(2)), (true, Int16(2)), [1], Int32[1], Bool[1], [2], Int32[2], [1,2], Int32[1,2], [true, Int16(2)])
       println("dims::", typeof(dims)); @btime Base.reduced_indices($ax, $dims)
       end
dims::Int64
  8.557 ns (0 allocations: 0 bytes)
dims::Int32
  157.837 ns (5 allocations: 144 bytes)
dims::Bool
  156.296 ns (5 allocations: 144 bytes)
dims::Int64
  154.349 ns (4 allocations: 80 bytes)
dims::Int32
  157.170 ns (5 allocations: 144 bytes)
dims::Tuple{Int64}
  156.335 ns (5 allocations: 144 bytes)
dims::Tuple{Int32}
  157.251 ns (5 allocations: 144 bytes)
dims::Tuple{Bool}
  156.024 ns (5 allocations: 144 bytes)
dims::Tuple{Int64}
  156.496 ns (5 allocations: 144 bytes)
dims::Tuple{Int32}
  157.386 ns (5 allocations: 144 bytes)
dims::Tuple{Int64, Int64}
  157.364 ns (5 allocations: 144 bytes)
dims::Tuple{Int32, Int32}
  157.819 ns (5 allocations: 144 bytes)
dims::Tuple{Bool, Int16}
  177.620 ns (6 allocations: 160 bytes)
dims::Vector{Int64}
  156.095 ns (5 allocations: 144 bytes)
dims::Vector{Int32}
  157.387 ns (5 allocations: 144 bytes)
dims::Vector{Bool}
  157.614 ns (5 allocations: 144 bytes)
dims::Vector{Int64}
  156.340 ns (5 allocations: 144 bytes)
dims::Vector{Int32}
  157.545 ns (5 allocations: 144 bytes)
dims::Vector{Int64}
  156.789 ns (5 allocations: 144 bytes)
dims::Vector{Int32}
  158.867 ns (5 allocations: 144 bytes)
dims::Vector{Int16}
  158.285 ns (5 allocations: 144 bytes)
Benchmarks for this pr
julia> using BenchmarkTools

julia> M = [1 2; 3 4]
2×2 Matrix{Int64}:
 1  2
 3  4

julia> ax = axes(M)
(Base.OneTo(2), Base.OneTo(2))

julia> for dims in (1, Int32(1), true, 2, Int32(2), (1,), (Int32(1),), (true,), (2,), (Int32(2),), (1,2), (Int32(1),Int32(2)), (true, Int16(2)), [1], Int32[1], Bool[1], [2], Int32[2], [1,2], Int32[1,2], [true, Int16(2)])
       println("dims::", typeof(dims)); @btime Base.reduced_indices($ax, $dims)
       end
dims::Int64
  3.439 ns (0 allocations: 0 bytes)
dims::Int32
  3.199 ns (0 allocations: 0 bytes)
dims::Bool
  3.676 ns (0 allocations: 0 bytes)
dims::Int64
  3.203 ns (0 allocations: 0 bytes)
dims::Int32
  3.935 ns (0 allocations: 0 bytes)
dims::Tuple{Int64}
  3.197 ns (0 allocations: 0 bytes)
dims::Tuple{Int32}
  3.190 ns (0 allocations: 0 bytes)
dims::Tuple{Bool}
  3.192 ns (0 allocations: 0 bytes)
dims::Tuple{Int64}
  3.924 ns (0 allocations: 0 bytes)
dims::Tuple{Int32}
  3.204 ns (0 allocations: 0 bytes)
dims::Tuple{Int64, Int64}
  8.312 ns (0 allocations: 0 bytes)
dims::Tuple{Int32, Int32}
  8.312 ns (0 allocations: 0 bytes)
dims::Tuple{Bool, Int16}
  36.400 ns (2 allocations: 32 bytes)
dims::Vector{Int64}
  8.315 ns (0 allocations: 0 bytes)
dims::Vector{Int32}
  8.313 ns (0 allocations: 0 bytes)
dims::Vector{Bool}
  8.313 ns (0 allocations: 0 bytes)
dims::Vector{Int64}
  8.557 ns (0 allocations: 0 bytes)
dims::Vector{Int32}
  8.557 ns (0 allocations: 0 bytes)
dims::Vector{Int64}
  8.802 ns (0 allocations: 0 bytes)
dims::Vector{Int32}
  8.800 ns (0 allocations: 0 bytes)
dims::Vector{Int16}
  8.800 ns (0 allocations: 0 bytes)

Copy link
Contributor

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. It preserves slightly weird behaviours like sum(M, dims=(1,1,3)), and seems to provide a real speedup for small arrays.

base/reducedim.jl Outdated Show resolved Hide resolved
base/reducedim.jl Outdated Show resolved Hide resolved
@lxvm
Copy link
Contributor Author

lxvm commented Jan 16, 2024

I've addressed all of the helpful comments, thanks, and I think this is complete. Luckily all of the tests passed!

@lxvm
Copy link
Contributor Author

lxvm commented Jan 30, 2024

Hi, I just wanted to bump this pr and ask if it looks good. Are the tests I added enough?

@vtjnash vtjnash removed the needs tests Unit tests are required for this change label Jan 30, 2024
@vtjnash vtjnash merged commit 4b1bbeb into JuliaLang:master Jan 30, 2024
7 of 8 checks passed
@vtjnash
Copy link
Sponsor Member

vtjnash commented Jan 30, 2024

Thanks! Seems good to me. It is remarkable to see how much more concise the code is now

@lxvm lxvm deleted the issue35199 branch January 31, 2024 01:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

reduced_indices is type unstable
4 participants