Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Add distributed torch_geometric example #23580

Merged
merged 13 commits into from
Apr 21, 2022

Conversation

amogkam
Copy link
Contributor

@amogkam amogkam commented Mar 30, 2022

Add example for distributed pytorch geometric (graph learning) with Ray AIR

This only showcases distributed training, but with data small enough that it can be loaded in by each training worker individually. Distributed data ingest is out of scope for this PR.

Why are these changes needed?

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

)

# Disable distributed sampler since the train_loader has already been split above.
train_loader = train.torch.prepare_data_loader(train_loader, add_dist_sampler=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb question, why do we do split separately? Instead of combined in prepare_data_loader?

Copy link
Contributor Author

@amogkam amogkam Apr 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to use torch geometric's NeighborSampler for sampling subgraphs from the overall graph, instead of the standard DistributedSampler

return x.log_softmax(dim=-1)

@torch.no_grad()
def inference(self, x_all, subgraph_loader):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this planned to be used for predictor impl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually yes, but the challenge for prediction is how to add "fresh data" to the graph to do inference on.

scaling_config={"num_workers": num_workers, "use_gpu": use_gpu},
)
result = trainer.fit()
print(result.metrics)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does prediction look like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prediction is not supported for now- we need to be able to add "fresh data" to the existing graph and then re-run the inference algorithm on the new data.

@@ -8,3 +8,10 @@ tblib
-f https://download.pytorch.org/whl/torch_stable.html
torch==1.9.0+cu111
torchvision==0.10.0+cu111

-f https://data.pyg.org/whl/torch-1.9.0+cu111.html
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, what is this for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are required dependencies for pytorch geometric

self.convs.append(SAGEConv(hidden_channels, out_channels))

def forward(self, x, adjs):
for i, (edge_index, _, size) in enumerate(adjs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you comment a bit about the format of this adjs matrix?
especially 1. what does size mean in this context? and 2. how do we make sure there are always enough hidden layers to handle the adjacency links in adjs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment here- but more information are in the torch geometric docs.

For 2, we pass in a sizes list to the NeighborSampler, so the size of this list should match with the number of layers in the model.

x = F.relu(x)
xs.append(x.cpu())

x_all = torch.cat(xs, dim=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks a bit weird to me. I think I am just clueless.
if we overwrite the entire x_all here, we will only have features for the nodes that we scored with the last layer.
it feels more appropriate to me to update the weights of xs in x_all, not simply "x_all = ..." ??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just the inference code so no weights updating. I think this works the same way as a standard feed-forward neural network. We only want the output of the last layer, and we don't care about the hidden states during inference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I understand this now. subgraph_loader actually samples a subgraph for every single node in the graph.
so if there are n nodes in the graph, the inner loop will run n times. each time, we essentially aggregating data from all neighboring nodes to this specific node.
so at the end, torch.cat(xs) will give us a new updated graph, since xs will contain data for every single node at the point.
interesting design.

@richardliaw richardliaw added this to the Ray AIR milestone Apr 8, 2022
Comment on lines +11 to +12

-f https://data.pyg.org/whl/torch-1.9.0+cu111.html
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to make these changes to requirements_dl.txt (line 6 above)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there's only a GPU test, I think it should be fine for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh but doesn't that make the instruction in line 6 no longer true? Do we actually want these to be in CPU docker as well? Alternative solution: move these above that line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the comment to reflect the changes!

@@ -36,6 +36,6 @@
conditions: ["RAY_CI_ML_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/travis/upload_build_info.sh; fi }; trap cleanup EXIT
- DATA_PROCESSING_TESTING=1 TRAIN_TESTING=1 TUNE_TESTING=1 ./ci/travis/install-dependencies.sh
- DATA_PROCESSING_TESTING=1 TRAIN_TESTING=1 TUNE_TESTING=1 PYTHON=3.7 ./ci/travis/install-dependencies.sh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my learning, is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Torch geometric does not support python 3.6.

We could just make a separate build just for 3.7, but I thought it would be better to just upgrade everything to 3.7 since this is what we do for Tune anyways currently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait isn't the default value 3.7?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it's 3.6 I believe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I believe it was updated for GPU images here

But similar to my comment on that PR, having it explicit makes sense (in case we change default version in the future)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got it. Actually there are versions of torch geometric that support python 3.6, but the later versions don't. But in any case, it's fine to have this be explicit?

@amogkam amogkam requested a review from matthewdeng April 8, 2022 20:03
@amogkam amogkam added the tests-ok The tagger certifies test failures are unrelated and assumes personal liability. label Apr 20, 2022
@amogkam amogkam merged commit 732175e into ray-project:master Apr 21, 2022
@amogkam amogkam deleted the torch-geometric-examples branch April 21, 2022 16:48
Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry about the delay, have a few minor questions/comments.

@@ -504,7 +504,8 @@ def _wait_for_batch(self, item):
# the tensor might be freed once it is no longer used by
# the creator stream.
for i in item:
i.record_stream(curr_stream)
if isinstance(i, torch.Tensor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you comment what may show up here as well, and why you need this if statement now?

Copy link
Contributor Author

@amogkam amogkam Apr 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pytorch dataloader can actually just return a batch of anything. In all of our examples and tests so far our data loaders return batches of tensors, but in this case, the torch geometric data loader also returns batch size, node id, etc., which are not all tensors.

# Use 10% of nodes for validation and 10% for testing.
fake_dataset = FakeDataset(transform=RandomNodeSplit(num_val=0.1, num_test=0.1))

def gen_dataset():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels a little unnecessary.
why don[t we simply return fake_dataset here, and below in the configuration, we say "dataset_fn": gen_fake_dataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point 😅. Made a follow up PR here #24080!

def inference(self, x_all, subgraph_loader):
for i in range(self.num_layers):
xs = []
for batch_size, n_id, adj in subgraph_loader:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually reading this again now, I am still a bit curious how should a user use this inference call.
this will only work if subgraph_loader iterates through all nodes in a graph. so:

  1. how does a user construct such a subgraph loader?
  2. is it really a common case that someone would want to score an entire graph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the intent is to use this just for validation and testing and not for actual live predictions.

We will need to figure out the inference/prediction story more later. This was copied over from the example on torch geometric, but let me rename this to "test" to make this more clear.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tests-ok The tagger certifies test failures are unrelated and assumes personal liability.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants