Skip to content

Commit

Permalink
optimize the data structure from c++ to python to speed up sampling i…
Browse files Browse the repository at this point in the history
…n graph engine
  • Loading branch information
Liwb5 committed Nov 17, 2021
1 parent ca8c4f3 commit 06847bc
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 11 deletions.
32 changes: 29 additions & 3 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,18 +289,44 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
status.wait();
}
}
std::vector<std::vector<std::pair<uint64_t, float>>>

std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>>
GraphPyClient::batch_sample_neighbors(std::string name,
std::vector<uint64_t> node_ids,
int sample_size) {
int sample_size, bool return_weight,
bool return_edges) {
std::vector<std::vector<std::pair<uint64_t, float>>> v;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status =
worker_ptr->batch_sample_neighbors(table_id, node_ids, sample_size, v);
status.wait();
}
return v;

// res.first[0]: neighbors (nodes)
// res.first[1]: slice index
// res.first[2]: src nodes
// res.second: edges weight
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> res;
res.first.push_back({});
res.first.push_back({});
if (return_edges) res.first.push_back({});
for (size_t i = 0; i < v.size(); i++) {
for (size_t j = 0; j < v[i].size(); j++) {
res.first[0].push_back(v[i][j].first);
if (return_edges) res.first[2].push_back(node_ids[i]);
if (return_weight) res.second.push_back(v[i][j].second);
}
if (i == v.size() - 1) break;

if (i == 0) {
res.first[1].push_back(v[i].size());
} else {
res.first[1].push_back(v[i].size() + res.first[1].back());
}
}

return res;
}

void GraphPyClient::use_neighbors_sample_cache(std::string name,
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/distributed/service/graph_py_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,10 @@ class GraphPyClient : public GraphPyService {
int get_client_id() { return client_id; }
void set_client_id(int client_id) { this->client_id = client_id; }
void start_client();
std::vector<std::vector<std::pair<uint64_t, float>>> batch_sample_neighbors(
std::string name, std::vector<uint64_t> node_ids, int sample_size);
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>>
batch_sample_neighbors(std::string name, std::vector<uint64_t> node_ids,
int sample_size, bool return_weight,
bool return_edges);
std::vector<uint64_t> random_sample_nodes(std::string name, int server_index,
int sample_size);
std::vector<std::vector<std::string>> get_node_feat(
Expand Down
14 changes: 8 additions & 6 deletions paddle/fluid/distributed/test/graph_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -556,15 +556,17 @@ void RunBrpcPushSparse() {
ASSERT_EQ(count_item_nodes.size(), 12);
}

vs = client1.batch_sample_neighbors(std::string("user2item"),
std::vector<uint64_t>(1, 96), 4);
ASSERT_EQ(vs[0].size(), 3);
std::pair<std::vector<std::vector<uint64_t>>, std::vector<float>> res;
res = client1.batch_sample_neighbors(
std::string("user2item"), std::vector<uint64_t>(1, 96), 4, true, false);
ASSERT_EQ(res.first[0].size(), 3);
std::vector<uint64_t> node_ids;
node_ids.push_back(96);
node_ids.push_back(37);
vs = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4);
res = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4,
true, false);

ASSERT_EQ(vs.size(), 2);
ASSERT_EQ(res.first[0].size(), 2);
std::vector<uint64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6);
ASSERT_EQ(nodes_ids.size(), 2);
ASSERT_EQ(true, (nodes_ids[0] == 59 && nodes_ids[1] == 37) ||
Expand Down Expand Up @@ -693,4 +695,4 @@ void testGraphToBuffer() {
VLOG(0) << s1.get_feature(0);
}

TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }
TEST(RunBrpcPushSparse, Run) { RunBrpcPushSparse(); }
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ void BindGraphPyClient(py::module* m) {
.def("start_client", &GraphPyClient::start_client)
.def("batch_sample_neighboors", &GraphPyClient::batch_sample_neighbors)
.def("batch_sample_neighbors", &GraphPyClient::batch_sample_neighbors)
.def("use_neighbors_sample_cache",
&GraphPyClient::use_neighbors_sample_cache)
.def("remove_graph_node", &GraphPyClient::remove_graph_node)
.def("random_sample_nodes", &GraphPyClient::random_sample_nodes)
.def("stop_server", &GraphPyClient::stop_server)
Expand Down

0 comments on commit 06847bc

Please sign in to comment.