From 77e2ea2874ef26e791f55182c14475471f848dd7 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 5 May 2020 16:47:00 -0700 Subject: [PATCH] [RPC][BUGFIX][BACKPORT-0.6] Fix bug in rpc ring buffer shrink (#5516) --- src/support/ring_buffer.h | 30 ++++++++++++++++------- tests/python/unittest/test_runtime_rpc.py | 14 +++++++++++ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/support/ring_buffer.h b/src/support/ring_buffer.h index e6e3b04ec7a9..7a1bcb63f7fa 100644 --- a/src/support/ring_buffer.h +++ b/src/support/ring_buffer.h @@ -49,8 +49,12 @@ class RingBuffer { return ring_.size(); } /*! - * Reserve capacity to be at least n. - * Will only increase capacity if n is bigger than current capacity. + * Reserve capacity to be at least n. + * Will only increase capacity if n is bigger than current capacity. + * + * The effect of Reserve only lasts before the next call to Reserve. + * Other functions in the ring buffer can also call into the reserve. + * * \param n The size of capacity. */ void Reserve(size_t n) { @@ -63,19 +67,27 @@ class RingBuffer { size_t ncopy = head_ptr_ + bytes_available_ - old_size; memcpy(&ring_[0] + old_size, &ring_[0], ncopy); } - } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity && bytes_available_ > 0) { - // shrink too large temporary buffer to avoid out of memory on some embedded devices + } else if (ring_.size() > n * 8 && + ring_.size() > kInitCapacity) { + // shrink too large temporary buffer to + // avoid out of memory on some embedded devices + if (bytes_available_ != 0) { + // move existing bytes to the head. size_t old_bytes = bytes_available_; - std::vector tmp(old_bytes); - Read(&tmp[0], old_bytes); - ring_.resize(kInitCapacity); - ring_.shrink_to_fit(); memcpy(&ring_[0], &tmp[0], old_bytes); - head_ptr_ = 0; bytes_available_ = old_bytes; + } + // shrink the ring. + size_t new_size = kInitCapacity; + new_size = std::max(new_size, n); + new_size = std::max(new_size, bytes_available_); + + ring_.resize(new_size); + ring_.shrink_to_fit(); + head_ptr_ = 0; } } diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 091e9427a25a..4e7921b0ae7a 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -102,6 +102,19 @@ def remote_array_func(y): fremote(r_cpu) +def test_rpc_large_array(): + # testcase of large array creation + server = rpc.Server("localhost") + remote = rpc.connect(server.host, server.port) + ctx = remote.cpu(0) + a_np = np.ones((5041, 720)).astype('float32') + b_np = np.ones((720, 192)).astype('float32') + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + np.testing.assert_equal(a.asnumpy(), a_np) + np.testing.assert_equal(b.asnumpy(), b_np) + + def test_rpc_echo(): def check(remote): fecho = remote.get_function("testing.echo") @@ -447,3 +460,4 @@ def target(host, port, device_key, timeout): test_local_func() test_rpc_tracker_register() test_rpc_tracker_request() + test_rpc_large_array()