Skip to content

Commit

Permalink
-
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Oct 15, 2024
1 parent ca8ebbd commit a422c2e
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
3 changes: 1 addition & 2 deletions benchmarks/lightning/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import argparse
import os

# FIXME this is HPU only
os.environ["PT_HPU_LAZY_MODE"] = str(int(int(os.getenv("WORLD_SIZE", -1)) <= 0))

from habana_frameworks.torch import hpu; hpu.init()

import torch
import torch.nn.functional as F
import lightning as L
Expand Down
6 changes: 5 additions & 1 deletion benchmarks/llm/recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,8 @@ def train(self) -> None:
"""
The core training loop.
"""

import torchcompat.core as accelerator

if self._model_compile:
log.info(
"NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration."
Expand Down Expand Up @@ -580,10 +581,13 @@ def train(self) -> None:
loss = self._loss_fn(logits, labels) / self._gradient_accumulation_steps
running_loss += loss
loss.backward()
accelerator.mark_step()

# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
self._optimizer.step()
accelerator.mark_step()

self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
# Update the number of steps when the weights are updated
Expand Down
7 changes: 7 additions & 0 deletions benchmarks/rlhf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import shutil

import accelerate
from accelerate import PartialState
from datasets import load_dataset
from transformers import (
Expand All @@ -15,10 +16,16 @@
from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE

import torchcompat.core as compat


class PPOv2TrainerIntrumented(PPOv2Trainer):
def __init__(self, config: PPOv2Config, *args, **kwargs):
config.report_to = []

# FIXME: better way to monkeypatch this ?
# Use the compatibility accelerator class
accelerate.Accelerator = compat.accelerate.Accelerator
super().__init__(config, *args, **kwargs)

def batch_size_fn(batch):
Expand Down
13 changes: 10 additions & 3 deletions scripts/article/run_hpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export MILABENCH_WORDIR="$(pwd)/$MILABENCH_GPU_ARCH"
export MILABENCH_BASE="$MILABENCH_WORDIR/results"
export MILABENCH_VENV="$MILABENCH_WORDIR/env"
export BENCHMARK_VENV="$MILABENCH_WORDIR/results/venv/torch"

export PT_HPU_LAZY_MODE=0

if [ -z "${MILABENCH_SOURCE}" ]; then
export MILABENCH_CONFIG="$MILABENCH_WORDIR/milabench/config/standard.yaml"
Expand Down Expand Up @@ -84,6 +84,8 @@ install_prepare() {
#
# Generate/download datasets, download models etc...
#
sed -i 's/pic.numpy(force=True)/pic.numpy()/' $BENCHMARK_VENV/lib/python3.10/dist-packages/torchvision/transforms/functional.py
sed -i 's/range(hpu.device_count())/range(len(available_modules))/' $BENCHMARK_VENV/lib/site-packages/habana_frameworks/torch/hpu/_utils.py
milabench prepare $ARGS
}

Expand All @@ -95,12 +97,17 @@ else
fi


(
(
. $BENCHMARK_VENV/bin/activate
pip install lightning-habana
pip install habana-media-loader
# git clone [email protected]:Delaunay/torchcompat.git
# git clone [email protected]:Delaunay/voir.git
pip install -e $MILABENCH_WORDIR/torchcompat
pip install -e $MILABENCH_WORDIR/voir
pip install -e $MILABENCH_WORDIR/optimum-habana
# pip install habana_dataloader
)
)

if [ "$MILABENCH_PREPARE" -eq 0 ]; then
cd $MILABENCH_WORDIR
Expand Down

0 comments on commit a422c2e

Please sign in to comment.