Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial torch.compile support (inference only) #300

Merged
merged 11 commits into from
Jan 25, 2024
2 changes: 2 additions & 0 deletions mace/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from e3nn.util.jit import compile_mode

from mace.tools.scatter import scatter_sum
from mace.tools.compile import simplify_if_compile

from .irreps_tools import (
linear_out_irreps,
Expand Down Expand Up @@ -46,6 +47,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [n_nodes, irreps] # [...
return self.linear(x) # [n_nodes, 1]


@simplify_if_compile
@compile_mode("script")
class NonLinearReadoutBlock(torch.nn.Module):
def __init__(
Expand Down
4 changes: 2 additions & 2 deletions mace/modules/symmetric_contraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
)
c_tensor = c_tensor + out
out = contract_features(c_tensor, x)
resize_shape = torch.prod(torch.tensor(out.shape[1:]))
return out.view(out.shape[0], resize_shape)

return out.view(out.shape[0], -1)

def U_tensors(self, nu: int):
return dict(self.named_buffers())[f"U_matrix_{nu}"]
93 changes: 93 additions & 0 deletions mace/tools/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from contextlib import contextmanager
from functools import wraps
from typing import Callable, Tuple

import torch.nn as nn
import torch._dynamo as dynamo
from torch import autograd
from torch.fx import symbolic_trace
from e3nn import get_optimization_defaults, set_optimization_defaults

ModuleFactory = Callable[..., nn.Module]
TypeTuple = Tuple[type, ...]


@contextmanager
def disable_e3nn_codegen():
"""Context manager that disables the legacy PyTorch code generation used in e3nn."""
init_val = get_optimization_defaults()["jit_script_fx"]
set_optimization_defaults(jit_script_fx=False)
yield
set_optimization_defaults(jit_script_fx=init_val)


def prepare(func: ModuleFactory, allow_autograd: bool = True) -> ModuleFactory:
"""Function transform that prepares a MACE module for torch.compile

Args:
func (ModuleFactory): A function that creates an nn.Module
allow_autograd (bool, optional): Force inductor compiler to inline call to
`torch.autograd.grad`. Defaults to True.

Returns:
ModuleFactory: Decorated function that creates a torch.compile compatible module
"""
if allow_autograd:
dynamo.allow_in_graph(autograd.grad)
elif dynamo.allowed_functions.is_allowed(autograd.grad):
dynamo.disallow_in_graph(autograd.grad)

@wraps(func)
def wrapper(*args, **kwargs):
with disable_e3nn_codegen():
model = func(*args, **kwargs)

model = simplify(model)
return model

return wrapper


_SIMPLIFY_REGISTRY = set()


def simplify_if_compile(module: nn.Module) -> nn.Module:
"""Decorator to register a module for symbolic simplification

The decorated module will be simplifed using `torch.fx.symbolic_trace`.
This constrains the module to not have any dynamic control flow, see:

https://pytorch.org/docs/stable/fx.html#limitations-of-symbolic-tracing

Args:
module (nn.Module): the module to register

Returns:
nn.Module: registered module
"""
_SIMPLIFY_REGISTRY.add(module)
return module


def simplify(module: nn.Module) -> nn.Module:
"""Recursively searches for registered modules to simplify with
`torch.fx.symbolic_trace` to support compiling with the PyTorch Dynamo compiler.

Modules are registered with the `simplify_if_compile` decorator and

Args:
module (nn.Module): the module to simplify

Returns:
nn.Module: the simplified module
"""
simplify_types = tuple(_SIMPLIFY_REGISTRY)

for name, child in module.named_children():
if isinstance(child, simplify_types):
traced = symbolic_trace(child)
setattr(module, name, traced)
else:
simplify(child)

return module
3 changes: 0 additions & 3 deletions mace/tools/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
return src


@torch.jit.script
def scatter_sum(
src: torch.Tensor,
index: torch.Tensor,
Expand All @@ -49,7 +48,6 @@ def scatter_sum(
return out.scatter_add_(dim, index, src)


@torch.jit.script
def scatter_std(
src: torch.Tensor,
index: torch.Tensor,
Expand Down Expand Up @@ -87,7 +85,6 @@ def scatter_std(
return out


@torch.jit.script
def scatter_mean(
src: torch.Tensor,
index: torch.Tensor,
Expand Down
3 changes: 2 additions & 1 deletion mace/tools/torch_geometric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .data import Data
from .dataloader import DataLoader
from .dataset import Dataset
from .seed import seed_everything

__all__ = ["Batch", "Data", "Dataset", "DataLoader"]
__all__ = ["Batch", "Data", "Dataset", "DataLoader", "seed_everything"]
17 changes: 17 additions & 0 deletions mace/tools/torch_geometric/seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import random

import numpy as np
import torch


def seed_everything(seed: int):
r"""Sets the seed for generating random numbers in :pytorch:`PyTorch`,
:obj:`numpy` and Python.

Args:
seed (int): The desired seed.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
179 changes: 179 additions & 0 deletions tests/test_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from functools import wraps
from typing import Callable

import numpy as np
import pandas as pd
import pytest
import torch
import torch.nn.functional as F
from e3nn import o3
from scipy.spatial.transform import Rotation as R

from mace import data, modules, tools
from mace.tools import torch_geometric, compile

torch.set_default_dtype(torch.float64)
config = data.Configuration(
atomic_numbers=np.array([8, 1, 1]),
positions=np.array(
[
[0.0, -2.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
]
),
forces=np.array(
[
[0.0, -1.3, 0.0],
[1.0, 0.2, 0.0],
[0.0, 1.1, 0.3],
]
),
energy=-1.5,
charges=np.array([-2.0, 1.0, 1.0]),
dipole=np.array([-1.5, 1.5, 2.0]),
)
# Created the rotated environment
rot = R.from_euler("z", 60, degrees=True).as_matrix()
positions_rotated = np.array(rot @ config.positions.T).T
config_rotated = data.Configuration(
atomic_numbers=np.array([8, 1, 1]),
positions=positions_rotated,
forces=np.array(
[
[0.0, -1.3, 0.0],
[1.0, 0.2, 0.0],
[0.0, 1.1, 0.3],
]
),
energy=-1.5,
charges=np.array([-2.0, 1.0, 1.0]),
dipole=np.array([-1.5, 1.5, 2.0]),
)
table = tools.AtomicNumberTable([1, 8])
atomic_energies = np.array([1.0, 3.0], dtype=float)


def create_mace(device: str, seed: int = 1702):
torch_geometric.seed_everything(seed)

model_config = {
"r_max": 5,
"num_bessel": 8,
"num_polynomial_cutoff": 6,
"max_ell": 3,
"interaction_cls": modules.interaction_classes[
"RealAgnosticResidualInteractionBlock"
],
"interaction_cls_first": modules.interaction_classes[
"RealAgnosticResidualInteractionBlock"
],
"num_interactions": 2,
"num_elements": 2,
"hidden_irreps": o3.Irreps("128x0e + 128x1o"),
"MLP_irreps": o3.Irreps("16x0e"),
"gate": F.silu,
"atomic_energies": atomic_energies,
"avg_num_neighbors": 8,
"atomic_numbers": table.zs,
"correlation": 3,
"radial_type": "bessel",
}
model = modules.MACE(**model_config)
return model.to(device)


def create_batch(device: str):
atomic_data = data.AtomicData.from_config(config, z_table=table, cutoff=3.0)
atomic_data2 = data.AtomicData.from_config(
config_rotated, z_table=table, cutoff=3.0
)

data_loader = torch_geometric.dataloader.DataLoader(
dataset=[atomic_data, atomic_data2],
batch_size=2,
shuffle=True,
drop_last=False,
)
batch = next(iter(data_loader))
batch = batch.to(device)
batch = batch.to_dict()
return batch


def time_func(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
torch._inductor.cudagraph_mark_step_begin()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
func(*args, **kwargs)
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / 1000

return wrapper


@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_mace(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip(reason="cuda is not available")

model_defaults = create_mace(device)
tmp_model = compile.prepare(create_mace)(device)
model_compiled = torch.compile(tmp_model, mode="default")

batch = create_batch(device)
output1 = model_defaults(batch, training=True)
output2 = model_compiled(batch, training=True)
assert torch.allclose(output1["energy"][0], output2["energy"][0])
assert torch.allclose(output2["energy"][0], output2["energy"][1])


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available")
@pytest.mark.parametrize("compile_mode", ["default", "reduce-overhead", "max-autotune"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_inference_speedup(compile_mode, dtype):
torch.set_default_dtype(dtype)

# PyTorch eager Baseline
nruns = 16
batch = create_batch("cuda")
model = create_mace("cuda")
model = time_func(model)
t_eager = np.array([model(batch, training=False) for _ in range(nruns)])

print(f'Compiling using mode="{compile_mode}"')
torch.compiler.reset()
model = compile.prepare(create_mace)("cuda")
compiled = torch.compile(model, mode=compile_mode, fullgraph=True)
compiled = time_func(compiled)
t_compiled = np.array([compiled(batch, training=True) for _ in range(nruns)])

df = pd.DataFrame(
{
"eager": t_eager,
f"compile mode={compile_mode}": t_compiled,
"speedup": t_eager / t_compiled,
}
)
print(f"\n\n{df.to_string(index=False)}\n\n")

assert np.median(df["speedup"][-4:]) > 1, "Median compile speedup is less than 1"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available")
def test_graph_breaks():
import torch._dynamo as dynamo

batch = create_batch("cuda")
model = compile.prepare(create_mace)("cuda")
explanation = dynamo.explain(model)(batch, training=False)

# these clutter the output but might be useful for investigating graph breaks
explanation.ops_per_graph = None
explanation.out_guards = None
print(explanation)
assert explanation.graph_break_count == 0