Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run garbage collection manually in train loop #509

Merged
merged 3 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Rename `Olmo` to `OLMo` everywhere in the codebase
- Disabled automatic garbage collection during training, instead we run manually at regular intervals to avoid ranks getting out-of-sync with their own gc.

### Removed

Expand Down
12 changes: 9 additions & 3 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@
from .exceptions import OLMoCheckpointError
from .optim import Optimizer, fix_optim_state_dict
from .safetensors_util import safetensors_file_to_state_dict
from .torch_util import barrier, get_fs_local_rank, get_global_rank, get_world_size
from .torch_util import (
barrier,
gc_cuda,
get_fs_local_rank,
get_global_rank,
get_world_size,
)
from .util import (
_get_s3_client,
default_thread_count,
Expand Down Expand Up @@ -191,7 +197,7 @@ def load_fsdp_model_and_optim_state(
),
)
del model_state
torch.cuda.empty_cache()
gc_cuda()
load_fsdp_optim_state(fsdp_model, optim, optim_state["optim"])


Expand All @@ -212,7 +218,7 @@ def load_fsdp_optim_state(fsdp_model: FSDP, optim: Optimizer, optim_state: Dict[
v = state[k]
if isinstance(v, torch.Tensor):
state[k] = v.to(device="cpu")
torch.cuda.empty_cache()
gc_cuda()
optim.load_state_dict(fix_optim_state_dict(optim, flattened_osd))


Expand Down
7 changes: 7 additions & 0 deletions olmo/torch_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import os
from typing import Optional, TypeVar

Expand Down Expand Up @@ -130,3 +131,9 @@ def synchronize_value(value: V, device: torch.device) -> V:

def synchronize_flag(flag: bool, device: torch.device) -> bool:
return synchronize_value(flag, device)


def gc_cuda():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
28 changes: 25 additions & 3 deletions olmo/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import cProfile
import gc
import logging
import math
import os
Expand Down Expand Up @@ -38,6 +39,7 @@
from .optim import Optimizer, Scheduler
from .torch_util import (
barrier,
gc_cuda,
get_fs_local_rank,
get_global_rank,
get_world_size,
Expand Down Expand Up @@ -136,6 +138,7 @@ class Trainer:
cur_train_loss: float = float("inf")
indices_file: Optional[TextIO] = None
_start_time: float = 0.0
_gc_init_state: bool = True
loss_fn: Callable[..., torch.Tensor] = field(default_factory=lambda: cross_entropy_loss) # type: ignore
last_sharded_checkpoint_step: Optional[int] = None
last_unsharded_checkpoint_step: Optional[int] = None
Expand Down Expand Up @@ -537,15 +540,19 @@ def restore_unsharded_checkpoint(
def save_checkpoint(
self, checkpoint_type: CheckpointType = CheckpointType.sharded
) -> Tuple[PathOrStr, Optional[PathOrStr]]:
result: Tuple[PathOrStr, Optional[PathOrStr]]
if checkpoint_type == CheckpointType.sharded:
return self.save_sharded_checkpoint()
result = self.save_sharded_checkpoint()
elif checkpoint_type == CheckpointType.unsharded:
return self.save_unsharded_checkpoint()
result = self.save_unsharded_checkpoint()
elif checkpoint_type == CheckpointType.sharded_ephemeral:
return self.save_ephemeral_checkpoint()
result = self.save_ephemeral_checkpoint()
else:
raise NotImplementedError(checkpoint_type)

gc_cuda()
return result

def restore_checkpoint(
self,
load_path: PathOrStr,
Expand Down Expand Up @@ -576,6 +583,8 @@ def restore_checkpoint(
elif checkpoint_type is not None:
raise NotImplementedError(checkpoint_type)

gc_cuda()

def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = CheckpointType.sharded):
if checkpoint_type == CheckpointType.sharded:
self.remove_sharded_checkpoint(idx=idx)
Expand Down Expand Up @@ -936,6 +945,10 @@ def fit(self):
self.cfg.stop_at = min(self.cfg.stop_at, self.global_step + self.cfg.stop_after)

self._start_time = time.time()
self._gc_init_state = gc.isenabled() # cache if garbage collection is enabled, reset on close.

# Disable automatic garbage collection, FSDP doesn't work well with it.
gc.disable()

if self.cfg.load_path is not None and self.global_step > 0 and self.cfg.eval_on_load:
eval_metrics = self.eval()
Expand Down Expand Up @@ -1141,6 +1154,9 @@ def on_trace_ready(p):
if stop_at is not None and self.global_step >= stop_at:
break

# Run generation 1 garbage collection.
gc.collect(1)

# Python Profiler stuff
# We do this now, at the bottom of this loop, so we capture the work of getting the next batch.
if python_profiler is not None:
Expand Down Expand Up @@ -1178,9 +1194,15 @@ def on_trace_ready(p):
log.info(f"Checkpoint saved to {checkpoint_path}")

def close(self, exit_code: int = 0) -> None:
gc_cuda()

if self.indices_file is not None:
self.indices_file.flush()
self.indices_file.close()
if self._gc_init_state:
gc.enable()
else:
gc.disable()
if wandb.run is not None:
wandb.finish(exit_code=exit_code, quiet=True)

Expand Down
Loading