From 5449bea43c431b68bdb0ffa51b81828ddd1f108b Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 10 Aug 2023 12:57:22 +0800 Subject: [PATCH] [gemini] fix tensor storage cleaning in state dict collection --- colossalai/zero/gemini/chunk/__init__.py | 7 +++++-- colossalai/zero/gemini/gemini_optimizer.py | 9 +++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/colossalai/zero/gemini/chunk/__init__.py b/colossalai/zero/gemini/chunk/__init__.py index 6914d2dbef45..6fa3fab4512f 100644 --- a/colossalai/zero/gemini/chunk/__init__.py +++ b/colossalai/zero/gemini/chunk/__init__.py @@ -1,6 +1,9 @@ -from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState +from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState, free_storage from .manager import ChunkManager from .search_utils import classify_params_by_dp_degree, search_chunk_configuration from .utils import init_chunk_manager -__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager'] +__all__ = [ + 'Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager', + 'free_storage' +] diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 7d0db6b1fa23..1c665318d7f5 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -17,7 +17,7 @@ from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.utils import disposable, get_current_device, is_ddp_ignored -from .chunk import Chunk, ChunkManager +from .chunk import Chunk, ChunkManager, free_storage from .gemini_ddp import ZeroDDP __all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer'] @@ -467,11 +467,8 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: continue self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset, shard_size) - - # Clean gathered states - for state_shard in gathered_state_shards: - del state_shard[0] - gc.collect() + # Clean gathered states + free_storage(state_shard[0]) # Reshape tensors if is_collector: