From f900010de53812b69239d0b9dfae5cd47fe97563 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 7 Jan 2020 13:38:11 +0100 Subject: [PATCH] Implement array resizing. Fix #547 --- src/array.jl | 33 +++++++++++++++++++++++++++++++++ test/base.jl | 23 +++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/src/array.jl b/src/array.jl index 6fc8d624..e6b7a00b 100644 --- a/src/array.jl +++ b/src/array.jl @@ -462,3 +462,36 @@ function Base.reverse(input::CuVector{T}, start=1, stop=length(input)) where {T} return output end + + +## resizing + +""" + resize!(a::CuVector, n::Int) + +Resize `a` to contain `n` elements. If `n` is smaller than the current collection length, +the first `n` elements will be retained. If `n` is larger, the new elements are not +guaranteed to be initialized. + +Several restrictions apply to which types of `CuArray`s can be resized: + +- the array should be backed by the memory pool, and not have been constructed with `unsafe_wrap` +- the array cannot be derived (view, reshape) from another array +- the array cannot have any derived arrays itself + +""" +function Base.resize!(A::CuVector{T}, n::Int) where T + A.parent === nothing || error("cannot resize derived CuArray") + A.refcount == 1 || error("cannot resize shared CuArray") + A.pooled || error("cannot resize wrapped CuArray") + + ptr = convert(CuPtr{T}, alloc(n * sizeof(T))) + m = min(length(A), n) + unsafe_copyto!(ptr, pointer(A), m) + + free(convert(CuPtr{Nothing}, pointer(A))) + A.dims = (n,) + A.ptr = ptr + + A +end diff --git a/test/base.jl b/test/base.jl index 72702767..4792d4f0 100644 --- a/test/base.jl +++ b/test/base.jl @@ -449,3 +449,26 @@ end y = exp.(x) @test y isa CuArray{Complex{Float32}} end + +@testset "resizing" begin + a = CuArray([1,2,3]) + + resize!(a, 3) + @test length(a) == 3 + @test Array(a) == [1,2,3] + + resize!(a, 5) + @test length(a) == 5 + @test Array(a)[1:3] == [1,2,3] + + resize!(a, 2) + @test length(a) == 2 + @test Array(a)[1:2] == [1,2] + + b = view(a, 1:2) + @test_throws ErrorException resize!(a, 2) + @test_throws ErrorException resize!(b, 2) + + c = unsafe_wrap(CuArray{Int}, pointer(b), 2) + @test_throws ErrorException resize!(c, 2) +end