Skip to content

Commit

Permalink
Merge branch 'version-dolma-flan-change' of github.com:allenai/OLMo i…
Browse files Browse the repository at this point in the history
…nto version-dolma-flan-change
  • Loading branch information
IanMagnusson committed Aug 21, 2024
2 parents bd29493 + b189ca6 commit bd319be
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 26 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

- Fixed conversion to HuggingFace model for DDP-trained models.
- Added support for remote source and destination for HuggingFace model conversion.

### Added

- Added support for document masking via flash-attn during training with `--data.generate_doc_lengths`.
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ Note: passing CLI overrides like `--reset_trainer_state` is only necessary if yo

Additional tools for evaluating OLMo models are available at the [OLMo Eval](https://github.com/allenai/ai2-olmo-eval) repo.

## Debugging

See [Debugging](https://github.com/allenai/OLMo/blob/main/docs/NOTES.md#debugging).

## Citing

```bibtex
Expand Down
22 changes: 22 additions & 0 deletions docs/NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,25 @@ outputs = model.generate(input_tensor, max_steps=3, beam_size=3)
best_generation = outputs.token_ids[0][0].tolist()
print(tokenizer.decode(best_generation))
```

## Debugging

### Finding the cause of hangs

Hangs in distributed training can be due to several different causes, including
bad user code, AMD/Nvidia memory-allocation issues, or issues in hardware setup.
These issues can be difficult to root-cause and even harder to fix.

One approach we use to find the cause of a hang in distributed training is to first identify which processes/nodes are hanging. The [scripts/pyspy_all_processes.sh](https://github.com/allenai/OLMo/blob/main/scripts/pyspy_all_processes.sh) script retrieves the python state of relevant python processes using `pyspy`. A process/node with different state may be experiencing a hang.

If a hang is suspected to be in GPU code, then you can run `gcore <pid>` on a hanging process to get core dumps. Then you can run `gdb <corefile>` and check where the code is hanging from a C++ perspective. Code being stuck on a GPU memory allocation (malloc) may be indicative of a hardware/setup issue rather than a problem in training code.

### Comparing two models that should be identical

There are some scenarios when one might want to investigate why two models/setups that should be identical are yielding different results. A naive solution is to run both setups side-by-side and compare results manually (and this might not be possible if you have just 1 GPU).

An alternative for comparing OLMo models is to run the training of both models with the `--module_outputs_save_steps=[<list of steps]` config option. This causes OLMo to save a portion of the inputs & outputs of each OLMo submodule into a `traces/` folder at the model step's save location. Then [script/compare_module_outputs.py](https://github.com/allenai/OLMo/blob/main/scripts/compare_module_outputs.py) can be used to compare these portions of inputs & outputs, thus hopefully isolating the issue to a subset of model modules. See [script/compare_module_outputs.py](https://github.com/allenai/OLMo/blob/main/scripts/compare_module_outputs.py) for more details on its usage.

When comparing different hardware or dependency setups, it is possible that model
state gets corrupted before the first forward pass of training. One can check this
by running training with `--force_save_unsharded --dry_run --load_path=<original_model_path>` to save a checkpoint after the original model has loaded but before training has started. Then [scripts/compare_model_state.py](https://github.com/allenai/OLMo/blob/main/scripts/compare_model_state.py) can be used to see if parameters are different between the 2 models.
240 changes: 214 additions & 26 deletions hf_olmo/convert_olmo_to_hf.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,52 @@
import argparse
import logging
import os
import re
import shutil
import tempfile
from hashlib import md5
from typing import Iterable, Optional
from urllib.parse import urlparse

import torch
from omegaconf import OmegaConf as om
from tqdm import tqdm

from hf_olmo.configuration_olmo import OLMoConfig
from hf_olmo.modeling_olmo import OLMoForCausalLM
from hf_olmo.tokenization_olmo_fast import OLMoTokenizerFast
from olmo import ModelConfig, Tokenizer
from olmo import ModelConfig, Tokenizer, TrainConfig
from olmo.checkpoint import build_sharded_checkpointer
from olmo.util import _get_s3_client

logger = logging.getLogger(__name__)

HF_FILENAMES = {
"config.json",
"pytorch_model.bin",
"special_tokens_map.json",
"tokenizer_config.json",
"tokenizer.json",
}


def longest_common_prefix(strs: Iterable[str]) -> str:
"""
Finds the longest common prefix among a list of strings.
"""
if not strs:
return ""

# Find the shortest string in the list
shortest_str = min(strs, key=len)

for i, char in enumerate(shortest_str):
for other_str in strs:
if other_str[i] != char:
return shortest_str[:i]

return shortest_str


def write_config(checkpoint_dir: str):
# save config as HF config
Expand All @@ -36,8 +70,15 @@ def write_model(checkpoint_dir: str, ignore_olmo_compatibility: bool = False):
old_model_path = os.path.join(checkpoint_dir, "model.pt")
new_model_path = os.path.join(checkpoint_dir, "pytorch_model.bin")

state_dict = torch.load(old_model_path)
new_state_dict = {f"{OLMoForCausalLM.base_model_prefix}.{key}": val for key, val in state_dict.items()}
state_dict = torch.load(old_model_path, map_location="cpu")

# this takes care of the case where the model was saved with a different prefix,
# typically due to unsharding.
common_prefix = longest_common_prefix(state_dict.keys())
new_state_dict = {
key.replace(common_prefix, f"{OLMoForCausalLM.base_model_prefix}.transformer."): val
for key, val in state_dict.items()
}
torch.save(new_state_dict, new_model_path)

if ignore_olmo_compatibility:
Expand All @@ -60,46 +101,150 @@ def write_tokenizer(checkpoint_dir: str):


def convert_checkpoint(checkpoint_dir: str, ignore_olmo_compatibility: bool = False):
print("Converting checkpoint to HF format...")
write_config(checkpoint_dir)

print("Saving model to checkpoint...")
write_model(checkpoint_dir, ignore_olmo_compatibility=ignore_olmo_compatibility)

print("Saving tokenizer to checkpoint...")
write_tokenizer(checkpoint_dir)

# Cannot remove it before writing the tokenizer
if ignore_olmo_compatibility:
os.remove(os.path.join(checkpoint_dir, "config.yaml"))


def download_remote_checkpoint_and_convert_to_hf(checkpoint_dir: str, local_dir: str):
from cached_path import cached_path
def fix_tokenizer(checkpoint_dir: str, tokenizer_name_or_path: Optional[str] = None):
path = os.path.join(checkpoint_dir, "config.yaml")
conf = om.load(path)

model_name = os.path.basename(checkpoint_dir)
local_model_path = os.path.join(local_dir, model_name)
os.makedirs(local_model_path, exist_ok=True)
print("Saving tokenizer to checkpoint...")

model_files = ["model.pt", "config.yaml"] # , "optim.pt", "other.pt"]
for filename in model_files:
final_location = os.path.join(local_model_path, filename)
if not os.path.exists(final_location):
remote_file = os.path.join(checkpoint_dir, filename)
logger.debug(f"Downloading file {filename}")
cached_file = cached_path(remote_file)
shutil.copy(cached_file, final_location)
logger.debug(f"File at {final_location}")
tokenizer_name_or_path = str(tokenizer_name_or_path or conf["tokenizer"]["identifier"]) # pyright: ignore

try:
if os.path.exists(tokenizer_name_or_path):
Tokenizer.from_file(tokenizer_name_or_path)
else:
logger.info(f"File already present at {final_location}")
Tokenizer.from_pretrained(tokenizer_name_or_path)
except Exception as e:
# the tokenizer is not valid
logger.error(f"Invalid tokenizer: {tokenizer_name_or_path}. Error: {e}")
raise e

convert_checkpoint(local_model_path)
return local_model_path
conf["tokenizer"]["identifier"] = tokenizer_name_or_path # pyright: ignore

if tokenizer_name_or_path == "allenai/gpt-neox-olmo-dolma-v1_5" or tokenizer_name_or_path.endswith(
"allenai_eleuther-ai-gpt-neox-20b-pii-special.json"
):
conf["model"]["eos_token_id"] = 50279 # pyright: ignore

def fix_bad_tokenizer(checkpoint_dir: str):
path = os.path.join(checkpoint_dir, "config.yaml")
conf = om.load(path)
conf["tokenizer"]["identifier"] = "allenai/gpt-neox-olmo-dolma-v1_5"
conf["model"]["eos_token_id"] = 50279
om.save(conf, path)


def download_s3_directory(bucket_name: str, prefix: str, local_dir: str, ignore: str | None = None):
# Create S3 client
s3_client = _get_s3_client("s3")

re_ignore = re.compile(ignore) if ignore else None

# List objects within the given prefix
paginator = s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)

# Create a list to hold all the files to download
files_to_download = []
for page in pages:
for obj in page.get("Contents", []):
if re_ignore and re_ignore.search(obj["Key"]):
continue
files_to_download.append(obj["Key"])

# Initialize the progress bar
for s3_key in tqdm(files_to_download, desc="Downloading files"):
# Construct the full local path
local_file_path = os.path.join(local_dir, os.path.relpath(s3_key, prefix))
local_file_dir = os.path.dirname(local_file_path)

# Ensure local directory exists
if not os.path.exists(local_file_dir):
os.makedirs(local_file_dir)

# Download the file
s3_client.download_file(bucket_name, s3_key, local_file_path)


def make_local_checkpoint(checkpoint_dir: str) -> str:
parsed_dir = urlparse(checkpoint_dir)

assert parsed_dir.scheme in ["s3", ""], "Only s3 and local paths are supported."

if os.path.exists(checkpoint_dir):
return checkpoint_dir

temp_dir = os.path.join(tempfile.gettempdir(), md5(checkpoint_dir.encode()).hexdigest())
if os.path.exists(temp_dir):
return temp_dir
try:
os.makedirs(temp_dir, exist_ok=True)
print(f"Downloading checkpoint to {temp_dir}...")
download_s3_directory(
bucket_name=parsed_dir.netloc,
prefix=parsed_dir.path.lstrip("/"),
local_dir=temp_dir,
ignore=r"/(optim|train)/",
)
except Exception as e:
logger.error(f"Error downloading checkpoint: {e}")
shutil.rmtree(temp_dir)
raise e

return temp_dir


def upload_local_checkpoint(local_checkpoint_dir: str, destination_dir: str):
if destination_dir == local_checkpoint_dir:
return
elif (parsed_url := urlparse(destination_dir)).scheme == "s3":
s3_bucket_name = parsed_url.netloc
s3_prefix = parsed_url.path[1:]

local_paths = [
os.path.join(root, post_fn)
for root, _, files in os.walk(local_checkpoint_dir)
for post_fn in files
if os.path.basename(post_fn) in HF_FILENAMES
]
dest_paths = [
os.path.join(s3_prefix, os.path.relpath(local_path, local_checkpoint_dir))
for local_path in local_paths
]

s3_client = _get_s3_client("s3")
for local_path, dest_path in tqdm(
zip(local_paths, dest_paths), desc="Uploading files", total=len(local_paths)
):
s3_client.upload_file(local_path, s3_bucket_name, dest_path)
elif parsed_url.scheme == "":
shutil.copytree(local_checkpoint_dir, destination_dir)
else:
raise ValueError(f"Unsupported destination: {destination_dir}. Only s3 and local paths are supported.")


def maybe_unshard(checkpoint_dir: str):
if os.path.exists(os.path.join(checkpoint_dir, "model.pt")):
return

print(f"Unsharding {checkpoint_dir}...")
train_config = TrainConfig.load(os.path.join(checkpoint_dir, "config.yaml"))
checkpointer = build_sharded_checkpointer(train_config)
model_state, _, _ = checkpointer.unshard_checkpoint(
load_path=checkpoint_dir, load_optimizer_state=False, load_trainer_state=False
)
torch.save(model_state, os.path.join(checkpoint_dir, "model.pt"))


def main():
parser = argparse.ArgumentParser(
description="Adds a config.json to the checkpoint directory, and creates pytorch_model.bin, "
Expand All @@ -108,6 +253,13 @@ def main():
parser.add_argument(
"--checkpoint-dir",
help="Location of OLMo checkpoint.",
required=True,
)

parser.add_argument(
"--destination-dir",
help="Location to save the converted checkpoint; default is the same as the checkpoint-dir.",
default=None,
)

parser.add_argument(
Expand All @@ -116,11 +268,47 @@ def main():
help="Ignore compatibility with the olmo codebase. "
"This will remove files that are needed specifically for olmo codebase, eg. config.yaml, etc.",
)
parser.add_argument(
"--logger-level",
default="warning",
help="Set the logger level.",
)

parser.add_argument(
"--tokenizer",
help="Override the tokenizer to use for the checkpoint.",
)
parser.add_argument(
"--keep-olmo-artifacts",
action="store_true",
help="Keep olmo-specific artifacts in the checkpoint.",
)

args = parser.parse_args()
fix_bad_tokenizer(args.checkpoint_dir)

args.destination_dir = args.destination_dir or args.checkpoint_dir
logging.basicConfig()
logger.setLevel(logging.getLevelName(args.logger_level.upper()))

local_checkpoint_dir = make_local_checkpoint(args.checkpoint_dir)
args.checkpoint_dir = local_checkpoint_dir
maybe_unshard(local_checkpoint_dir)

fix_tokenizer(checkpoint_dir=local_checkpoint_dir, tokenizer_name_or_path=args.tokenizer)
convert_checkpoint(args.checkpoint_dir, args.ignore_olmo_compatibility)

if not args.keep_olmo_artifacts:
print("Removing non-HF artifacts...")
os.remove(os.path.join(local_checkpoint_dir, "config.yaml"))
os.remove(os.path.join(local_checkpoint_dir, "model.pt"))
shutil.rmtree(os.path.join(local_checkpoint_dir, "optim"), ignore_errors=True)
shutil.rmtree(os.path.join(local_checkpoint_dir, "model"), ignore_errors=True)
shutil.rmtree(os.path.join(local_checkpoint_dir, "train"), ignore_errors=True)

upload_local_checkpoint(local_checkpoint_dir, args.destination_dir)

print(f"Converted checkpoint saved to {args.destination_dir}")


if __name__ == "__main__":
main()
28 changes: 28 additions & 0 deletions scripts/compare_module_outputs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,31 @@
"""
Script for comparing collected outputs of OLMo submodules from 2
different training run steps (of the same or different runs).
This script is useful for identifying where model activations start to differ
within 2 forward passes that should yield identical results. In turn, detecting
regressions can be a lot quicker/easier.
This script requires that traces containing submodule outputs have been collected
during training. The traces can be saved using
`--module_outputs_save_steps=[<list of step>]`. Be mindful that the saving takes
a lot of storage and is very slow, so collect traces sparingly. If comparing 2
training runs starting from the same checkpoint, a viable approach is to collect
the 2 steps after training resumes. The first step can be used to detect issues
in the forward pass, while if only the second step shows discrepancies then the
backward pass may be the cause of any issues.
Example usage (Aug 2024):
```
python scripts/compare_module_outputs.py test_model/traces/step10 test_model_2/traces/step10
```
If this model produces no output stating diffs (without `--verbose`), then the
outputs between the 2 models are identical. If `mis-matching wte elements: ...`
shows a non-zero value, then the input data of the 2 forward passes being compared
is likely different.
"""

import logging
from argparse import ArgumentParser
from pathlib import Path
Expand Down

0 comments on commit bd319be

Please sign in to comment.