Skip to content

Commit

Permalink
Merge pull request #477 from RokasEl/schedule-free
Browse files Browse the repository at this point in the history
Support for schedulefree optimizer
  • Loading branch information
ilyes319 authored Jun 21, 2024
2 parents 7842e99 + c90223c commit 3e6eb77
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 7 deletions.
14 changes: 13 additions & 1 deletion mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,15 +558,27 @@ def run(args: argparse.Namespace) -> None:
],
lr=args.lr,
amsgrad=args.amsgrad,
betas=(args.beta, 0.999),
)

optimizer: torch.optim.Optimizer
if args.optimizer == "adamw":
optimizer = torch.optim.AdamW(**param_options)
elif args.optimizer == "schedulefree":
try:
from schedulefree import adamw_schedulefree
except ImportError as exc:
raise ImportError(
"`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`"
) from exc
_param_options = {k: v for k, v in param_options.items() if k != "amsgrad"}
optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options)
else:
optimizer = torch.optim.Adam(**param_options)

logger = tools.MetricsLogger(directory=args.results_dir, tag=tag + "_train")
logger = tools.MetricsLogger(
directory=args.results_dir, tag=tag + "_train"
) # pylint: disable=E1123

lr_scheduler = LRScheduler(optimizer, args)

Expand Down
8 changes: 7 additions & 1 deletion mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
help="Optimizer for parameter optimization",
type=str,
default="adam",
choices=["adam", "adamw"],
choices=["adam", "adamw", "schedulefree"],
)
parser.add_argument(
"--beta",
help="Beta parameter for the optimizer",
type=float,
default=0.9,
)
parser.add_argument("--batch_size", help="batch size", type=int, default=10)
parser.add_argument(
Expand Down
21 changes: 17 additions & 4 deletions mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,18 @@ def radial_to_transform(radial):
"num_interactions": model.num_interactions.item(),
"num_elements": len(model.atomic_numbers),
"hidden_irreps": o3.Irreps(str(model.products[0].linear.irreps_out)),
"MLP_irreps": o3.Irreps(str(model.readouts[-1].hidden_irreps)),
"gate": model.readouts[-1] # pylint: disable=protected-access
.non_linearity._modules["acts"][0]
.f,
"MLP_irreps": (
o3.Irreps(str(model.readouts[-1].hidden_irreps))
if model.num_interactions.item() > 1
else 1
),
"gate": (
model.readouts[-1] # pylint: disable=protected-access
.non_linearity._modules["acts"][0]
.f
if model.num_interactions.item() > 1
else None
),
"atomic_energies": model.atomic_energies_fn.atomic_energies.cpu().numpy(),
"avg_num_neighbors": model.interactions[0].avg_num_neighbors,
"atomic_numbers": model.atomic_numbers,
Expand Down Expand Up @@ -373,6 +381,9 @@ def custom_key(key):
class LRScheduler:
def __init__(self, optimizer, args) -> None:
self.scheduler = args.scheduler
self._optimizer_type = (
args.optimizer
) # Schedulefree does not need an optimizer but checkpoint handler does.
if args.scheduler == "ExponentialLR":
self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer=optimizer, gamma=args.lr_scheduler_gamma
Expand All @@ -387,6 +398,8 @@ def __init__(self, optimizer, args) -> None:
raise RuntimeError(f"Unknown scheduler: '{args.scheduler}'")

def step(self, metrics=None, epoch=None): # pylint: disable=E1123
if self._optimizer_type == "schedulefree":
return # In principle, schedulefree optimizer can be used with a scheduler but the paper suggests it's not necessary
if self.scheduler == "ExponentialLR":
self.lr_scheduler.step(epoch=epoch)
elif self.scheduler == "ReduceLROnPlateau":
Expand Down
5 changes: 4 additions & 1 deletion mace/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def train(
# Train
if distributed:
train_sampler.set_epoch(epoch)

if "ScheduleFree" in type(optimizer).__name__:
optimizer.train()
train_one_epoch(
model=model,
loss_fn=loss_fn,
Expand All @@ -201,6 +202,8 @@ def train(
param_context = (
ema.average_parameters() if ema is not None else nullcontext()
)
if "ScheduleFree" in type(optimizer).__name__:
optimizer.eval()
with param_context:
valid_loss, eval_metrics = evaluate(
model=model_to_evaluate,
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ dev =
pre-commit
pytest
pylint
schedulefree = schedulefree
127 changes: 127 additions & 0 deletions tests/test_schedulefree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from unittest.mock import MagicMock
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from e3nn import o3

from mace import data, modules, tools
from mace.tools import torch_geometric, scripts_utils
import tempfile

try:
import schedulefree
except ImportError:
pytest.skip(
"Skipping schedulefree tests due to ImportError", allow_module_level=True
)

torch.set_default_dtype(torch.float64)

table = tools.AtomicNumberTable([6])
atomic_energies = np.array([1.0], dtype=float)
cutoff = 5.0


def create_mace(device: str, seed: int = 1702):
torch_geometric.seed_everything(seed)

model_config = {
"r_max": cutoff,
"num_bessel": 8,
"num_polynomial_cutoff": 6,
"max_ell": 3,
"interaction_cls": modules.interaction_classes[
"RealAgnosticResidualInteractionBlock"
],
"interaction_cls_first": modules.interaction_classes[
"RealAgnosticResidualInteractionBlock"
],
"num_interactions": 2,
"num_elements": 1,
"hidden_irreps": o3.Irreps("8x0e + 8x1o"),
"MLP_irreps": o3.Irreps("16x0e"),
"gate": F.silu,
"atomic_energies": atomic_energies,
"avg_num_neighbors": 8,
"atomic_numbers": table.zs,
"correlation": 3,
"radial_type": "bessel",
}
model = modules.MACE(**model_config)
return model.to(device)


def create_batch(device: str):
from ase import build

size = 2
atoms = build.bulk("C", "diamond", a=3.567, cubic=True)
atoms_list = [atoms.repeat((size, size, size))]
print("Number of atoms", len(atoms_list[0]))

configs = [data.config_from_atoms(atoms) for atoms in atoms_list]
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(config, z_table=table, cutoff=cutoff)
for config in configs
],
batch_size=1,
shuffle=False,
drop_last=False,
)
batch = next(iter(data_loader))
batch = batch.to(device)
batch = batch.to_dict()
return batch


def do_optimization_step(
model,
optimizer,
device,
):
batch = create_batch(device)
model.train()
optimizer.train()
optimizer.zero_grad()
output = model(batch, training=True, compute_force=False)
loss = output["energy"].mean()
loss.backward()
optimizer.step()
model.eval()
optimizer.eval()



@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_can_load_checkpoint(device):
model = create_mace(device)
optimizer = schedulefree.adamw_schedulefree.AdamWScheduleFree(model.parameters())
args = MagicMock()
args.optimizer = "schedulefree"
args.scheduler = "ExponentialLR"
args.lr_scheduler_gamma = 0.9
lr_scheduler = scripts_utils.LRScheduler(optimizer, args)
with tempfile.TemporaryDirectory() as d:
checkpoint_handler = tools.CheckpointHandler(
directory=d, keep=False, tag="schedulefree"
)
for _ in range(10):
do_optimization_step(model, optimizer, device)
batch = create_batch(device)
output = model(batch)
energy = output["energy"].detach().cpu().numpy()

state = tools.CheckpointState(
model=model, optimizer=optimizer, lr_scheduler=lr_scheduler
)
checkpoint_handler.save(state, epochs=0, keep_last=False)
checkpoint_handler.load_latest(
state=tools.CheckpointState(model, optimizer, lr_scheduler),
swa=False,
)
batch = create_batch(device)
output = model(batch)
new_energy = output["energy"].detach().cpu().numpy()
assert np.allclose(energy, new_energy, atol=1e-9)

0 comments on commit 3e6eb77

Please sign in to comment.