From cd8ad65f5a2ce53f88d321b82dfbb5b198beb009 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 31 Oct 2023 14:48:01 +0800 Subject: [PATCH] [hotfix] fix the bug of repeatedly storing param group (#4951) --- colossalai/booster/plugin/gemini_plugin.py | 12 ++++++------ colossalai/booster/plugin/low_level_zero_plugin.py | 7 ++++--- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 20a931b816ea..d1a9bc2623a3 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -150,24 +150,24 @@ def save_sharded_optimizer( # Preparing file paths and index file. states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) index_file = CheckpointIndexFile(checkpoint) + index_file.append_meta_data("param_groups", param_group_file) # Store the information of param groups to param_group_file. - index_file.append_meta_data("param_groups", param_group_file) - group_file_path = os.path.join(checkpoint, param_group_file) - param_groups = optimizer.get_param_groups_for_saving() - torch.save(param_groups, group_file_path) + if self.coordinator.is_master(): + group_file_path = os.path.join(checkpoint, param_group_file) + param_groups = optimizer.get_param_groups_for_saving() + torch.save(param_groups, group_file_path) # States are broken into shards within max_shard_size. state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True) # Save shards of optimizer states. - is_master = self.coordinator.is_master() total_size = save_state_dict_shards( sharded_state_dict=state_dict_shard, checkpoint=checkpoint, index_file=index_file, base_filename=states_name, - is_master=is_master, + is_master=self.coordinator.is_master(), use_safetensors=False, ) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index dc78fe8c094c..09343138f5ff 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -119,11 +119,12 @@ def save_sharded_optimizer( # Preparing file paths and index file. states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) index_file = CheckpointIndexFile(checkpoint) + index_file.append_meta_data("param_groups", param_group_file) # Store the information of param groups to param_group_file. - index_file.append_meta_data("param_groups", param_group_file) - group_file_path = os.path.join(checkpoint, param_group_file) - save_param_groups(state_dict, group_file_path) + if self.coordinator.is_master(): + group_file_path = os.path.join(checkpoint, param_group_file) + save_param_groups(state_dict, group_file_path) # Save shards of optimizer states. total_size = 0