Skip to content

Commit

Permalink
[gemini] fix tensor storage cleaning in state dict collection
Browse files Browse the repository at this point in the history
  • Loading branch information
Fridge003 committed Aug 10, 2023
1 parent 458ae33 commit 5449bea
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
7 changes: 5 additions & 2 deletions colossalai/zero/gemini/chunk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState
from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState, free_storage
from .manager import ChunkManager
from .search_utils import classify_params_by_dp_degree, search_chunk_configuration
from .utils import init_chunk_manager

__all__ = ['Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager']
__all__ = [
'Chunk', 'ChunkManager', 'classify_params_by_dp_degree', 'search_chunk_configuration', 'init_chunk_manager',
'free_storage'
]
9 changes: 3 additions & 6 deletions colossalai/zero/gemini/gemini_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.utils import disposable, get_current_device, is_ddp_ignored

from .chunk import Chunk, ChunkManager
from .chunk import Chunk, ChunkManager, free_storage
from .gemini_ddp import ZeroDDP

__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
Expand Down Expand Up @@ -467,11 +467,8 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
continue
self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset,
shard_size)

# Clean gathered states
for state_shard in gathered_state_shards:
del state_shard[0]
gc.collect()
# Clean gathered states
free_storage(state_shard[0])

# Reshape tensors
if is_collector:
Expand Down

0 comments on commit 5449bea

Please sign in to comment.