Skip to content

Commit

Permalink
[hotfix] fix the bug of repeatedly storing param group
Browse files Browse the repository at this point in the history
  • Loading branch information
Fridge003 committed Oct 20, 2023
1 parent b8e770c commit 5f19c1c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
12 changes: 6 additions & 6 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,22 +152,22 @@ def save_sharded_optimizer(
index_file = CheckpointIndexFile(checkpoint)

# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
param_groups = optimizer.get_param_groups_for_saving()
torch.save(param_groups, group_file_path)
if self.coordinator.is_master():
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
param_groups = optimizer.get_param_groups_for_saving()
torch.save(param_groups, group_file_path)

# States are broken into shards within max_shard_size.
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)

# Save shards of optimizer states.
is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=is_master,
is_master=self.coordinator.is_master(),
use_safetensors=False,
)

Expand Down
7 changes: 4 additions & 3 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,10 @@ def save_sharded_optimizer(
index_file = CheckpointIndexFile(checkpoint)

# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(state_dict, group_file_path)
if self.coordinator.is_master():
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(state_dict, group_file_path)

# Save shards of optimizer states.
total_size = 0
Expand Down

0 comments on commit 5f19c1c

Please sign in to comment.