Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Merge pull request #557 from JuliaGPU/tb/resize
Browse files Browse the repository at this point in the history
Implement array resizing.
  • Loading branch information
maleadt authored Jan 7, 2020
2 parents 398e563 + f900010 commit 3259594
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
33 changes: 33 additions & 0 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,36 @@ function Base.reverse(input::CuVector{T}, start=1, stop=length(input)) where {T}

return output
end


## resizing

"""
resize!(a::CuVector, n::Int)
Resize `a` to contain `n` elements. If `n` is smaller than the current collection length,
the first `n` elements will be retained. If `n` is larger, the new elements are not
guaranteed to be initialized.
Several restrictions apply to which types of `CuArray`s can be resized:
- the array should be backed by the memory pool, and not have been constructed with `unsafe_wrap`
- the array cannot be derived (view, reshape) from another array
- the array cannot have any derived arrays itself
"""
function Base.resize!(A::CuVector{T}, n::Int) where T
A.parent === nothing || error("cannot resize derived CuArray")
A.refcount == 1 || error("cannot resize shared CuArray")
A.pooled || error("cannot resize wrapped CuArray")

ptr = convert(CuPtr{T}, alloc(n * sizeof(T)))
m = min(length(A), n)
unsafe_copyto!(ptr, pointer(A), m)

free(convert(CuPtr{Nothing}, pointer(A)))
A.dims = (n,)
A.ptr = ptr

A
end
23 changes: 23 additions & 0 deletions test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,26 @@ end
y = exp.(x)
@test y isa CuArray{Complex{Float32}}
end

@testset "resizing" begin
a = CuArray([1,2,3])

resize!(a, 3)
@test length(a) == 3
@test Array(a) == [1,2,3]

resize!(a, 5)
@test length(a) == 5
@test Array(a)[1:3] == [1,2,3]

resize!(a, 2)
@test length(a) == 2
@test Array(a)[1:2] == [1,2]

b = view(a, 1:2)
@test_throws ErrorException resize!(a, 2)
@test_throws ErrorException resize!(b, 2)

c = unsafe_wrap(CuArray{Int}, pointer(b), 2)
@test_throws ErrorException resize!(c, 2)
end

0 comments on commit 3259594

Please sign in to comment.