Skip to content

Commit

Permalink
[GpuGraph] same first_node_type share start (xuewujiao#19)
Browse files Browse the repository at this point in the history
* same first nodetype share start

* strategy
  • Loading branch information
Thunderbrook authored Jun 8, 2022
1 parent 4ad7760 commit a136431
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ int GraphDataGenerator::FillWalkBuf(std::shared_ptr<phi::Allocation> d_walk) {
int cur_node_idx = cursor_ % node_type_len;
int node_type = first_node_type_[cur_node_idx];
auto &path = meta_path_[cur_node_idx];
size_t start = node_type_start_[cur_node_idx];
size_t start = node_type_start_[node_type];
// auto node_query_result = gpu_graph_ptr->query_node_list(
// gpuid_, node_type, start, once_sample_startid_len_);

Expand All @@ -474,10 +474,10 @@ int GraphDataGenerator::FillWalkBuf(std::shared_ptr<phi::Allocation> d_walk) {
int tmp_len = start + once_sample_startid_len_ > device_key_size
? device_key_size - start
: once_sample_startid_len_;
node_type_start_[cur_node_idx] = tmp_len + start;
node_type_start_[node_type] = tmp_len + start;
if (tmp_len == 0) {
finish_node_type_.insert(cur_node_idx);
if (finish_node_type_.size() == node_type_len) {
finish_node_type_.insert(node_type);
if (finish_node_type_.size() == node_type_start_.size()) {
break;
}
cursor_ += 1;
Expand Down Expand Up @@ -682,9 +682,7 @@ void GraphDataGenerator::SetConfig(
platform::errors::NotFound("(%s) is not found in node_to_id.", type));
VLOG(2) << "node_to_id[" << type << "] = " << iter->second;
first_node_type_.push_back(iter->second);
}
for (size_t i = 0; i < node_types.size(); i++) {
node_type_start_[i] = 0;
node_type_start_[iter->second] = 0;
}
meta_path_.resize(first_node_type_.size());
auto meta_paths = paddle::string::split_string<std::string>(meta_path, ";");
Expand Down

0 comments on commit a136431

Please sign in to comment.