Skip to content

Commit

Permalink
add fused to config
Browse files Browse the repository at this point in the history
  • Loading branch information
wz337 committed May 22, 2024
1 parent 3bd14ec commit e8461b4
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
9 changes: 9 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,15 @@ def build_test_list(args):
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
),
OverrideDefinitions(
[
[
f"--optimizer.name Adam --optimizer.fused --job.dump_folder {args.output_dir}/fused_adamw/",
f"--optimizer.name AdamW --optimizer.fused --job.dump_folder {args.output_dir}/fused_adamw/",
]
],
"Fused Optimizer Test",
),
]
return integration_tests_flavors

Expand Down
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ def __init__(self):
self.parser.add_argument(
"--optimizer.lr", type=float, default=8e-4, help="Learning rate to use"
)
self.parser.add_argument(
"--optimizer.fused",
default=False,
action="store_true",
help="Whether the fused implementation(CUDA only) is used.",
)

# training configs
self.parser.add_argument(
Expand Down
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,16 @@ def build_optimizer(model, job_config: JobConfig):
# build optimizer
name = job_config.optimizer.name
lr = job_config.optimizer.lr
fused = job_config.optimizer.fused
# when fused = False, foreach = True by default.
if name == "Adam":
# TODO: make the optimizer options configurable by toml/cmd args
optimizer = torch.optim.Adam(
model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1, foreach=True
model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1, fused=fused
)
elif name == "AdamW":
optimizer = torch.optim.AdamW(
model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1, foreach=True
model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1, fused=fused
)
else:
raise NotImplementedError(f"Optimizer {name} not added.")
Expand Down

0 comments on commit e8461b4

Please sign in to comment.