From 6d8e7a534573bb4e8c16e9cd1445245576c74c62 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 8 Jun 2022 03:37:41 +0800 Subject: [PATCH 1/5] search and fill slot_feature --- .../ps/table/common_graph_table.cc | 75 +++++++- .../distributed/ps/table/common_graph_table.h | 16 +- .../distributed/ps/table/graph/graph_node.h | 22 +++ paddle/fluid/framework/data_feed.cu | 167 ++++++++++++++++++ paddle/fluid/framework/data_feed.h | 5 + paddle/fluid/framework/data_set.cc | 14 +- .../fleet/heter_ps/graph_gpu_ps_table.h | 16 +- .../fleet/heter_ps/graph_gpu_ps_table_inl.cu | 165 ++++++++++++++++- .../fleet/heter_ps/graph_gpu_wrapper.cu | 34 +++- .../fleet/heter_ps/graph_gpu_wrapper.h | 8 +- paddle/fluid/platform/flags.cc | 19 ++ paddle/fluid/pybind/fleet_py.cc | 5 +- 12 files changed, 517 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.cc b/paddle/fluid/distributed/ps/table/common_graph_table.cc index 06f86ba705135..7c87e8a066a08 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.cc +++ b/paddle/fluid/distributed/ps/table/common_graph_table.cc @@ -45,7 +45,7 @@ int32_t GraphTable::Load_to_ssd(const std::string &path, } paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( - int ntype_id, std::vector &node_ids, int slot_num) { + std::vector &node_ids, int slot_num) { std::vector> bags(task_pool_size_); for (auto x : node_ids) { int location = x % shard_num % task_pool_size_; @@ -63,7 +63,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( std::vector feature_ids; for (size_t j = 0; j < bags[i].size(); j++) { // TODO use FEATURE_TABLE instead - Node *v = find_node(1, ntype_id, bags[i][j]); + Node *v = find_node(1, bags[i][j]); x.node_id = bags[i][j]; if (v == NULL) { x.feature_size = 0; @@ -85,10 +85,6 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( } x.feature_size = total_feature_size; node_fea_array[i].push_back(x); - VLOG(2) << "node_fea_array[i].size() = [" - << node_fea_array[i].size() << "]"; - VLOG(2) << "feature_array[i].size() = [" << feature_array[i].size() - << "]"; } } return 0; @@ -102,8 +98,7 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea( tot_len += feature_array[i].size(); } VLOG(0) << "Loaded feature table on cpu, feature_list_size[" << tot_len - << "] node_ids_size[" << node_ids.size() << "] ntype_id[" << ntype_id - << "]"; + << "] node_ids_size[" << node_ids.size() << "]"; res.init_on_cpu(tot_len, (unsigned int)node_ids.size(), slot_num); unsigned int offset = 0, ind = 0; for (int i = 0; i < task_pool_size_; i++) { @@ -1240,6 +1235,24 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge, return 0; } + +Node *GraphTable::find_node(int type_id, uint64_t id) { + size_t shard_id = id % shard_num; + if (shard_id >= shard_end || shard_id < shard_start) { + return nullptr; + } + Node *node = nullptr; + size_t index = shard_id - shard_start; + auto &search_shards = type_id == 0 ? edge_shards : feature_shards; + for (auto& search_shard: search_shards) { + PADDLE_ENFORCE_NOT_NULL(search_shard[index]); + node = search_shard[index]->find_node(id); + if (node != nullptr) { + break; + } + } + return node; +} Node *GraphTable::find_node(int type_id, int idx, uint64_t id) { size_t shard_id = id % shard_num; @@ -1537,6 +1550,30 @@ std::pair GraphTable::parse_feature( return std::make_pair(-1, ""); } +std::vector> GraphTable::get_all_id(int type_id, int slice_num) { + std::vector> res(slice_num); + auto &search_shards = type_id == 0 ? edge_shards : feature_shards; + std::vector>> tasks; + for (int idx = 0; idx < search_shards.size(); idx++) { + for (int j = 0; j < search_shards[idx].size(); j++) { + tasks.push_back(_shards_task_pool[j % task_pool_size_]->enqueue( + [&search_shards, idx, j]() -> std::vector { + return search_shards[idx][j]->get_all_id(); + })); + } + } + for (size_t i = 0; i < tasks.size(); ++i) { + tasks[i].wait(); + } + for (size_t i = 0; i < tasks.size(); i++) { + auto ids = tasks[i].get(); + for (auto &id : ids) { + res[(uint64_t)(id) % slice_num].push_back(id); + } + } + return res; +} + std::vector> GraphTable::get_all_id(int type_id, int idx, int slice_num) { std::vector> res(slice_num); @@ -1559,6 +1596,28 @@ std::vector> GraphTable::get_all_id(int type_id, int idx, } return res; } + +std::vector> GraphTable::get_all_feature_ids(int type_id, int idx, + int slice_num) { + std::vector> res(slice_num); + auto &search_shards = type_id == 0 ? edge_shards[idx] : feature_shards[idx]; + std::vector>> tasks; + for (int i = 0; i < search_shards.size(); i++) { + tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( + [&search_shards, i]() -> std::vector { + return search_shards[i]->get_all_feature_ids(); + })); + } + for (size_t i = 0; i < tasks.size(); ++i) { + tasks[i].wait(); + } + for (size_t i = 0; i < tasks.size(); i++) { + auto ids = tasks[i].get(); + for (auto &id : ids) res[id % slice_num].push_back(id); + } + return res; +} + int32_t GraphTable::pull_graph_list(int type_id, int idx, int start, int total_size, std::unique_ptr &buffer, diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.h b/paddle/fluid/distributed/ps/table/common_graph_table.h index 4c3f6c824e0dc..2ec2ba93f4498 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.h +++ b/paddle/fluid/distributed/ps/table/common_graph_table.h @@ -70,6 +70,16 @@ class GraphShard { } return res; } + std::vector get_all_feature_ids() { + // TODO by huwei02, dedup + std::vector total_res; + for (int i = 0; i < (int)bucket.size(); i++) { + std::vector res; + res.push_back(bucket[i]->get_feature_ids(&res)); + total_res.insert(total_res.end(), res.begin(), res.end()); + } + return total_res; + } GraphNode *add_graph_node(uint64_t id); GraphNode *add_graph_node(Node *node); FeatureNode *add_feature_node(uint64_t id); @@ -475,8 +485,11 @@ class GraphTable : public Table { int32_t load_edges(const std::string &path, bool reverse, const std::string &edge_type); + std::vector> get_all_id(int type, int slice_num); std::vector> get_all_id(int type, int idx, int slice_num); + std::vector> get_all_feature_ids(int type, int idx, + int slice_num); int32_t load_nodes(const std::string &path, std::string node_type); int32_t add_graph_node(int idx, std::vector &id_list, @@ -486,6 +499,7 @@ class GraphTable : public Table { int32_t get_server_index_by_id(uint64_t id); Node *find_node(int type_id, int idx, uint64_t id); + Node *find_node(int type_id, uint64_t id); virtual int32_t Pull(TableContext &context) { return 0; } virtual int32_t Push(TableContext &context) { return 0; } @@ -561,7 +575,7 @@ class GraphTable : public Table { virtual paddle::framework::GpuPsCommGraph make_gpu_ps_graph( int idx, std::vector ids); virtual paddle::framework::GpuPsCommGraphFea make_gpu_ps_graph_fea( - int ntype_id, std::vector &node_ids, int slot_num); + std::vector &node_ids, int slot_num); int32_t Load_to_ssd(const std::string &path, const std::string ¶m); int64_t load_graph_to_memory_from_ssd(int idx, std::vector &ids); int32_t make_complementary_graph(int idx, int64_t byte_size); diff --git a/paddle/fluid/distributed/ps/table/graph/graph_node.h b/paddle/fluid/distributed/ps/table/graph/graph_node.h index 7d9dee0294a6a..017cd28ed0705 100644 --- a/paddle/fluid/distributed/ps/table/graph/graph_node.h +++ b/paddle/fluid/distributed/ps/table/graph/graph_node.h @@ -50,6 +50,9 @@ class Node { virtual void to_buffer(char *buffer, bool need_feature); virtual void recover_from_buffer(char *buffer); virtual std::string get_feature(int idx) { return std::string(""); } + virtual int get_feature_ids(std::vector *res) const { + return 0; + } virtual int get_feature_ids(int slot_idx, std::vector *res) const { return 0; } @@ -102,6 +105,25 @@ class FeatureNode : public Node { } } + virtual int get_feature_ids(std::vector *res) const { + PADDLE_ENFORCE_NOT_NULL(res); + res->clear(); + errno = 0; + for (auto& feature_item: feature) { + const char *feat_str = feature_item.c_str(); + auto fields = paddle::string::split_string(feat_str, " "); + char *head_ptr = NULL; + for (auto &field : fields) { + PADDLE_ENFORCE_EQ(field.empty(), false); + uint64_t feasign = strtoull(field.c_str(), &head_ptr, 10); + PADDLE_ENFORCE_EQ(field.c_str() + field.length(), head_ptr); + res->push_back(feasign); + } + } + PADDLE_ENFORCE_EQ(errno, 0); + return 0; + } + virtual int get_feature_ids(int slot_idx, std::vector *res) const { PADDLE_ENFORCE_NOT_NULL(res); res->clear(); diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index b801e55f44252..df9835d0b227f 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -20,11 +20,14 @@ limitations under the License. */ #include #include #include +#include #include "cub/cub.cuh" #include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h" #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h" +DECLARE_int32(batch_num); + namespace paddle { namespace framework { @@ -169,6 +172,45 @@ int GraphDataGenerator::AcquireInstance(BufState *state) { return 0; } +// TODO opt +__global__ void GraphFillFeatureKernel(int64_t *id_tensor, int *fill_ins_num, + int64_t *walk, int64_t *feature, int *row, int central_word, + int step, int len, int col_num, int slot_num) { + __shared__ int64_t local_key[CUDA_NUM_THREADS * 2]; + __shared__ int local_num; + __shared__ int global_num; + + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.x == 0) { + local_num = 0; + } + __syncthreads(); + if (idx < len) { + int src = row[idx] * col_num + central_word; + if (walk[src] != 0 && walk[src + step] != 0) { + size_t dst = atomicAdd(&local_num, 1); + for (int i = 0; i < slot_num; ++i) { + local_key[dst * 2 * slot_num + i * 2] = feature[src * slot_num + i]; + local_key[dst * 2 * slot_num + i * 2 + 1] = feature[(src + step) * slot_num + i]; + } + } + } + + if (threadIdx.x == 0) { + global_num = atomicAdd(fill_ins_num, local_num); + } + __syncthreads(); + + if (threadIdx.x < local_num) { + for (int i = 0; i < slot_num; ++i) { + id_tensor[(global_num * 2 + 2 * threadIdx.x) * slot_num + i] + = local_key[(2 * threadIdx.x) * slot_num + i]; + id_tensor[(global_num * 2 + 2 * threadIdx.x + 1) * slot_num + i] = + local_key[(2 * threadIdx.x + 1) * slot_num + i]; + } + } +} + __global__ void GraphFillIdKernel(int64_t *id_tensor, int *fill_ins_num, int64_t *walk, int *row, int central_word, int step, int len, int col_num) { @@ -205,6 +247,12 @@ __global__ void GraphFillIdKernel(int64_t *id_tensor, int *fill_ins_num, } } +__global__ void GraphFillSlotLodKernel(int64_t *id_tensor, int len) { + CUDA_KERNEL_LOOP(idx, len) { + id_tensor[idx] = idx; + } +} + int GraphDataGenerator::FillInsBuf() { if (ins_buf_pair_len_ >= batch_size_) { return batch_size_; @@ -227,6 +275,27 @@ int GraphDataGenerator::FillInsBuf() { // return -1; //} } + + if (slot_num_ > 0) { + FillFeatureBuf(d_walk_, d_feature_); + if (debug_mode_) { + int len = buf_size_ > 5000? 5000: buf_size_; + uint64_t h_walk[len]; + cudaMemcpy(h_walk, d_walk_->ptr(), len * sizeof(uint64_t), + cudaMemcpyDeviceToHost); + uint64_t h_feature[len * slot_num_]; + cudaMemcpy(h_feature, d_feature_->ptr(), len * slot_num_ * sizeof(uint64_t), + cudaMemcpyDeviceToHost); + for(int i = 0; i < len; ++i) { + std::stringstream ss; + for (int j = 0; j < slot_num_; ++j) { + ss << h_feature[i * slot_num_ + j] << " "; + } + VLOG(2) << "aft FillFeatureBuf, gpu[" << gpuid_ << "] walk[" << i << "] = " << (uint64_t)h_walk[i] + << " feature[" << i * slot_num_ << ".." << (i + 1) * slot_num_ << "] = " << ss.str(); + } + } + } } int64_t *walk = reinterpret_cast(d_walk_->ptr()); @@ -242,6 +311,19 @@ int GraphDataGenerator::FillInsBuf() { int h_pair_num; cudaMemcpyAsync(&h_pair_num, d_pair_num, sizeof(int), cudaMemcpyDeviceToHost, stream_); + + int64_t *feature_buf = reinterpret_cast(d_feature_buf_->ptr()); + if (slot_num_ > 0) { + int64_t *feature = reinterpret_cast(d_feature_->ptr()); + cudaMemsetAsync(d_pair_num, 0, sizeof(int), stream_); + int len = buf_state_.len; + VLOG(2) << "feature_buf start[" << ins_buf_pair_len_ * 2 * slot_num_ << "] len[" << len << "]"; + GraphFillFeatureKernel<<>>( + feature_buf + ins_buf_pair_len_ * 2 * slot_num_, d_pair_num, walk, feature, + random_row + buf_state_.cursor, buf_state_.central_word, + window_step_[buf_state_.step], len, walk_len_, slot_num_); + } + cudaStreamSynchronize(stream_); ins_buf_pair_len_ += h_pair_num; @@ -255,11 +337,27 @@ int GraphDataGenerator::FillInsBuf() { VLOG(2) << "h_ins_buf[" << xx << "]: " << h_ins_buf[xx]; } delete[] h_ins_buf; + + int64_t h_feature_buf[(batch_size_ * 2 * 2) * slot_num_]; + cudaMemcpy(h_feature_buf, feature_buf, (batch_size_ * 2 * 2) * slot_num_ * sizeof(int64_t), + cudaMemcpyDeviceToHost); + for (int xx = 0; xx < (batch_size_ * 2 * 2) * slot_num_; xx++) { + VLOG(2) << "h_feature_buf[" << xx << "]: " << h_feature_buf[xx]; + } } return ins_buf_pair_len_; } + +int times = 0; int GraphDataGenerator::GenerateBatch() { + times += 1; + VLOG(0) << "Begin batch " << times; + if (times > FLAGS_batch_num) { + VLOG(0) << "close batch"; + return 0; + } + platform::CUDADeviceGuard guard(gpuid_); int res = 0; while (ins_buf_pair_len_ < batch_size_) { @@ -282,6 +380,18 @@ int GraphDataGenerator::GenerateBatch() { feed_vec_[1]->mutable_data({total_instance}, this->place_); clk_tensor_ptr_ = feed_vec_[2]->mutable_data({total_instance}, this->place_); + + int64_t* slot_tensor_ptr_[slot_num_]; + int64_t* slot_lod_tensor_ptr_[slot_num_]; + if (slot_num_ > 0) { + for (int i = 0; i < slot_num_; ++i) { + slot_tensor_ptr_[i] = + feed_vec_[3 + 2 * i]->mutable_data({total_instance, 1}, this->place_); + slot_lod_tensor_ptr_[i] = + feed_vec_[3 + 2 * i + 1]->mutable_data({total_instance + 1}, this->place_); + } + } + VLOG(2) << "total_instance: " << total_instance << ", ins_buf_pair_len = " << ins_buf_pair_len_; int64_t *ins_buf = reinterpret_cast(d_ins_buf_->ptr()); @@ -294,15 +404,57 @@ int GraphDataGenerator::GenerateBatch() { GraphFillCVMKernel<<>>(clk_tensor_ptr_, total_instance); + if (slot_num_ > 0) { + int64_t *feature_buf = reinterpret_cast(d_feature_buf_->ptr()); + for (int i = 0; i < slot_num_; ++i) { + int feature_buf_offset = (ins_buf_pair_len_ * 2 - total_instance) * slot_num_ + i * 2; + // TODO huwei02 opt + for (int j = 0; j < total_instance; j += 2) { + VLOG(2) << "slot_tensor[" << i << "][" << j << "] <- feature_buf[" << feature_buf_offset + j * 8 << "]"; + VLOG(2) << "slot_tensor[" << i << "][" << j + 1 << "] <- feature_buf[" << feature_buf_offset + j * 8 + 1 << "]"; + cudaMemcpyAsync(slot_tensor_ptr_[i] + j, &feature_buf[feature_buf_offset + j * 8], sizeof(int64_t) * 2, + cudaMemcpyDeviceToDevice, stream_); + } + GraphFillSlotLodKernel<<>>( + slot_lod_tensor_ptr_[i], total_instance + 1); + } + } + offset_.clear(); offset_.push_back(0); offset_.push_back(total_instance); LoD lod{offset_}; feed_vec_[0]->set_lod(lod); + if (slot_num_ > 0) { + for (int i = 0; i < slot_num_; ++i) { + feed_vec_[3 + 2 * i]->set_lod(lod); + } + } ins_buf_pair_len_ -= total_instance / 2; cudaStreamSynchronize(stream_); + + if (debug_mode_) { + int64_t h_slot_tensor[slot_num_][total_instance]; + int64_t h_slot_lod_tensor[slot_num_][total_instance + 1]; + for (int i = 0; i < slot_num_; ++i) { + cudaMemcpy(h_slot_tensor[i], slot_tensor_ptr_[i], total_instance * sizeof(int64_t), + cudaMemcpyDeviceToHost); + int len = total_instance > 5000? 5000: total_instance; + for(int j = 0; j < len; ++j) { + VLOG(2) << "gpu[" << gpuid_ << "] slot_tensor[" << i <<"][" << j << "] = " << h_slot_tensor[i][j]; + } + + cudaMemcpy(h_slot_lod_tensor[i], slot_lod_tensor_ptr_[i], (total_instance + 1) * sizeof(int64_t), + cudaMemcpyDeviceToHost); + len = total_instance + 1 > 5000? 5000: total_instance + 1; + for(int j = 0; j < len; ++j) { + VLOG(2) << "gpu[" << gpuid_ << "] slot_lod_tensor[" << i <<"][" << j << "] = " << h_slot_lod_tensor[i][j]; + } + } + } + return 1; } @@ -425,7 +577,17 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, uint64_t *walk, cur_sampleidx2row_ = 1 - cur_sampleidx2row_; } +int GraphDataGenerator::FillFeatureBuf(std::shared_ptr d_walk, + std::shared_ptr d_feature) { + platform::CUDADeviceGuard guard(gpuid_); + + auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); + int ret = gpu_graph_ptr->get_feature_of_nodes(gpuid_, d_walk, d_feature, buf_size_, slot_num_); + return ret; +} + int GraphDataGenerator::FillWalkBuf(std::shared_ptr d_walk) { + VLOG(0) << "begin FillWalkBuf"; platform::CUDADeviceGuard guard(gpuid_); size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_; //////// @@ -579,6 +741,7 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place &place, platform::DeviceContextPool::Instance().Get(place)) ->stream(); feed_vec_ = feed_vec; + slot_num_ = (feed_vec_.size() - 3) / 2; // d_device_keys_.resize(h_device_keys_.size()); VLOG(2) << "h_device_keys size: " << h_device_keys_.size(); @@ -612,6 +775,8 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place &place, jump_rows_ = 0; d_walk_ = memory::AllocShared(place_, buf_size_ * sizeof(uint64_t)); cudaMemsetAsync(d_walk_->ptr(), 0, buf_size_ * sizeof(uint64_t), stream_); + d_feature_ = memory::AllocShared(place_, buf_size_ * slot_num_ * sizeof(uint64_t)); + cudaMemsetAsync(d_feature_->ptr(), 0, buf_size_ * sizeof(uint64_t), stream_); d_sample_keys_ = memory::AllocShared(place_, once_max_sample_keynum * sizeof(uint64_t)); @@ -638,6 +803,8 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place &place, ins_buf_pair_len_ = 0; d_ins_buf_ = memory::AllocShared(place_, (batch_size_ * 2 * 2) * sizeof(int64_t)); + d_feature_buf_ = + memory::AllocShared(place_, (batch_size_ * 2 * 2) * slot_num_ * sizeof(int64_t)); d_pair_num_ = memory::AllocShared(place_, sizeof(int)); cudaStreamSynchronize(stream_); diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 2fc0d242198cc..9c44de182e158 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -895,6 +895,8 @@ class GraphDataGenerator { int AcquireInstance(BufState* state); int GenerateBatch(); int FillWalkBuf(std::shared_ptr d_walk); + int FillFeatureBuf(std::shared_ptr d_walk, + std::shared_ptr d_feature); void FillOneStep(uint64_t* start_ids, uint64_t* walk, int len, NeighborSampleResult& sample_res, int cur_degree, int step, int* len_per_row); @@ -929,6 +931,7 @@ class GraphDataGenerator { std::vector> d_device_keys_; std::shared_ptr d_walk_; + std::shared_ptr d_feature_; std::shared_ptr d_len_per_row_; std::shared_ptr d_random_row_; // @@ -942,6 +945,7 @@ class GraphDataGenerator { std::unordered_map node_type_start_; std::shared_ptr d_ins_buf_; + std::shared_ptr d_feature_buf_; std::shared_ptr d_pair_num_; int ins_buf_pair_len_; // size of a d_walk buf @@ -950,6 +954,7 @@ class GraphDataGenerator { std::vector window_step_; BufState buf_state_; int batch_size_; + int slot_num_; int shuffle_seed_; int debug_mode_; std::vector first_node_type_; diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 8e3462edfde88..b7a6f28e0b30b 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -462,7 +462,7 @@ void DatasetImpl::LoadIntoMemory() { auto& type_total_key = graph_all_type_total_keys_[cnt]; type_total_key.resize(thread_num_); for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) { - VLOG(2) << "type: " << node_idx << ", gpu_graph_device_keys[" << i + VLOG(2) << "node type: " << node_idx << ", gpu_graph_device_keys[" << i << "] = " << gpu_graph_device_keys[i].size(); for (size_t j = 0; j < gpu_graph_device_keys[i].size(); j++) { gpu_graph_total_keys_.push_back(gpu_graph_device_keys[i][j]); @@ -475,6 +475,18 @@ void DatasetImpl::LoadIntoMemory() { } cnt++; } + for (auto& iter : node_to_id) { + int node_idx = iter.second; + auto gpu_graph_device_keys = + gpu_graph_ptr->get_all_feature_ids(1, node_idx, thread_num_); + for (size_t i = 0; i < gpu_graph_device_keys.size(); i++) { + VLOG(2) << "node type: " << node_idx << ", gpu_graph_device_keys[" << i + << "] = " << gpu_graph_device_keys[i].size(); + for (size_t j = 0; j < gpu_graph_device_keys[i].size(); j++) { + gpu_graph_total_keys_.push_back(gpu_graph_device_keys[i][j]); + } + } + } // FIX: trick for iterate edge table for (auto& iter : edge_to_id) { int edge_idx = iter.second; diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h index b1ff5852d5985..c4231cb7beb8b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h @@ -26,24 +26,25 @@ namespace framework { enum GraphTableType { EDGE_TABLE, FEATURE_TABLE }; class GpuPsGraphTable : public HeterComm { public: - int get_table_offset(int gpu_id, GraphTableType type, int idx) { + int get_table_offset(int gpu_id, GraphTableType type, int idx) const { int type_id = type; return gpu_id * (graph_table_num_ + feature_table_num_) + type_id * graph_table_num_ + idx; } GpuPsGraphTable(std::shared_ptr resource, int topo_aware, - int graph_table_num, int feature_table_num) + int graph_table_num) : HeterComm(1, resource) { load_factor_ = 0.25; rw_lock.reset(new pthread_rwlock_t()); this->graph_table_num_ = graph_table_num; - this->feature_table_num_ = feature_table_num; + this->feature_table_num_ = 1; gpu_num = resource_->total_device(); memset(global_device_map, -1, sizeof(global_device_map)); for (auto &table : tables_) { delete table; table = NULL; } + int feature_table_num = 1; tables_ = std::vector( gpu_num * (graph_table_num + feature_table_num), NULL); for (int i = 0; i < gpu_num; i++) { @@ -108,7 +109,7 @@ class GpuPsGraphTable : public HeterComm { // } } void build_graph_on_single_gpu(GpuPsCommGraph &g, int gpu_id, int idx); - void build_graph_fea_on_single_gpu(GpuPsCommGraphFea &g, int gpu_id, int idx); + void build_graph_fea_on_single_gpu(GpuPsCommGraphFea &g, int gpu_id); void clear_graph_info(int gpu_id, int index); void clear_graph_info(int index); void clear_feature_info(int gpu_id, int index); @@ -125,10 +126,15 @@ class GpuPsGraphTable : public HeterComm { NeighborSampleResult graph_neighbor_sample_v2(int gpu_id, int idx, uint64_t *key, int sample_size, int len, bool cpu_query_switch); + + int get_feature_of_nodes(int gpu_id, + std::shared_ptr d_walk, + std::shared_ptr d_offset, int size, int slot_num); + NodeQueryResult query_node_list(int gpu_id, int idx, int start, int query_size); void display_sample_res(void *key, void *val, int len, int sample_len); - void move_neighbor_sample_result_to_source_gpu(int gpu_id, int gpu_num, + void move_result_to_source_gpu(int gpu_id, int gpu_num, int sample_size, int *h_left, int *h_right, uint64_t *src_sample_res, diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu index 342b1ac09add5..f423a33abe349 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu @@ -68,6 +68,38 @@ __global__ void copy_buffer_ac_to_final_place( } } +__global__ void get_features_kernel(GpuPsCommGraphFea graph, int64_t* node_offset_array, + int* actual_size, uint64_t* feature, int slot_num, int n) { + int idx = blockIdx.x * blockDim.y + threadIdx.y; + if (idx < n) { + int node_offset = node_offset_array[idx]; + int offset = idx * slot_num; + if (node_offset == -1) { + for (int k = 0; k < slot_num; ++ k) { + feature[offset + k] = 0; + } + actual_size[idx] = slot_num; + return; + } + + GpuPsGraphFeaNode* node = &(graph.node_list[node_offset]); + uint64_t* feature_start = &(graph.feature_list[node->feature_offset]); + uint8_t* slot_id_start = &(graph.slot_id_list[node->feature_offset]); + int m = 0; + for (int k = 0; k < slot_num; ++k) { + if (m >= node->feature_size || k < slot_id_start[m]) { + feature[offset + k] = 0; + } else if (k == slot_id_start[m]) { + feature[offset + k] = feature_start[m]; + ++m; + } else { + assert(0); + } + } + actual_size[idx] = slot_num; + } +} + template __global__ void neighbor_sample_kernel(GpuPsCommGraph graph, int64_t* node_index, int* actual_size, @@ -182,7 +214,7 @@ void GpuPsGraphTable::display_sample_res(void* key, void* val, int len, } } -void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu( +void GpuPsGraphTable::move_result_to_source_gpu( int start_index, int gpu_num, int sample_size, int* h_left, int* h_right, uint64_t* src_sample_res, int* actual_sample_size) { int shard_len[gpu_num]; @@ -238,6 +270,17 @@ __global__ void fill_dvalues(uint64_t* d_shard_vals, uint64_t* d_vals, } } +__global__ void fill_dvalues(uint64_t* d_shard_vals, uint64_t* d_vals, + int* d_shard_actual_sample_size, + int* idx, int sample_size, int len) { + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < len) { + for (int j = 0; j < sample_size; j++) { + d_vals[idx[i] * sample_size + j] = d_shard_vals[i * sample_size + j]; + } + } +} + __global__ void fill_actual_vals(uint64_t* vals, uint64_t* actual_vals, int* actual_sample_size, int* cumsum_actual_sample_size, @@ -258,7 +301,8 @@ __global__ void node_query_example(GpuPsCommGraph graph, int start, int size, } } -void GpuPsGraphTable::clear_feature_info(int gpu_id, int idx) { +void GpuPsGraphTable::clear_feature_info(int gpu_id) { + int idx = 0; if (idx >= feature_table_num_) return; int offset = get_table_offset(gpu_id, GraphTableType::FEATURE_TABLE, idx); if (offset < tables_.size()) { @@ -283,9 +327,6 @@ void GpuPsGraphTable::clear_feature_info(int gpu_id, int idx) { cudaFree(graph.node_list); } } -void GpuPsGraphTable::clear_feature_info(int idx) { - for (int i = 0; i < gpu_num; i++) clear_feature_info(i, idx); -} void GpuPsGraphTable::clear_graph_info(int gpu_id, int idx) { if (idx >= graph_table_num_) return; @@ -314,8 +355,9 @@ In this function, memory is allocated on each gpu to save the graphs, gpu i saves the ith graph from cpu_graph_list */ void GpuPsGraphTable::build_graph_fea_on_single_gpu(GpuPsCommGraphFea& g, - int gpu_id, int ntype_id) { - clear_feature_info(gpu_id, ntype_id); + int gpu_id) { + clear_feature_info(gpu_id); + int ntype_id = 0; platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); @@ -686,7 +728,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2( } cudaStreamSynchronize(resource_->remote_stream(i, gpu_id)); } - move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, sample_size, + move_result_to_source_gpu(gpu_id, total_gpu, sample_size, h_left, h_right, d_shard_vals_ptr, d_shard_actual_sample_size_ptr); fill_dvalues<<>>( @@ -856,6 +898,113 @@ NodeQueryResult GpuPsGraphTable::query_node_list(int gpu_id, int idx, int start, cudaStreamSynchronize(resource_->remote_stream(gpu_id, gpu_id)); return result; } + +int GpuPsGraphTable::get_feature_of_nodes(int gpu_id, std::shared_ptr d_nodes, + std::shared_ptr d_feature, int node_num, int slot_num) { + if (node_num == 0) { + return -1; + } + + platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id)); + platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id)); + int total_gpu = resource_->total_device(); + auto stream = resource_->local_stream(gpu_id, 0); + + auto d_left = memory::Alloc(place, total_gpu * sizeof(int)); + auto d_right = memory::Alloc(place, total_gpu * sizeof(int)); + int* d_left_ptr = reinterpret_cast(d_left->ptr()); + int* d_right_ptr = reinterpret_cast(d_right->ptr()); + + cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream); + cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream); + // + auto d_idx = memory::Alloc(place, node_num * sizeof(int)); + int* d_idx_ptr = reinterpret_cast(d_idx->ptr()); + + auto d_shard_keys = memory::Alloc(place, node_num * sizeof(uint64_t)); + uint64_t* d_shard_keys_ptr = reinterpret_cast(d_shard_keys->ptr()); + auto d_shard_vals = memory::Alloc(place, slot_num * node_num * sizeof(uint64_t)); + uint64_t* d_shard_vals_ptr = reinterpret_cast(d_shard_vals->ptr()); + auto d_shard_actual_size = memory::Alloc(place, node_num * sizeof(int)); + int* d_shard_actual_size_ptr = reinterpret_cast(d_shard_actual_size->ptr()); + + uint64_t* key = (uint64_t*)d_nodes->ptr(); + split_input_to_shard((uint64_t*)(key), d_idx_ptr, node_num, d_left_ptr, d_right_ptr, gpu_id); + + heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, node_num, stream); + cudaStreamSynchronize(stream); + + int h_left[total_gpu]; // NOLINT + cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int), cudaMemcpyDeviceToHost); + int h_right[total_gpu]; // NOLINT + cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int), cudaMemcpyDeviceToHost); + for (int i = 0; i < total_gpu; ++i) { + int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; + if (shard_len == 0) { + continue; + } + create_storage(gpu_id, i, shard_len * sizeof(uint64_t), + shard_len * slot_num * sizeof(uint64_t) + shard_len * sizeof(int64_t) + + sizeof(int) * (shard_len + shard_len % 2)); + } + + walk_to_dest(gpu_id, total_gpu, h_left, h_right, (uint64_t*)(d_shard_keys_ptr), NULL); + + for (int i = 0; i < total_gpu; ++i) { + if (h_left[i] == -1) { + continue; + } + int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; + auto& node = path_[gpu_id][i].nodes_.back(); + cudaMemsetAsync(node.val_storage, -1, shard_len * sizeof(int64_t), node.in_stream); + cudaStreamSynchronize(node.in_stream); + platform::CUDADeviceGuard guard(resource_->dev_id(i)); + // If not found, val is -1. + int table_offset = get_table_offset(i, GraphTableType::FEATURE_TABLE, 0); + tables_[table_offset]->get(reinterpret_cast(node.key_storage), + reinterpret_cast(node.val_storage), + h_right[i] - h_left[i] + 1, + resource_->remote_stream(i, gpu_id)); + + int offset = i * feature_table_num_; + auto graph = gpu_graph_fea_list_[offset]; + int64_t* val_array = reinterpret_cast(node.val_storage); + int* actual_size_array = (int*)(val_array + shard_len); + uint64_t* feature_array = (uint64_t*)(actual_size_array + shard_len + shard_len % 2); + dim3 grid((shard_len - 1) / dim_y + 1); + dim3 block(1, dim_y); + get_features_kernel<<remote_stream(i, gpu_id)>>>( + graph, val_array, actual_size_array, feature_array, slot_num, shard_len); + } + + for (int i = 0; i < total_gpu; ++i) { + if (h_left[i] == -1) { + continue; + } + cudaStreamSynchronize(resource_->remote_stream(i, gpu_id)); + } + + move_result_to_source_gpu(gpu_id, total_gpu, slot_num, h_left, h_right, + d_shard_vals_ptr, d_shard_actual_size_ptr); + + int grid_size = (node_num - 1) / block_size_ + 1; + uint64_t* result = (uint64_t*)d_feature->ptr(); + fill_dvalues<<>>(d_shard_vals_ptr, result, + d_shard_actual_size_ptr, d_idx_ptr, slot_num, node_num); + + for (int i = 0; i < total_gpu; ++i) { + int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1; + if (shard_len == 0) { + continue; + } + destroy_storage(gpu_id, i); + } + + cudaStreamSynchronize(stream); + + return 0; +} + } }; #endif diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu index 09d89fb49f1ac..e329d4e44ba13 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu @@ -27,12 +27,27 @@ void GraphGpuWrapper::set_device(std::vector ids) { device_id_mapping.push_back(device_id); } } + +std::vector> GraphGpuWrapper::get_all_id(int type, + int slice_num) { + return ((GpuPsGraphTable *)graph_table) + ->cpu_graph_table_->get_all_id(type, slice_num); +} + std::vector> GraphGpuWrapper::get_all_id(int type, int idx, int slice_num) { return ((GpuPsGraphTable *)graph_table) ->cpu_graph_table_->get_all_id(type, idx, slice_num); } + +std::vector> GraphGpuWrapper::get_all_feature_ids(int type, + int idx, + int slice_num) { + return ((GpuPsGraphTable *)graph_table) + ->cpu_graph_table_->get_all_feature_ids(type, idx, slice_num); +} + void GraphGpuWrapper::set_up_types(std::vector &edge_types, std::vector &node_types) { id_to_edge = edge_types; @@ -53,6 +68,10 @@ void GraphGpuWrapper::set_up_types(std::vector &edge_types, void GraphGpuWrapper::set_feature_separator(std::string ch) { feature_separator_ = ch; + if (graph_table != nullptr) { + ((GpuPsGraphTable *)graph_table) + ->cpu_graph_table_->set_feature_separator(feature_separator_); + } } void GraphGpuWrapper::make_partitions(int idx, int64_t byte_size, @@ -160,7 +179,7 @@ void GraphGpuWrapper::init_service() { std::make_shared(device_id_mapping); resource->enable_p2p(); GpuPsGraphTable *g = - new GpuPsGraphTable(resource, 1, id_to_edge.size(), id_to_feature.size()); + new GpuPsGraphTable(resource, 1, id_to_edge.size()); g->init_cpu_table(table_proto); g->cpu_graph_table_->set_feature_separator(feature_separator_); graph_table = (char *)g; @@ -191,12 +210,12 @@ void GraphGpuWrapper::upload_batch(int ntype_id, VLOG(0) << "begin make_gpu_ps_graph_fea, node_ids[" << i << "]_size[" << node_ids[i].size() << "]"; GpuPsCommGraphFea sub_graph = g->cpu_graph_table_->make_gpu_ps_graph_fea( - ntype_id, node_ids[i], slot_num); + node_ids[i], slot_num); // sub_graph.display_on_cpu(); VLOG(0) << "begin build_graph_fea_on_single_gpu, node_ids[" << i << "]_size[" << node_ids[i].size() << "]"; - g->build_graph_fea_on_single_gpu(sub_graph, i, ntype_id); + g->build_graph_fea_on_single_gpu(sub_graph, i); sub_graph.release_on_cpu(); @@ -212,6 +231,15 @@ NeighborSampleResult GraphGpuWrapper::graph_neighbor_sample_v3( ->graph_neighbor_sample_v3(q, cpu_switch); } +int GraphGpuWrapper::get_feature_of_nodes(int gpu_id, + std::shared_ptr d_walk, + std::shared_ptr d_offset, uint32_t size, int slot_num) const { + platform::CUDADeviceGuard guard(gpu_id); + PADDLE_ENFORCE_NOT_NULL(graph_table); + return ((GpuPsGraphTable *)graph_table) + ->get_feature_of_nodes(gpu_id, d_walk, d_offset, size, slot_num); +} + NeighborSampleResult GraphGpuWrapper::graph_neighbor_sample( int gpu_id, uint64_t *device_keys, int walk_degree, int len) { platform::CUDADeviceGuard guard(gpu_id); diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h index e140eac5d9c79..992188e172be7 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h @@ -36,7 +36,7 @@ class GraphGpuWrapper { void set_up_types(std::vector& edge_type, std::vector& node_type); void upload_batch(int etype_id, std::vector>& ids); - void upload_batch(int ntype_id, std::vector>& ids, + void upload_batch(std::vector>& ids, int slot_num); void add_table_feat_conf(std::string table_name, std::string feat_name, std::string feat_dtype, int feat_shape); @@ -51,8 +51,11 @@ class GraphGpuWrapper { void make_complementary_graph(int idx, int64_t byte_size); void set_search_level(int level); void init_search_level(int level); + std::vector> get_all_id(int type, int slice_num); std::vector> get_all_id(int type, int idx, int slice_num); + std::vector> get_all_feature_ids(int type, int idx, + int slice_num); NodeQueryResult query_node_list(int gpu_id, int idx, int start, int query_size); NeighborSampleResult graph_neighbor_sample_v3(NeighborSampleQuery q, @@ -63,6 +66,9 @@ class GraphGpuWrapper { std::vector& key, int sample_size); void set_feature_separator(std::string ch); + int get_feature_of_nodes(int gpu_id, + std::shared_ptr d_walk, + std::shared_ptr d_offset, uint32_t size, int slot_num) const; std::unordered_map edge_to_id, feature_to_id; std::vector id_to_feature, id_to_edge; diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index 054a804e6b38e..f2b5b2907c049 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -88,6 +88,25 @@ PADDLE_DEFINE_EXPORTED_bool( "input and output must be half precision) and recurrent neural networks " "(RNNs)."); +/** + * CUDA related FLAG + * Name: FLAGS_selected_gpus + * Since Version: 1.3.0 + * Value Range: integer list separated by comma, default empty list + * Example: FLAGS_selected_gpus=0,1,2,3,4,5,6,7 to train or predict with 0~7 gpu + * cards + * Note: A list of device ids separated by comma, like: 0,1,2,3 + */ +PADDLE_DEFINE_EXPORTED_int32( + batch_num, 0, + "A list of device ids separated by comma, like: 0,1,2,3. " + "This option is useful when doing multi process training and " + "each process have only one device (GPU). If you want to use " + "all visible devices, set this to empty string. NOTE: the " + "reason of doing this is that we want to use P2P communication" + "between GPU devices, use CUDA_VISIBLE_DEVICES can only use" + "share-memory only."); + /** * CUDA related FLAG * Name: FLAGS_selected_gpus diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 9cf6ae9e9981f..e2f4feebf9e3a 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -353,9 +353,10 @@ void BindGraphGpuWrapper(py::module* m) { py::overload_cast>&>( &GraphGpuWrapper::upload_batch)) .def("upload_batch", - py::overload_cast>&, int>( + py::overload_cast>&, int>( &GraphGpuWrapper::upload_batch)) - .def("get_all_id", &GraphGpuWrapper::get_all_id) + .def("get_all_id", py::overload_cast(&GraphGpuWrapper::get_all_id)) + .def("get_all_id", py::overload_cast(&GraphGpuWrapper::get_all_id)) .def("load_next_partition", &GraphGpuWrapper::load_next_partition) .def("make_partitions", &GraphGpuWrapper::make_partitions) .def("make_complementary_graph", From 7621775a775d192f6bf78187c84ad464e7f777cd Mon Sep 17 00:00:00 2001 From: root Date: Wed, 8 Jun 2022 14:09:54 +0800 Subject: [PATCH 2/5] search and fill slot_feature, fix compile error --- paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu index e329d4e44ba13..31d603985c4dc 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu +++ b/paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu @@ -201,8 +201,7 @@ void GraphGpuWrapper::upload_batch(int idx, } // feature table -void GraphGpuWrapper::upload_batch(int ntype_id, - std::vector> &node_ids, +void GraphGpuWrapper::upload_batch(std::vector> &node_ids, int slot_num) { debug_gpu_memory_info("upload_batch feature start"); GpuPsGraphTable *g = (GpuPsGraphTable *)graph_table; From 070ffa77bc8351e6192424173a517b5171bc20b5 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 8 Jun 2022 15:16:14 +0800 Subject: [PATCH 3/5] search and fill slot_feature, rename 8 as slot_num_ --- paddle/fluid/framework/data_feed.cu | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index df9835d0b227f..8290cda5cd1bd 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -410,10 +410,12 @@ int GraphDataGenerator::GenerateBatch() { int feature_buf_offset = (ins_buf_pair_len_ * 2 - total_instance) * slot_num_ + i * 2; // TODO huwei02 opt for (int j = 0; j < total_instance; j += 2) { - VLOG(2) << "slot_tensor[" << i << "][" << j << "] <- feature_buf[" << feature_buf_offset + j * 8 << "]"; - VLOG(2) << "slot_tensor[" << i << "][" << j + 1 << "] <- feature_buf[" << feature_buf_offset + j * 8 + 1 << "]"; - cudaMemcpyAsync(slot_tensor_ptr_[i] + j, &feature_buf[feature_buf_offset + j * 8], sizeof(int64_t) * 2, - cudaMemcpyDeviceToDevice, stream_); + VLOG(2) << "slot_tensor[" << i << "][" << j << "] <- feature_buf[" + << feature_buf_offset + j * slot_num_ << "]"; + VLOG(2) << "slot_tensor[" << i << "][" << j + 1 << "] <- feature_buf[" + << feature_buf_offset + j * slot_num_ + 1 << "]"; + cudaMemcpyAsync(slot_tensor_ptr_[i] + j, &feature_buf[feature_buf_offset + j * slot_num_], + sizeof(int64_t) * 2, cudaMemcpyDeviceToDevice, stream_); } GraphFillSlotLodKernel<<>>( slot_lod_tensor_ptr_[i], total_instance + 1); From d60e789981b768631dc7045b8d306705637f40ac Mon Sep 17 00:00:00 2001 From: root Date: Wed, 8 Jun 2022 17:56:16 +0800 Subject: [PATCH 4/5] remove debug code --- paddle/fluid/framework/data_feed.cu | 10 ---------- paddle/fluid/platform/flags.cc | 19 ------------------- 2 files changed, 29 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index 63e0ce5a5e71e..9bd23c94d96b9 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -26,8 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h" #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h" -DECLARE_int32(batch_num); - namespace paddle { namespace framework { @@ -349,15 +347,7 @@ int GraphDataGenerator::FillInsBuf() { } -int times = 0; int GraphDataGenerator::GenerateBatch() { - times += 1; - VLOG(0) << "Begin batch " << times; - if (times > FLAGS_batch_num) { - VLOG(0) << "close batch"; - return 0; - } - platform::CUDADeviceGuard guard(gpuid_); int res = 0; while (ins_buf_pair_len_ < batch_size_) { diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index f2b5b2907c049..054a804e6b38e 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -88,25 +88,6 @@ PADDLE_DEFINE_EXPORTED_bool( "input and output must be half precision) and recurrent neural networks " "(RNNs)."); -/** - * CUDA related FLAG - * Name: FLAGS_selected_gpus - * Since Version: 1.3.0 - * Value Range: integer list separated by comma, default empty list - * Example: FLAGS_selected_gpus=0,1,2,3,4,5,6,7 to train or predict with 0~7 gpu - * cards - * Note: A list of device ids separated by comma, like: 0,1,2,3 - */ -PADDLE_DEFINE_EXPORTED_int32( - batch_num, 0, - "A list of device ids separated by comma, like: 0,1,2,3. " - "This option is useful when doing multi process training and " - "each process have only one device (GPU). If you want to use " - "all visible devices, set this to empty string. NOTE: the " - "reason of doing this is that we want to use P2P communication" - "between GPU devices, use CUDA_VISIBLE_DEVICES can only use" - "share-memory only."); - /** * CUDA related FLAG * Name: FLAGS_selected_gpus From d397050cf983604ed1053ad888639ddd3d10f7a6 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 8 Jun 2022 18:48:46 +0800 Subject: [PATCH 5/5] remove debug code --- paddle/fluid/framework/data_feed.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index 9bd23c94d96b9..1814fa44da62c 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -579,7 +579,6 @@ int GraphDataGenerator::FillFeatureBuf(std::shared_ptr d_walk, } int GraphDataGenerator::FillWalkBuf(std::shared_ptr d_walk) { - VLOG(0) << "begin FillWalkBuf"; platform::CUDADeviceGuard guard(gpuid_); size_t once_max_sample_keynum = walk_degree_ * once_sample_startid_len_; ////////