From 7cb11d501550377dd26092f2c09f83b2be5c9f20 Mon Sep 17 00:00:00 2001 From: Simon Kornblith Date: Wed, 3 Sep 2014 23:34:40 -0400 Subject: [PATCH] =?UTF-8?q?Safer,=20extensible=20=EF=B9=ABinbounds?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- base/abstractarray.jl | 4 +++ base/array.jl | 69 ++++++++++++++++++++----------------- base/base.jl | 21 ++++++++++-- base/bitarray.jl | 4 ++- base/exports.jl | 2 ++ base/multidimensional.jl | 10 ------ base/number.jl | 2 ++ base/range.jl | 16 ++++----- src/julia-syntax.scm | 73 +++++++++++++++++++++++++--------------- 9 files changed, 122 insertions(+), 79 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index cdfa0abe14e6c..7029f7ba875ed 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -375,6 +375,8 @@ imag{T<:Real}(x::AbstractArray{T}) = zero(x) getindex(t::AbstractArray, i::Real) = error("indexing not defined for ", typeof(t)) +unsafe_getindex(args...) = getindex(args...) + # linear indexing with a single multi-dimensional index function getindex(A::AbstractArray, I::AbstractArray) x = similar(A, size(I)) @@ -441,6 +443,8 @@ setindex!(t::AbstractArray, x, i::Real) = error("setindex! not defined for ",typeof(t)) setindex!(t::AbstractArray, x) = throw(MethodError(setindex!, (t, x))) +unsafe_setindex!(args...) = setindex!(args...) + ## Indexing: handle more indices than dimensions if "extra" indices are 1 # Don't require vector/matrix subclasses to implement more than 1/2 indices, diff --git a/base/array.jl b/base/array.jl index fcaa3009a2c9e..6df8691b28fcc 100644 --- a/base/array.jl +++ b/base/array.jl @@ -241,21 +241,25 @@ collect(itr) = collect(eltype(itr), itr) ## Indexing: getindex ## -getindex(a::Array) = arrayref(a,1) - -getindex(A::Array, i0::Real) = arrayref(A,to_index(i0)) -getindex(A::Array, i0::Real, i1::Real) = arrayref(A,to_index(i0),to_index(i1)) -getindex(A::Array, i0::Real, i1::Real, i2::Real) = - arrayref(A,to_index(i0),to_index(i1),to_index(i2)) -getindex(A::Array, i0::Real, i1::Real, i2::Real, i3::Real) = - arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3)) -getindex(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real) = - arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4)) -getindex(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real) = - arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4),to_index(i5)) - -getindex(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real, I::Real...) = - arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4),to_index(i5),to_index(I)...) +for (getindexfn, transform) in ((:getindex, x->x), (:unsafe_getindex, x->:(@boundscheck false return $x))) + @eval begin + $getindexfn(a::Array) = $(transform(:(arrayref(a,1)))) + + $getindexfn(A::Array, i0::Real) = $(transform(:(arrayref(A,to_index(i0))))) + $getindexfn(A::Array, i0::Real, i1::Real) = $(transform(:(arrayref(A,to_index(i0),to_index(i1))))) + $getindexfn(A::Array, i0::Real, i1::Real, i2::Real) = + $(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2))))) + $getindexfn(A::Array, i0::Real, i1::Real, i2::Real, i3::Real) = + $(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3))))) + $getindexfn(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real) = + $(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4))))) + $getindexfn(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real) = + $(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4),to_index(i5))))) + + $getindexfn(A::Array, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real, I::Real...) = + $(transform(:(arrayref(A,to_index(i0),to_index(i1),to_index(i2),to_index(i3),to_index(i4),to_index(i5),to_index(I)...)))) + end +end # Fast copy using copy! for UnitRange function getindex(A::Array, I::UnitRange{Int}) @@ -302,21 +306,26 @@ getindex(A::Array, I::AbstractArray{Bool}) = getindex_bool_1d(A, I) ## Indexing: setindex! ## -setindex!{T}(A::Array{T}, x) = arrayset(A, convert(T,x), 1) - -setindex!{T}(A::Array{T}, x, i0::Real) = arrayset(A, convert(T,x), to_index(i0)) -setindex!{T}(A::Array{T}, x, i0::Real, i1::Real) = - arrayset(A, convert(T,x), to_index(i0), to_index(i1)) -setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real) = - arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2)) -setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real) = - arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3)) -setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real) = - arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4)) -setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real) = - arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4), to_index(i5)) -setindex!{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real, I::Real...) = - arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4), to_index(i5), to_index(I)...) + +for (setindexfn, transform) in ((:setindex!, x->x), (:unsafe_setindex!, x->:(@boundscheck false return $x))) + @eval begin + $setindexfn{T}(A::Array{T}, x) = $(transform(:(arrayset(A, convert(T,x), 1)))) + + $setindexfn{T}(A::Array{T}, x, i0::Real) = $(transform(:(arrayset(A, convert(T,x), to_index(i0))))) + $setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real) = + $(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1))))) + $setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real) = + $(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2))))) + $setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real) = + $(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3))))) + $setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real) = + $(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4))))) + $setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real) = + $(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4), to_index(i5))))) + $setindexfn{T}(A::Array{T}, x, i0::Real, i1::Real, i2::Real, i3::Real, i4::Real, i5::Real, I::Real...) = + $(transform(:(arrayset(A, convert(T,x), to_index(i0), to_index(i1), to_index(i2), to_index(i3), to_index(i4), to_index(i5), to_index(I)...)))) + end +end function setindex!{T<:Real}(A::Array, x, I::AbstractVector{T}) for i in I diff --git a/base/base.jl b/base/base.jl index 05b8c2a9e20ed..16345d75bc6e8 100644 --- a/base/base.jl +++ b/base/base.jl @@ -206,8 +206,25 @@ macro boundscheck(yesno,blk) $(Expr(:boundscheck,:pop))) end -macro inbounds(blk) - :(@boundscheck false $(esc(blk))) +function rewrite_ref(getindexfn, setindexfn, ast::Expr) + if ast.head === :ref + ast = Expr(:custom_ref, getindexfn, setindexfn, ast.args...) + end + + args = ast.args + for i = 1:arraylen(args) + arg = arrayref(args, i) + if isa(arg, Expr) + arrayset(args, rewrite_ref(getindexfn, setindexfn, arg), i) + end + end + + return ast +end +rewrite_ref(getindexfn, setindexfn, x) = x + +macro inbounds(ex) + esc(rewrite_ref(:unsafe_getindex, :unsafe_setindex!, ex)) end macro label(name::Symbol) diff --git a/base/bitarray.jl b/base/bitarray.jl index 0c5dfbf6e6198..de82077a07779 100644 --- a/base/bitarray.jl +++ b/base/bitarray.jl @@ -342,8 +342,9 @@ end ## Indexing: getindex ## function unsafe_bitgetindex(Bc::Vector{Uint64}, i::Int) - return (Bc[@_div64(i-1)+1] & (uint64(1)<<@_mod64(i-1))) != 0 + return @inbounds (Bc[@_div64(i-1)+1] & (uint64(1)<<@_mod64(i-1))) != 0 end +unsafe_getindex(v::BitArray, ind::Int) = Base.unsafe_bitgetindex(v.chunks, ind) function getindex(B::BitArray, i::Int) 1 <= i <= length(B) || throw(BoundsError()) @@ -408,6 +409,7 @@ function unsafe_bitsetindex!(Bc::Array{Uint64}, x::Bool, i::Int) end end end +unsafe_setindex!(v::BitArray, x::Bool, ind::Int) = (Base.unsafe_bitsetindex!(v.chunks, x, ind); v) setindex!(B::BitArray, x) = setindex!(B, convert(Bool,x), 1) diff --git a/base/exports.jl b/base/exports.jl index e70f03a2b3d0c..1daaaf55346d8 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -776,6 +776,8 @@ export union!, union, unique, + unsafe_getindex, + unsafe_setindex!, values, ∈, ∉, diff --git a/base/multidimensional.jl b/base/multidimensional.jl index c25e3984a31fe..d99b56899ed9e 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -5,16 +5,6 @@ nothing end -unsafe_getindex(v::Real, ind::Int) = v -unsafe_getindex(v::Range, ind::Int) = first(v) + (ind-1)*step(v) -unsafe_getindex(v::BitArray, ind::Int) = Base.unsafe_bitgetindex(v.chunks, ind) -unsafe_getindex(v::AbstractArray, ind::Int) = v[ind] -unsafe_getindex(v, ind::Real) = unsafe_getindex(v, to_index(ind)) - -unsafe_setindex!{T}(v::AbstractArray{T}, x::T, ind::Int) = (v[ind] = x; v) -unsafe_setindex!(v::BitArray, x::Bool, ind::Int) = (Base.unsafe_bitsetindex!(v.chunks, x, ind); v) -unsafe_setindex!{T}(v::AbstractArray{T}, x::T, ind::Real) = unsafe_setindex!(v, x, to_index(ind)) - # Version that uses cartesian indexing for src @ngenerate N typeof(dest) function _getindex!(dest::Array, src::AbstractArray, I::NTuple{N,Union(Int,AbstractVector)}...) checksize(dest, I...) diff --git a/base/number.jl b/base/number.jl index 16de2c217f376..32674b3cdee99 100644 --- a/base/number.jl +++ b/base/number.jl @@ -15,6 +15,8 @@ getindex(x::Number) = x getindex(x::Number, i::Integer) = i == 1 ? x : throw(BoundsError()) getindex(x::Number, I::Integer...) = all([i == 1 for i in I]) ? x : throw(BoundsError()) getindex(x::Number, I::Real...) = getindex(x, to_index(i)...) +unsafe_getindex(x::Number, i::Real) = x +unsafe_getindex(x::Number, i::Real...) = x first(x::Number) = x last(x::Number) = x diff --git a/base/range.jl b/base/range.jl index fdfd451b45a14..9dc52b068731e 100644 --- a/base/range.jl +++ b/base/range.jl @@ -252,15 +252,15 @@ done(r::UnitRange, i) = i==oftype(i,r.stop)+1 ## indexing -getindex(r::Range, i::Real) = getindex(r, to_index(i)) +unsafe_getindex(r::Range, i::Real) = getindex(r, to_index(i)) +unsafe_getindex{T}(r::Range{T}, i::Integer) = + oftype(T, first(r) + (i-1)*step(r)) +unsafe_getindex{T}(r::FloatRange{T}, i::Integer) = + oftype(T, (r.start + (i-1)*r.step)/r.divisor) function getindex{T}(r::Range{T}, i::Integer) 1 <= i <= length(r) || error(BoundsError) - oftype(T, first(r) + (i-1)*step(r)) -end -function getindex{T}(r::FloatRange{T}, i::Integer) - 1 <= i <= length(r) || error(BoundsError) - oftype(T, (r.start + (i-1)*r.step)/r.divisor) + unsafe_getindex(r, i) end function getindex(r::UnitRange, s::UnitRange{Int}) @@ -509,7 +509,7 @@ function vcat{T}(r::Range{T}) a = Array(T,n) i = 1 for x in r - @inbounds a[i] = x + @boundscheck false a[i] = x i += 1 end return a @@ -523,7 +523,7 @@ function vcat{T}(rs::Range{T}...) i = 1 for r in rs for x in r - @inbounds a[i] = x + @boundscheck false a[i] = x i += 1 end end diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index ac8062d4ad791..be5eb3083c210 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -182,12 +182,13 @@ `(= ,(car e) (call ,op (:: ,(car e) ,(car declT)) ,rhs)))))) (define (expand-update-operator op lhs rhs . declT) - (cond ((and (pair? lhs) (eq? (car lhs) 'ref)) + (cond ((and (pair? lhs) (or (eq? (car lhs) 'ref) (eq? (car lhs) 'custom_ref))) ;; expand indexing inside op= first, to remove "end" and ":" (let* ((ex (partially-expand-ref lhs)) (stmts (butlast (cdr ex))) (refex (last (cdr ex))) - (nuref `(ref ,(caddr refex) ,@(cdddr refex)))) + (nuref `(,@(if (eq? (car lhs) 'custom_ref) `(custom_ref ,(cadr lhs) ,(caddr lhs)) `(ref)) + ,(caddr refex) ,@(cdddr refex)))) `(block ,@stmts ,(expand-update-operator- op nuref rhs declT)))) ((and (pair? lhs) (eq? (car lhs) '|::|)) @@ -202,8 +203,8 @@ (define (dotop? o) (and (symbol? o) (eqv? (string.char (string o) 0) #\.))) (define (partially-expand-ref e) - (let ((a (cadr e)) - (idxs (cddr e))) + (let ((a (if (eq? (car e) 'custom_ref) (cadddr e) (cadr e))) + (idxs (if (eq? (car e) 'custom_ref) (cddddr e) (cddr e)))) (let* ((reuse (and (pair? a) (contains (lambda (x) (or (eq? x 'end) @@ -217,7 +218,7 @@ (new-idxs stuff) (process-indexes arr idxs) `(block ,@(append stmts stuff) - (call getindex ,arr ,@new-idxs)))))) + (call ,(if (eq? (car e) 'custom_ref) (cadr e) 'getindex) ,arr ,@new-idxs)))))) ;; accumulate a series of comparisons, with the given "and" constructor, ;; exit criteria, and "take" function that consumes part of a list, @@ -312,6 +313,11 @@ ;; inside ref only replace within the first argument (list* 'ref (replace-end (cadr ex) a n tuples last) (cddr ex))) + ((eq? (car ex) 'custom_ref) + ;; inside custom_ref only replace within the third argument + (list* 'custom_ref (cadr ex) (caddr ex) + (replace-end (cadddr ex) a n tuples last) + (cddddr ex))) (else (cons (car ex) (map (lambda (x) (replace-end x a n tuples last)) @@ -1534,6 +1540,28 @@ e ((get expand-table (car e) map-expand-forms) e))) +(define (expand-setindex a idxs rhs setindexfn) + (let* ((reuse (and (pair? a) + (contains (lambda (x) + (or (eq? x 'end) + (and (pair? x) + (eq? (car x) ':)))) + idxs))) + (arr (if reuse (gensy) a)) + (stmts (if reuse `((= ,arr ,(expand-forms a))) '()))) + (let* ((rrhs (and (pair? rhs) (not (quoted? rhs)))) + (r (if rrhs (gensy) rhs)) + (rini (if rrhs `((= ,r ,(expand-forms rhs))) '()))) + (receive + (new-idxs stuff) (process-indexes arr idxs) + `(block + ,@stmts + ,.(map expand-forms stuff) + ,@rini + ,(expand-forms + `(call ,setindexfn ,arr ,r ,@new-idxs)) + ,r))))) + (define expand-table (table 'quote identity @@ -1611,29 +1639,14 @@ ((ref) ;; (= (ref a . idxs) rhs) - (let ((a (cadr (cadr e))) - (idxs (cddr (cadr e))) - (rhs (caddr e))) - (let* ((reuse (and (pair? a) - (contains (lambda (x) - (or (eq? x 'end) - (and (pair? x) - (eq? (car x) ':)))) - idxs))) - (arr (if reuse (gensy) a)) - (stmts (if reuse `((= ,arr ,(expand-forms a))) '()))) - (let* ((rrhs (and (pair? rhs) (not (quoted? rhs)))) - (r (if rrhs (gensy) rhs)) - (rini (if rrhs `((= ,r ,(expand-forms rhs))) '()))) - (receive - (new-idxs stuff) (process-indexes arr idxs) - `(block - ,@stmts - ,.(map expand-forms stuff) - ,@rini - ,(expand-forms - `(call setindex! ,arr ,r ,@new-idxs)) - ,r)))))) + (expand-setindex (cadr (cadr e)) (cddr (cadr e)) + (caddr e) 'setindex!) + ) + + ((custom_ref) + (expand-setindex (cadddr (cadr e)) (cddddr (cadr e)) + (caddr e) (caddr (cadr e))) + ) ((|::|) ;; (= (|::| x T) rhs) @@ -1676,6 +1689,10 @@ (lambda (e) (expand-forms (partially-expand-ref e))) + 'custom_ref + (lambda (e) + (expand-forms (partially-expand-ref e))) + 'curly (lambda (e) (expand-forms