Skip to content

Commit

Permalink
[RPC][BUGFIX][BACKPORT-0.6] Fix bug in rpc ring buffer shrink
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 5, 2020
1 parent 70a5902 commit dfba7e1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
31 changes: 23 additions & 8 deletions src/support/ring_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,12 @@ class RingBuffer {
return ring_.size();
}
/*!
* Reserve capacity to be at least n.
* Will only increase capacity if n is bigger than current capacity.
* Reserve capacity to be at least n.
* Will only increase capacity if n is bigger than current capacity.
*
* The effect of Reserve only lasts before the next call to Reserve.
* Other functions in the ring buffer can also call into the reserve.
*
* \param n The size of capacity.
*/
void Reserve(size_t n) {
Expand All @@ -63,17 +67,28 @@ class RingBuffer {
size_t ncopy = head_ptr_ + bytes_available_ - old_size;
memcpy(&ring_[0] + old_size, &ring_[0], ncopy);
}
} else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity && bytes_available_ > 0) {
// shrink too large temporary buffer to avoid out of memory on some embedded devices
} else if (ring_.size() > n * 8 &&
ring_.size() > kInitCapacity) {
// shrink too large temporary buffer to
// avoid out of memory on some embedded devices
size_t old_bytes = bytes_available_;

std::vector<char> tmp(old_bytes);

Read(&tmp[0], old_bytes);
ring_.resize(kInitCapacity);
if (old_bytes != 0) {
Read(&tmp[0], old_bytes);
}

size_t new_size = kInitCapacity;
new_size = std::max(new_size, bytes_available_);
new_size = std::max(new_size, n);

ring_.resize(new_size);
ring_.shrink_to_fit();

memcpy(&ring_[0], &tmp[0], old_bytes);
if (old_bytes != 0) {
memcpy(&ring_[0], &tmp[0], old_bytes);
}

head_ptr_ = 0;
bytes_available_ = old_bytes;
}
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_runtime_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ def remote_array_func(y):
fremote(r_cpu)


def test_rpc_large_array_shrink():
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
ctx = remote.cpu(0)
a = tvm.nd.array(np.ones((5041, 720)).astype('float32'), ctx)
b = tvm.nd.array(np.ones((720, 192)).astype('float32'), ctx)


def test_rpc_echo():
def check(remote):
fecho = remote.get_function("testing.echo")
Expand Down Expand Up @@ -447,3 +455,4 @@ def target(host, port, device_key, timeout):
test_local_func()
test_rpc_tracker_register()
test_rpc_tracker_request()
test_rpc_large_array_shrink()

0 comments on commit dfba7e1

Please sign in to comment.