Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor scalar range getindex #50467

Merged
merged 10 commits into from
Aug 30, 2023
6 changes: 1 addition & 5 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1288,11 +1288,7 @@ end
# To avoid invalidations from multidimensional.jl: getindex(A::Array, i1::Union{Integer, CartesianIndex}, I::Union{Integer, CartesianIndex}...)
@propagate_inbounds getindex(A::Array, i1::Integer, I::Integer...) = A[to_indices(A, (i1, I...))...]

function unsafe_getindex(A::AbstractArray, I...)
@inline
@inbounds r = getindex(A, I...)
r
end
unsafe_getindex(A::AbstractArray, I...) = @inbounds getindex(A, I...)

struct CanonicalIndexError <: Exception
func::String
Expand Down
9 changes: 9 additions & 0 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2859,3 +2859,12 @@ function intersect(v::AbstractVector, r::AbstractRange)
return vectorfilter(T, _shrink_filter!(seen), common)
end
intersect(r::AbstractRange, v::AbstractVector) = intersect(v, r)

# Here instead of range.jl for bootstrapping because `@propagate_inbounds` depends on Vectors.
@propagate_inbounds function getindex(v::AbstractRange, i::Integer)
if i isa Bool # Not via dispatch to avoid ambiguities
throw(ArgumentError("invalid index: $i of type Bool"))
else
_getindex(v, i)
end
end
63 changes: 21 additions & 42 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -910,11 +910,15 @@ function isassigned(r::AbstractRange, i::Integer)
firstindex(r) <= i <= lastindex(r)
end

# `_getindex` is like `getindex` but does not check if `i isa Bool`
function _getindex(v::AbstractRange, i::Integer)
LilithHafner marked this conversation as resolved.
Show resolved Hide resolved
@boundscheck checkbounds(v, i)
unsafe_getindex(v, i)
end

_in_unit_range(v::UnitRange, val, i::Integer) = i > 0 && val <= v.stop && val >= v.start

function getindex(v::UnitRange{T}, i::Integer) where T
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
function _getindex(v::UnitRange{T}, i::Integer) where T
val = convert(T, v.start + (i - oneunit(i)))
@boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i)
val
Expand All @@ -923,68 +927,38 @@ end
const OverflowSafe = Union{Bool,Int8,Int16,Int32,Int64,Int128,
UInt8,UInt16,UInt32,UInt64,UInt128}

function getindex(v::UnitRange{T}, i::Integer) where {T<:OverflowSafe}
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
function _getindex(v::UnitRange{T}, i::Integer) where {T<:OverflowSafe}
val = v.start + (i - oneunit(i))
@boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i)
val % T
end

function getindex(v::OneTo{T}, i::Integer) where T
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
@boundscheck ((i > 0) & (i <= v.stop)) || throw_boundserror(v, i)
convert(T, i)
end

function getindex(v::AbstractRange{T}, i::Integer) where T
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
@boundscheck checkbounds(v, i)
convert(T, first(v) + (i - oneunit(i))*step_hp(v))
end

let BitInteger64 = Union{Int8,Int16,Int32,Int64,UInt8,UInt16,UInt32,UInt64} # for bootstrapping
function checkbounds(::Type{Bool}, v::StepRange{<:BitInteger64, <:BitInteger64}, i::BitInteger64)
@inline
res = widemul(step(v), i-oneunit(i)) + first(v)
(0 < i) & ifelse(0 < step(v), res <= last(v), res >= last(v))
end
end

function getindex(r::Union{StepRangeLen,LinRange}, i::Integer)
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
@boundscheck checkbounds(r, i)
unsafe_getindex(r, i)
end

# This is separate to make it useful even when running with --check-bounds=yes
# unsafe_getindex is separate to make it useful even when running with --check-bounds=yes
LilithHafner marked this conversation as resolved.
Show resolved Hide resolved
# it assumes the index is inbounds but does not segfault even if the index is out of bounds.
# it does not check if the index isa bool.
unsafe_getindex(v::OneTo{T}, i::Integer) where T = convert(T, i)
unsafe_getindex(v::AbstractRange{T}, i::Integer) where T = convert(T, first(v) + (i - oneunit(i))*step_hp(v))
function unsafe_getindex(r::StepRangeLen{T}, i::Integer) where T
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
u = oftype(r.offset, i) - r.offset
T(r.ref + u*r.step)
end

function _getindex_hiprec(r::StepRangeLen, i::Integer) # without rounding by T
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
u = oftype(r.offset, i) - r.offset
r.ref + u*r.step
end

function unsafe_getindex(r::LinRange, i::Integer)
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
lerpi(i-oneunit(i), r.lendiv, r.start, r.stop)
end
unsafe_getindex(r::LinRange, i::Integer) = lerpi(i-oneunit(i), r.lendiv, r.start, r.stop)

function lerpi(j::Integer, d::Integer, a::T, b::T) where T
@inline
t = j/d # ∈ [0,1]
# compute approximately fma(t, b, -fma(t, a, a))
return T((1-t)*a + t*b)
end

# non-scalar indexing

getindex(r::AbstractRange, ::Colon) = copy(r)

function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integer}
Expand Down Expand Up @@ -1083,6 +1057,11 @@ function getindex(r::StepRangeLen{T}, s::OrdinalRange{S}) where {T, S<:Integer}
end
end

function _getindex_hiprec(r::StepRangeLen, i::Integer) # without rounding by T
u = oftype(r.offset, i) - r.offset
r.ref + u*r.step
end

function getindex(r::LinRange{T}, s::OrdinalRange{S}) where {T, S<:Integer}
@inline
@boundscheck checkbounds(r, s)
Expand Down
2 changes: 0 additions & 2 deletions base/twiceprecision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,6 @@ end
# This assumes that r.step has already been split so that (0:len-1)*r.step.hi is exact
function unsafe_getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, i::Integer) where T
# Very similar to _getindex_hiprec, but optimized to avoid a 2nd call to add12
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
u = oftype(r.offset, i) - r.offset
shift_hi, shift_lo = u*r.step.hi, u*r.step.lo
x_hi, x_lo = add12(r.ref.hi, shift_hi)
Expand Down