Skip to content

Commit

Permalink
Merge pull request #400 from allenai/shanea/storage-cleaner-wandb
Browse files Browse the repository at this point in the history
[Storage cleaner] Add wandb path implementation
  • Loading branch information
2015aroras authored Dec 15, 2023
2 parents 5bdccc3 + b7a3c66 commit 685d11b
Showing 1 changed file with 109 additions and 4 deletions.
113 changes: 109 additions & 4 deletions scripts/storage_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
import botocore.exceptions as boto_exceptions
import google.cloud.storage as gcs
import torch
import wandb
from cached_path import add_scheme_client, cached_path, set_cache_dir
from cached_path.schemes import S3Client
from google.api_core.exceptions import NotFound
from omegaconf import OmegaConf as om
from rich.progress import Progress, TaskID, track

from olmo import util
Expand Down Expand Up @@ -946,8 +948,108 @@ def unshard_run_checkpoints(run_path: str, checkpoints_dest_dir: str, config: Un
_unshard_checkpoints(storage, run_dir_or_archive, checkpoints_dest_dir, config)


def _get_wandb_runs_from_wandb_dir(storage: StorageAdapter, wandb_dir: str, run_config: TrainConfig) -> List:
# For some reason, we often have a redundant nested wandb directory. Step into it here.
nested_wandb_dir = os.path.join(wandb_dir, "wandb/")
if storage.is_dir(nested_wandb_dir):
wandb_dir = nested_wandb_dir

# Wandb run directory names are stored in format <run>-<timestamp>-<id>
# https://docs.wandb.ai/guides/track/save-restore#examples-of-wandbsave
dir_names = storage.list_dirs(wandb_dir)
wandb_run_dir_names = [dir_name for dir_name in dir_names if dir_name.startswith("run")]
if len(wandb_run_dir_names) == 0:
log.warning("No wandb run directories found in wandb dir %s", wandb_dir)
return []

wandb_ids = [dir_name.split("-")[2] for dir_name in wandb_run_dir_names if dir_name.count("-") >= 2]

log.debug("Wandb ids: %s", wandb_ids)

assert run_config.wandb is not None
api: wandb.Api = wandb.Api()
return [api.run(path=f"{run_config.wandb.entity}/{run_config.wandb.project}/{id}") for id in wandb_ids]


def _get_wandb_path_from_run(wandb_run) -> str:
return "/".join(wandb_run.path)


def _get_wandb_runs_from_train_config(config: TrainConfig) -> List:
assert config.wandb is not None

run_filters = {
"display_name": config.wandb.name,
}
if config.wandb.group is not None:
run_filters["group"] = config.wandb.group

log.debug("Wandb entity/project: %s/%s", config.wandb.entity, config.wandb.project)
log.debug("Wandb filters: %s", run_filters)

api = wandb.Api()
return api.runs(path=f"{config.wandb.entity}/{config.wandb.project}", filters=run_filters)


def _are_equal_configs(wandb_config: TrainConfig, train_config: TrainConfig) -> bool:
return wandb_config.asdict(exclude=["wandb"]) == train_config.asdict(exclude=["wandb"])


def _get_wandb_config(wandb_run) -> TrainConfig:
local_storage = LocalFileSystemAdapter()
temp_file = local_storage.create_temp_file(suffix=".yaml")

om.save(config=wandb_run.config, f=temp_file)
wandb_config = TrainConfig.load(temp_file)

return wandb_config


def _get_matching_wandb_runs(wandb_runs, training_run_dir: str) -> List:
config_path = os.path.join(training_run_dir, CONFIG_YAML)
local_config_path = cached_path(config_path)
train_config = TrainConfig.load(local_config_path)

return [
wandb_run for wandb_run in wandb_runs if _are_equal_configs(_get_wandb_config(wandb_run), train_config)
]


def _get_wandb_path(run_dir: str) -> str:
raise NotImplementedError()
run_dir_storage = _get_storage_adapter_for_path(run_dir)

config_path = os.path.join(run_dir, CONFIG_YAML)
if not run_dir_storage.is_file(config_path):
raise FileNotFoundError("No config file found in run dir, cannot get wandb path")

local_config_path = cached_path(config_path)
config = TrainConfig.load(local_config_path, validate_paths=False)

if config.wandb is None or config.wandb.entity is None or config.wandb.project is None:
raise ValueError(f"Run at {run_dir} has missing wandb config, cannot get wandb run path")

wandb_runs = []

wandb_dir = os.path.join(run_dir, "wandb/")
if run_dir_storage.is_dir(wandb_dir):
wandb_runs += _get_wandb_runs_from_wandb_dir(run_dir_storage, wandb_dir, config)

wandb_runs += _get_wandb_runs_from_train_config(config)

# Remove duplicate wandb runs based on run path, and wandb runs that do not match our run.
wandb_runs = list({_get_wandb_path_from_run(wandb_run): wandb_run for wandb_run in wandb_runs}.values())
wandb_matching_runs = _get_matching_wandb_runs(wandb_runs, run_dir)

if len(wandb_matching_runs) == 0:
raise RuntimeError(f"Failed to find any wandb runs for {run_dir}. Run might no longer exist")

if len(wandb_matching_runs) > 1:
wandb_run_urls = [wandb_run.url for wandb_run in wandb_matching_runs]
raise RuntimeError(
f"Found {len(wandb_matching_runs)} runs matching run dir {run_dir}, cannot determine correct run: {wandb_run_urls}"
)

return _get_wandb_path_from_run(wandb_matching_runs[0])


def _append_wandb_path(
Expand All @@ -961,9 +1063,11 @@ def _append_wandb_path(

if _is_archive(run_dir_or_archive, run_dir_or_archive_storage) and append_archive_extension:
archive_extension = "".join(Path(run_dir_or_archive).suffixes)
wandb_path = wandb_path + archive_extension
relative_wandb_path = wandb_path + archive_extension
else:
relative_wandb_path = wandb_path + "/"

return os.path.join(base_dir, wandb_path)
return os.path.join(base_dir, relative_wandb_path)


def _copy(src_path: str, dest_path: str, temp_dir: str):
Expand Down Expand Up @@ -1045,7 +1149,7 @@ def _move_run(src_storage: StorageAdapter, run_dir_or_archive: str, dest_dir: st

src_move_path, dest_move_path = _get_src_and_dest_for_copy(src_storage, run_dir_or_archive, dest_dir, config)

if src_move_path == dest_move_path:
if src_move_path.rstrip("/") == dest_move_path.rstrip("/"):
# This could be a valid scenario if the user is, for example, trying to
# append wandb path to runs and this run has the right wandb path already.
log.info("Source and destination move paths are both %s, skipping", src_move_path)
Expand All @@ -1068,6 +1172,7 @@ def _move_run(src_storage: StorageAdapter, run_dir_or_archive: str, dest_dir: st
def move_run(run_path: str, dest_dir: str, config: MoveRunConfig):
storage = _get_storage_adapter_for_path(run_path)
run_dir_or_archive = _format_dir_or_archive_path(storage, run_path)
dest_dir = f"{dest_dir}/" if not dest_dir.endswith("/") else dest_dir
_move_run(storage, run_dir_or_archive, dest_dir, config)


Expand Down

0 comments on commit 685d11b

Please sign in to comment.