Skip to content

Commit

Permalink
Fix (examples/llm): Fix infinite loop in LLM entrypoint with WikiText2 (
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago authored Oct 8, 2024
1 parent 9048ecb commit 4d8b153
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
7 changes: 6 additions & 1 deletion src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ]
[--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}]
[--export-prefix EXPORT_PREFIX]
[--checkpoint-name CHECKPOINT_NAME]
[--checkpoint-name CHECKPOINT_NAME] [--fuse-sequences]

options:
-h, --help show this help message and exit
Expand Down Expand Up @@ -131,5 +131,10 @@ options:
--checkpoint-name CHECKPOINT_NAME
Filename to save checkpoint. If `None`, no checkpoint
is saved (default: None)
--fuse-sequences Whether to merge the dataset sequences in case they
are shorter than the requested number of samples per
sequence. This is useful in case you would like to
quantize or evaluate on long sequences (default:
False).

```
20 changes: 17 additions & 3 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import argparse
import re
import sys
from warnings import warn

import numpy as np
from optimum.amd.brevitas.accelerate_utils import offload_model
Expand Down Expand Up @@ -109,6 +110,13 @@ def validate(args):
else:
assert args.export_target != 'torch_qcdq', "Cannot export Torch QCDQ with FX"

if not args.fuse_sequences:
# 350 is approximately the 99% percentile for the sequence length in WikiText2 (train partition, using AutoTokenizer)
if args.seqlen >= 350:
warn(
"Data loading can take a long time or, potentially, enter an infinite loop. Consider setting --args.fuse_sequences "
"or decreasing the sequence length (seqlen)")


def main(args):
validate(args)
Expand Down Expand Up @@ -142,7 +150,6 @@ def main(args):
apply_awq(model, awq_results)

require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge else False
fuse_sequences = False

# Load the data for calibration and evaluation.
calibration_loader = get_dataset_for_model(
Expand All @@ -155,7 +162,7 @@ def main(args):
seed=args.seed,
require_fx=require_fx,
device=None,
fuse_sequences=fuse_sequences,
fuse_sequences=args.fuse_sequences,
)

validation_loader = get_dataset_for_model(
Expand All @@ -168,7 +175,7 @@ def main(args):
seed=args.seed,
require_fx=require_fx,
device=None,
fuse_sequences=fuse_sequences,
fuse_sequences=args.fuse_sequences,
)

device = next(iter(model.parameters())).device
Expand Down Expand Up @@ -474,6 +481,13 @@ def parse_args(args):
default=None,
help="Filename to save checkpoint. If `None`, no checkpoint is saved (default: %(default)s)"
)
parser.add_argument(
"--fuse-sequences",
action="store_true",
default=False,
help=
"Whether to merge the dataset sequences in case they are shorter than the requested number of samples per sequence. This is useful in case you would like to quantize or evaluate on long sequences (default: %(default)s).",
)
return parser.parse_args(args)


Expand Down

0 comments on commit 4d8b153

Please sign in to comment.