Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial torch.compile support (inference only) #300

Merged
merged 11 commits into from
Jan 25, 2024

Conversation

hatemhelal
Copy link
Contributor

@hatemhelal hatemhelal commented Jan 18, 2024

This PR add some test cases that use torch.compile and a new module mace.tools.compile that contains some helper utilities for MACE compatibilty with torch.compile

Some of the changes needed include:

  • remove the torch.jit.script annotations from the scatter-reduce implementations. This is necessary as the compiled script functions are not compatible with the inductor backend.
  • disable the torch script compilation baked into the e3nn project by invoking e3nn.set_optimization_defaults(jit_script_fx=False) ahead of creating the model instance. This can be managed with the disable_e3nn_codegen context manager.
  • To minimize graph breaks it appears to be necessary to apply a symbolic simplifier to convert the e3nn irreps objects into simpler python types that the inductor compiler knows how to handle. This is accomplished by decorating classes with simplify_if_compile
  • This is all put together with the prepare function which manages creating the model without e3nn codegen and applies the symbolic tracing simplification to registered modules.

Note that the remaining graph break in the compiled inference model is due to using autograd to evaluate the forces. This might be possible to fix but I expect it would be easier to do in another PR.

@hatemhelal
Copy link
Contributor Author

hatemhelal commented Jan 23, 2024

I experimented with the different torch.compile mode options and measured approximate speedups over native eager mode as:

mode speedup (a10g)
default 1.4
reduce-overhead 1.6
max-autotune 1.6

To get these values I ran:

pytest -s -k test_inference_speedup

Note, still need to investigate the following warning:

skipping cudagraphs due to complex input striding

which is seen for the non-default modes and I suspect is due to how the input batches are created in the test case.

@ilyes319 ilyes319 changed the base branch from main to compile January 25, 2024 17:41
@ilyes319 ilyes319 merged commit db75a72 into ACEsuit:compile Jan 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants