Skip to content

Commit

Permalink
set betas and weight decay for optimizers
Browse files Browse the repository at this point in the history
according to suggestions in #118 (comment)

ghstack-source-id: 357f0872cd1c9bad2c4c256d47adbd3f716a7651
Pull Request resolved: #123
  • Loading branch information
wanchaol committed Mar 9, 2024
1 parent af221ce commit 5e36c74
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,14 @@ def build_optimizer(model, job_config: JobConfig):
name = job_config.optimizer.name
lr = job_config.optimizer.lr
if name == "Adam":
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# TODO: make the optimizer options configurable by toml/cmd args
optimizer = torch.optim.Adam(
model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1
)
elif name == "AdamW":
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
optimizer = torch.optim.AdamW(
model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1
)
else:
raise NotImplementedError(f"optimizer {name} not added")

Expand Down

0 comments on commit 5e36c74

Please sign in to comment.