Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Safer, extensible ﹫inbounds #8227

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,8 @@ imag{T<:Real}(x::AbstractArray{T}) = zero(x)

getindex(t::AbstractArray, i::Real) = error("indexing not defined for ", typeof(t))

unsafe_getindex(args...) = getindex(args...)

# linear indexing with a single multi-dimensional index
function getindex(A::AbstractArray, I::AbstractArray)
x = similar(A, size(I))
Expand Down Expand Up @@ -441,6 +443,8 @@ setindex!(t::AbstractArray, x, i::Real) =
error("setindex! not defined for ",typeof(t))
setindex!(t::AbstractArray, x) = throw(MethodError(setindex!, (t, x)))

unsafe_setindex!(args...) = setindex!(args...)

## Indexing: handle more indices than dimensions if "extra" indices are 1

# Don't require vector/matrix subclasses to implement more than 1/2 indices,
Expand Down
69 changes: 39 additions & 30 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,21 +241,25 @@ collect(itr) = collect(eltype(itr), itr)

## Indexing: getindex ##

getindex(a::Array) = arrayref(a,1)

getindex(A::Array, i0::Real) = arrayref(A,to_index(i0))
getindex(A::Array, i0::Real, i1::Real) = arrayref(A,to_index(i0),to_index(i1))
getindex(A::Array, i0::Real, i1::Real, i2::Real) =
arrayref(A,to_index(i0),to_index(i1),to_index(i2))
getindex(A::Array, i0::Real, i1::Real, i2::Real, i3::Real) =
arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3))
getindex(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real) =
arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4))
getindex(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real) =
arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4),to_index(i5))

getindex(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real, I::Real...) =
arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4),to_index(i5),to_index(I)...)
for (getindexfn, transform) in ((:getindex, x->x), (:unsafe_getindex, x->:(@boundscheck false return $x)))
@eval begin
$getindexfn(a::Array) = $(transform(:(arrayref(a,1))))

$getindexfn(A::Array, i0::Real) = $(transform(:(arrayref(A,to_index(i0)))))
$getindexfn(A::Array, i0::Real, i1::Real) = $(transform(:(arrayref(A,to_index(i0),to_index(i1)))))
$getindexfn(A::Array, i0::Real, i1::Real, i2::Real) =
$(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2)))))
$getindexfn(A::Array, i0::Real, i1::Real, i2::Real, i3::Real) =
$(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3)))))
$getindexfn(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real) =
$(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4)))))
$getindexfn(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real) =
$(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4),to_index(i5)))))

$getindexfn(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real, I::Real...) =
$(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4),to_index(i5),to_index(I)...))))
end
end

# Fast copy using copy! for UnitRange
function getindex(A::Array, I::UnitRange{Int})
Expand Down Expand Up @@ -302,21 +306,26 @@ getindex(A::Array, I::AbstractArray{Bool}) = getindex_bool_1d(A, I)


## Indexing: setindex! ##
setindex!{T}(A::Array{T}, x) = arrayset(A, convert(T,x), 1)

setindex!{T}(A::Array{T}, x, i0::Real) = arrayset(A, convert(T,x), to_index(i0))
setindex!{T}(A::Array{T}, x, i0::Real, i1::Real) =
arrayset(A, convert(T,x), to_index(i0), to_index(i1))
setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real) =
arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2))
setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real) =
arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3))
setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real) =
arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4))
setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real) =
arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4), to_index(i5))
setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real, I::Real...) =
arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4), to_index(i5), to_index(I)...)

for (setindexfn, transform) in ((:setindex!, x->x), (:unsafe_setindex!, x->:(@boundscheck false return $x)))
@eval begin
$setindexfn{T}(A::Array{T}, x) = $(transform(:(arrayset(A, convert(T,x), 1))))

$setindexfn{T}(A::Array{T}, x, i0::Real) = $(transform(:(arrayset(A, convert(T,x), to_index(i0)))))
$setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real) =
$(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1)))))
$setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real) =
$(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2)))))
$setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real) =
$(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3)))))
$setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real) =
$(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4)))))
$setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real) =
$(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4), to_index(i5)))))
$setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real, I::Real...) =
$(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4), to_index(i5), to_index(I)...))))
end
end

function setindex!{T<:Real}(A::Array, x, I::AbstractVector{T})
for i in I
Expand Down
21 changes: 19 additions & 2 deletions base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,25 @@ macro boundscheck(yesno,blk)
$(Expr(:boundscheck,:pop)))
end

macro inbounds(blk)
:(@boundscheck false $(esc(blk)))
function rewrite_ref(getindexfn, setindexfn, ast::Expr)
if ast.head === :ref
ast = Expr(:custom_ref, getindexfn, setindexfn, ast.args...)
end

args = ast.args
for i = 1:arraylen(args)
arg = arrayref(args, i)
if isa(arg, Expr)
arrayset(args, rewrite_ref(getindexfn, setindexfn, arg), i)
end
end

return ast
end
rewrite_ref(getindexfn, setindexfn, x) = x

macro inbounds(ex)
esc(rewrite_ref(:unsafe_getindex, :unsafe_setindex!, ex))
end

macro label(name::Symbol)
Expand Down
4 changes: 3 additions & 1 deletion base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,9 @@ end
## Indexing: getindex ##

function unsafe_bitgetindex(Bc::Vector{Uint64}, i::Int)
return (Bc[@_div64(i-1)+1] & (uint64(1)<<@_mod64(i-1))) != 0
return @inbounds (Bc[@_div64(i-1)+1] & (uint64(1)<<@_mod64(i-1))) != 0
Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great if this does no longer cause the function to return nothing

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup:

julia> x = [1, 2];

julia> @inbounds x[1]
1

end
unsafe_getindex(v::BitArray, ind::Int) = Base.unsafe_bitgetindex(v.chunks, ind)

function getindex(B::BitArray, i::Int)
1 <= i <= length(B) || throw(BoundsError())
Expand Down Expand Up @@ -408,6 +409,7 @@ function unsafe_bitsetindex!(Bc::Array{Uint64}, x::Bool, i::Int)
end
end
end
unsafe_setindex!(v::BitArray, x::Bool, ind::Int) = (Base.unsafe_bitsetindex!(v.chunks, x, ind); v)

setindex!(B::BitArray, x) = setindex!(B, convert(Bool,x), 1)

Expand Down
2 changes: 2 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,8 @@ export
union!,
union,
unique,
unsafe_getindex,
unsafe_setindex!,
values,
∈,
∉,
Expand Down
10 changes: 0 additions & 10 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,6 @@
nothing
end

unsafe_getindex(v::Real, ind::Int) = v
unsafe_getindex(v::Range, ind::Int) = first(v) + (ind-1)*step(v)
unsafe_getindex(v::BitArray, ind::Int) = Base.unsafe_bitgetindex(v.chunks, ind)
unsafe_getindex(v::AbstractArray, ind::Int) = v[ind]
unsafe_getindex(v, ind::Real) = unsafe_getindex(v, to_index(ind))

unsafe_setindex!{T}(v::AbstractArray{T}, x::T, ind::Int) = (v[ind] = x; v)
unsafe_setindex!(v::BitArray, x::Bool, ind::Int) = (Base.unsafe_bitsetindex!(v.chunks, x, ind); v)
unsafe_setindex!{T}(v::AbstractArray{T}, x::T, ind::Real) = unsafe_setindex!(v, x, to_index(ind))

# Version that uses cartesian indexing for src
@ngenerate N typeof(dest) function _getindex!(dest::Array, src::AbstractArray, I::NTuple{N,Union(Int,AbstractVector)}...)
checksize(dest, I...)
Expand Down
2 changes: 2 additions & 0 deletions base/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ getindex(x::Number) = x
getindex(x::Number, i::Integer) = i == 1 ? x : throw(BoundsError())
getindex(x::Number, I::Integer...) = all([i == 1 for i in I]) ? x : throw(BoundsError())
getindex(x::Number, I::Real...) = getindex(x, to_index(i)...)
unsafe_getindex(x::Number, i::Real) = x
unsafe_getindex(x::Number, i::Real...) = x
first(x::Number) = x
last(x::Number) = x

Expand Down
16 changes: 8 additions & 8 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,15 @@ done(r::UnitRange, i) = i==oftype(i,r.stop)+1

## indexing

getindex(r::Range, i::Real) = getindex(r, to_index(i))
unsafe_getindex(r::Range, i::Real) = getindex(r, to_index(i))
unsafe_getindex{T}(r::Range{T}, i::Integer) =
oftype(T, first(r) + (i-1)*step(r))
unsafe_getindex{T}(r::FloatRange{T}, i::Integer) =
oftype(T, (r.start + (i-1)*r.step)/r.divisor)

function getindex{T}(r::Range{T}, i::Integer)
1 <= i <= length(r) || error(BoundsError)
oftype(T, first(r) + (i-1)*step(r))
end
function getindex{T}(r::FloatRange{T}, i::Integer)
1 <= i <= length(r) || error(BoundsError)
oftype(T, (r.start + (i-1)*r.step)/r.divisor)
unsafe_getindex(r, i)
end

function getindex(r::UnitRange, s::UnitRange{Int})
Expand Down Expand Up @@ -509,7 +509,7 @@ function vcat{T}(r::Range{T})
a = Array(T,n)
i = 1
for x in r
@inbounds a[i] = x
@boundscheck false a[i] = x
i += 1
end
return a
Expand All @@ -523,7 +523,7 @@ function vcat{T}(rs::Range{T}...)
i = 1
for r in rs
for x in r
@inbounds a[i] = x
@boundscheck false a[i] = x
i += 1
end
end
Expand Down
73 changes: 45 additions & 28 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,13 @@
`(= ,(car e) (call ,op (:: ,(car e) ,(car declT)) ,rhs))))))

(define (expand-update-operator op lhs rhs . declT)
(cond ((and (pair? lhs) (eq? (car lhs) 'ref))
(cond ((and (pair? lhs) (or (eq? (car lhs) 'ref) (eq? (car lhs) 'custom_ref)))
;; expand indexing inside op= first, to remove "end" and ":"
(let* ((ex (partially-expand-ref lhs))
(stmts (butlast (cdr ex)))
(refex (last (cdr ex)))
(nuref `(ref ,(caddr refex) ,@(cdddr refex))))
(nuref `(,@(if (eq? (car lhs) 'custom_ref) `(custom_ref ,(cadr lhs) ,(caddr lhs)) `(ref))
,(caddr refex) ,@(cdddr refex))))
`(block ,@stmts
,(expand-update-operator- op nuref rhs declT))))
((and (pair? lhs) (eq? (car lhs) '|::|))
Expand All @@ -202,8 +203,8 @@
(define (dotop? o) (and (symbol? o) (eqv? (string.char (string o) 0) #\.)))

(define (partially-expand-ref e)
(let ((a (cadr e))
(idxs (cddr e)))
(let ((a (if (eq? (car e) 'custom_ref) (cadddr e) (cadr e)))
(idxs (if (eq? (car e) 'custom_ref) (cddddr e) (cddr e))))
(let* ((reuse (and (pair? a)
(contains (lambda (x)
(or (eq? x 'end)
Expand All @@ -217,7 +218,7 @@
(new-idxs stuff) (process-indexes arr idxs)
`(block
,@(append stmts stuff)
(call getindex ,arr ,@new-idxs))))))
(call ,(if (eq? (car e) 'custom_ref) (cadr e) 'getindex) ,arr ,@new-idxs))))))

;; accumulate a series of comparisons, with the given "and" constructor,
;; exit criteria, and "take" function that consumes part of a list,
Expand Down Expand Up @@ -312,6 +313,11 @@
;; inside ref only replace within the first argument
(list* 'ref (replace-end (cadr ex) a n tuples last)
(cddr ex)))
((eq? (car ex) 'custom_ref)
;; inside custom_ref only replace within the third argument
(list* 'custom_ref (cadr ex) (caddr ex)
(replace-end (cadddr ex) a n tuples last)
(cddddr ex)))
(else
(cons (car ex)
(map (lambda (x) (replace-end x a n tuples last))
Expand Down Expand Up @@ -1534,6 +1540,28 @@
e
((get expand-table (car e) map-expand-forms) e)))

(define (expand-setindex a idxs rhs setindexfn)
(let* ((reuse (and (pair? a)
(contains (lambda (x)
(or (eq? x 'end)
(and (pair? x)
(eq? (car x) ':))))
idxs)))
(arr (if reuse (gensy) a))
(stmts (if reuse `((= ,arr ,(expand-forms a))) '())))
(let* ((rrhs (and (pair? rhs) (not (quoted? rhs))))
(r (if rrhs (gensy) rhs))
(rini (if rrhs `((= ,r ,(expand-forms rhs))) '())))
(receive
(new-idxs stuff) (process-indexes arr idxs)
`(block
,@stmts
,.(map expand-forms stuff)
,@rini
,(expand-forms
`(call ,setindexfn ,arr ,r ,@new-idxs))
,r)))))

(define expand-table
(table
'quote identity
Expand Down Expand Up @@ -1611,29 +1639,14 @@

((ref)
;; (= (ref a . idxs) rhs)
(let ((a (cadr (cadr e)))
(idxs (cddr (cadr e)))
(rhs (caddr e)))
(let* ((reuse (and (pair? a)
(contains (lambda (x)
(or (eq? x 'end)
(and (pair? x)
(eq? (car x) ':))))
idxs)))
(arr (if reuse (gensy) a))
(stmts (if reuse `((= ,arr ,(expand-forms a))) '())))
(let* ((rrhs (and (pair? rhs) (not (quoted? rhs))))
(r (if rrhs (gensy) rhs))
(rini (if rrhs `((= ,r ,(expand-forms rhs))) '())))
(receive
(new-idxs stuff) (process-indexes arr idxs)
`(block
,@stmts
,.(map expand-forms stuff)
,@rini
,(expand-forms
`(call setindex! ,arr ,r ,@new-idxs))
,r))))))
(expand-setindex (cadr (cadr e)) (cddr (cadr e))
(caddr e) 'setindex!)
)

((custom_ref)
(expand-setindex (cadddr (cadr e)) (cddddr (cadr e))
(caddr e) (caddr (cadr e)))
)

((|::|)
;; (= (|::| x T) rhs)
Expand Down Expand Up @@ -1676,6 +1689,10 @@
(lambda (e)
(expand-forms (partially-expand-ref e)))

'custom_ref
(lambda (e)
(expand-forms (partially-expand-ref e)))

'curly
(lambda (e)
(expand-forms
Expand Down