Skip to content

Commit

Permalink
solve int overflow and alloc,free (PaddlePaddle#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
Thunderbrook authored and zmxdream committed Jun 23, 2022
1 parent d4dd749 commit 62e7b47
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/heter_ps/heter_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class HeterComm {

void init_path();

void create_storage(int start_index, int end_index, int keylen, int vallen);
void create_storage(int start_index, int end_index, size_t keylen, size_t vallen);
void destroy_storage(int start_index, int end_index);
void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right,
KeyType* src_key, GradType* src_val);
Expand Down
20 changes: 10 additions & 10 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,20 @@ void HeterComm<KeyType, ValType, GradType>::init_path() {
template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::create_storage(int start_index,
int end_index,
int keylen,
int vallen) {
size_t keylen,
size_t vallen) {
auto& allocator = allocators_[start_index];
auto& nodes = path_[start_index][end_index].nodes_;
for (size_t i = 0; i < nodes.size(); ++i) {
platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].gpu_num));
allocator->DeviceAllocate(
PADDLE_ENFORCE_GPU_SUCCESS(allocator->DeviceAllocate(
resource_->dev_id(nodes[i].gpu_num),
(void**)&(nodes[i].key_storage), // NOLINT
keylen, resource_->remote_stream(nodes[i].gpu_num, start_index));
allocator->DeviceAllocate(
keylen, resource_->remote_stream(nodes[i].gpu_num, start_index)));
PADDLE_ENFORCE_GPU_SUCCESS(allocator->DeviceAllocate(
resource_->dev_id(nodes[i].gpu_num),
(void**)&(nodes[i].val_storage), // NOLINT
vallen, resource_->remote_stream(nodes[i].gpu_num, start_index));
vallen, resource_->remote_stream(nodes[i].gpu_num, start_index)));

nodes[i].key_bytes_len = keylen;
nodes[i].val_bytes_len = vallen;
Expand All @@ -196,10 +196,10 @@ void HeterComm<KeyType, ValType, GradType>::destroy_storage(int start_index,
for (size_t i = 0; i < nodes.size(); ++i) {
platform::CUDADeviceGuard guard(resource_->dev_id(nodes[i].gpu_num));

allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num),
nodes[i].key_storage);
allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num),
nodes[i].val_storage);
PADDLE_ENFORCE_GPU_SUCCESS(allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num),
nodes[i].key_storage));
PADDLE_ENFORCE_GPU_SUCCESS(allocator->DeviceFree(resource_->dev_id(nodes[i].gpu_num),
nodes[i].val_storage));
}
}

Expand Down

0 comments on commit 62e7b47

Please sign in to comment.