Skip to content

Commit

Permalink
Merge pull request #690 from allenai/shanea/trace-model-outputs-2
Browse files Browse the repository at this point in the history
Add ability to save and compare sub-module outputs
  • Loading branch information
2015aroras authored Aug 12, 2024
2 parents 4332c32 + a342777 commit 0bc7f6c
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `model.embedding_layer_norm` configuration option for adding a LN to the embeddings.
- Added `model.emb_init_std` configuration option to override the standard deviation used to initialize the embeddings.
- Added `CosLinearEnvelope` scheduler, which is a pointwise product of a cosine schedule and a linear decay.
- Added ability to save outputs of submodules for debugging purposes.

### Changed

Expand Down
6 changes: 6 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,12 @@ class TrainConfig(BaseConfig):
Path to cache directory of HF datasets saved with `datasets.save_to_disk`.
"""

module_outputs_save_steps: Optional[List[int]] = None
"""
Outputs of model submodules are saved during the provided steps. Submodule outputs
can be compared using `scripts/compare_module_outputs.py`.
"""

@property
def autocast_precision(self) -> torch.dtype:
if self.precision == "amp_bf16":
Expand Down
58 changes: 58 additions & 0 deletions olmo/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import cProfile
import functools
import gc
import logging
import math
Expand All @@ -20,6 +21,8 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils
import torch.utils.hooks
import wandb
from packaging import version
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
Expand Down Expand Up @@ -650,6 +653,53 @@ def remove_checkpoint(self, idx: int = 0, checkpoint_type: CheckpointType = Chec
else:
raise NotImplementedError(checkpoint_type)

def _setup_module_output_save_hooks(self, micro_batch_idx: int) -> List[torch.utils.hooks.RemovableHandle]:
if (
self.cfg.module_outputs_save_steps is None
or self.global_step not in self.cfg.module_outputs_save_steps
):
return []

if micro_batch_idx != 0 or get_global_rank() != 0:
# Hook is currently only used on the first microbatch of rank 0
return []

trace_save_folder = Path(self.cfg.save_folder) / f"traces/step{self.global_step}"
if trace_save_folder.exists():
if self.cfg.save_overwrite:
shutil.rmtree(trace_save_folder)
else:
raise OLMoConfigurationError(
f"Attempting to overwrite traces at step {self.global_step} without --save_overwrite"
)
trace_save_folder.mkdir(parents=True)

def trace_outputs_hook(
module_name: str, _: torch.nn.Module, args: Tuple[torch.Tensor, ...], output: torch.Tensor
) -> None:
if len(args) == 0:
log.info("No input args for module %s, output %s", module_name, output)

module_input = args[0] if len(args) > 0 else torch.tensor(())
trace_save_folder = Path(self.cfg.save_folder) / f"traces/step{self.global_step}"
trace_save_folder.mkdir(parents=True, exist_ok=True)

module_occurence_num = 0
while (
module_input_filepath := trace_save_folder / f"{module_name}_{module_occurence_num}_input.pt"
).exists():
module_occurence_num += 1
torch.save(module_input, module_input_filepath)

module_output_filepath = trace_save_folder / f"{module_name}_{module_occurence_num}_output.pt"
torch.save(output, module_output_filepath)

output_hooks = []
for module_name, module in self.model.named_modules(prefix="model"):
output_hooks.append(module.register_forward_hook(functools.partial(trace_outputs_hook, module_name)))

return output_hooks

def get_labels(self, batch: Dict[str, Any]) -> torch.Tensor:
# Labels are just input IDs shifted to the left (first item is ignored).
labels, label_mask, attention_mask, instance_mask = (
Expand Down Expand Up @@ -740,6 +790,10 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor
if micro_batch_idx != num_micro_batches - 1:
grad_sync_context = self.dist_model.no_sync

# Register output hooks
output_hooks: List[torch.utils.hooks.RemovableHandle] = []
output_hooks += self._setup_module_output_save_hooks(micro_batch_idx)

with grad_sync_context():
with torch.autocast("cuda", enabled=True, dtype=self.cfg.autocast_precision):
# Run forward pass.
Expand All @@ -756,6 +810,10 @@ def train_batch(self, batch: Dict[str, Any]) -> Tuple[torch.Tensor, Optional[tor
# Run backward pass.
loss.backward()

# Remove output hooks
for hook in output_hooks:
hook.remove()

return ce_batch_loss, z_batch_loss

def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> Dict[str, float]:
Expand Down
139 changes: 139 additions & 0 deletions scripts/compare_module_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import logging
from argparse import ArgumentParser
from pathlib import Path
from typing import List

import torch

logger = logging.getLogger(__name__)


def _get_module_names(checkpoint_traces_folder: Path) -> List[str]:
module_names = []
for trace_file in checkpoint_traces_folder.iterdir():
trace_file_name = trace_file.name
if trace_file_name.endswith("_input.pt"):
module_name = trace_file_name.removesuffix("_input.pt")
elif trace_file_name.endswith("_output.pt"):
module_name = trace_file_name.removesuffix("_output.pt")
else:
logger.warning("Cannot get parameter from file %s, skipping", trace_file_name)

module_names.append(module_name)

return module_names


def compare_module_output(
base_traces_folder: Path,
compare_traces_folder: Path,
module_name: str,
*,
include_non_tensor_outputs: bool = True,
verbose: bool = False,
):
base_module_input_path = base_traces_folder / f"{module_name}_input.pt"
base_module_output_path = base_traces_folder / f"{module_name}_output.pt"
compare_module_input_path = compare_traces_folder / f"{module_name}_input.pt"
compare_module_output_path = compare_traces_folder / f"{module_name}_output.pt"

map_location = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
base_input = torch.load(str(base_module_input_path), map_location=map_location)
compare_input = torch.load(str(compare_module_input_path), map_location=map_location)

if verbose or base_input.dtype != compare_input.dtype:
logger.info("%s input dtypes: %s %s", module_name, base_input.dtype, compare_input.dtype)
if verbose or base_input.shape != compare_input.shape:
logger.info("%s input shapes: %s %s", module_name, base_input.shape, compare_input.shape)
if (norm_diff := torch.linalg.vector_norm((compare_input - base_input).float()).item()) != 0.0 or verbose:
logger.info("%s input norm diff: %.6f", module_name, norm_diff)
if "wte" in module_name:
logger.info(
"%s mis-matching wte elements: %d",
module_name,
torch.sum(torch.logical_not(torch.eq(base_input, compare_input))),
)

base_output = torch.load(str(base_module_output_path), map_location=map_location)
compare_output = torch.load(str(compare_module_output_path), map_location=map_location)

if isinstance(base_output, torch.Tensor):
if verbose or base_output.dtype != compare_output.dtype:
logger.info("%s output dtypes: %s %s", module_name, base_output.dtype, compare_output.dtype)
if (
norm_diff := torch.linalg.vector_norm((compare_output - base_output).float()).item()
) != 0.0 or verbose:
logger.info("%s output norm diff: %.6f", module_name, norm_diff)
elif include_non_tensor_outputs:
logger.info("%s outputs: %s %s", module_name, base_output, compare_output)
else:
if verbose:
logger.info("Base output is type %s, skipping", type(base_output))


def compare_module_outputs(
base_traces_folder: Path,
compare_traces_folder: Path,
*,
include_non_tensor_outputs: bool = True,
verbose: bool = False,
):
base_modules = set(_get_module_names(base_traces_folder))
compare_modules = set(_get_module_names(compare_traces_folder))

base_only_modules = base_modules - compare_modules
if len(base_only_modules) > 0:
logger.info("Base-only modules: %s", ", ".join(base_only_modules))

compare_only_modules = compare_modules - base_modules
if len(compare_only_modules) > 0:
logger.info("Compare-only modules: %s", ", ".join(compare_only_modules))

common_modules = base_modules.intersection(compare_modules)
for module_name in sorted(common_modules):
compare_module_output(
base_traces_folder,
compare_traces_folder,
module_name,
include_non_tensor_outputs=include_non_tensor_outputs,
verbose=verbose,
)


def main():
logging.basicConfig(encoding="utf-8", level=logging.INFO)

parser = ArgumentParser()
parser.add_argument(
"base_model_traces_path",
type=Path,
help="Path where output traces of the base (i.e. reference) model are stored",
)
parser.add_argument(
"compare_model_traces_path",
type=Path,
help="Path where output traces of the compare (a.k.a new, different) model are stored",
)
parser.add_argument(
"--include_non_tensor_outputs",
action="store_true",
dest="include_non_tensor_outputs",
help="If set, compare module outputs that are not tensors",
)
parser.add_argument(
"--verbose",
action="store_true",
help="If set, show extra information",
)

args = parser.parse_args()
compare_module_outputs(
args.base_model_traces_path,
args.compare_model_traces_path,
include_non_tensor_outputs=args.include_non_tensor_outputs,
verbose=args.verbose,
)


if __name__ == "__main__":
main()

0 comments on commit 0bc7f6c

Please sign in to comment.