diff --git a/train.py b/train.py index 15f0e493..2f9a130c 100644 --- a/train.py +++ b/train.py @@ -65,9 +65,14 @@ def build_optimizer(model, job_config: JobConfig): name = job_config.optimizer.name lr = job_config.optimizer.lr if name == "Adam": - optimizer = torch.optim.Adam(model.parameters(), lr=lr) + # TODO: make the optimizer options configurable by toml/cmd args + optimizer = torch.optim.Adam( + model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1 + ) elif name == "AdamW": - optimizer = torch.optim.AdamW(model.parameters(), lr=lr) + optimizer = torch.optim.AdamW( + model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1 + ) else: raise NotImplementedError(f"optimizer {name} not added")