From e1f6a5d8d1ff9999ab63d2a493c03c4e3b90b9f2 Mon Sep 17 00:00:00 2001 From: i-colbert Date: Tue, 29 Oct 2024 19:57:43 +0000 Subject: [PATCH] Fix (data): updating wikitext2 data utility --- src/brevitas_examples/llm/llm_quant/data.py | 39 ++++++++++++------- .../llm/llm_quant/data_utils.py | 3 +- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/data.py b/src/brevitas_examples/llm/llm_quant/data.py index 2f208e715..a535feae9 100644 --- a/src/brevitas_examples/llm/llm_quant/data.py +++ b/src/brevitas_examples/llm/llm_quant/data.py @@ -17,10 +17,11 @@ """ import random +from typing import Any from datasets import load_dataset import torch -from transformers import AutoTokenizer +from tqdm import tqdm def get_c4(nsamples, seed, seqlen, tokenizer, split='train', nvalsamples=0): @@ -69,20 +70,20 @@ def get_c4(nsamples, seed, seqlen, tokenizer, split='train', nvalsamples=0): return valenc -def get_wikitext2(nsamples, seed, seqlen, tokenizer, type='raw', split='train'): - from datasets import load_dataset - dataset_name = 'wikitext-2-v1' - if type == 'raw': - dataset_name = 'wikitext-2-raw-v1' - if split == 'train': - traindata = load_dataset('wikitext', dataset_name, split='train') +def get_wikitext2( + tokenizer: Any, + seqlen: int, + nsamples: int, + split: str = 'train', + fuse_sequences: bool = True, + seed: int = 42): + random.seed(seed) + if split == 'train': + traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') - - import random - random.seed(seed) trainloader = [] - for _ in range(nsamples): + for _ in tqdm(range(nsamples)): i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) j = i + seqlen inp = trainenc.input_ids[:, i:j] @@ -90,6 +91,14 @@ def get_wikitext2(nsamples, seed, seqlen, tokenizer, type='raw', split='train'): trainloader.append({'input_ids': inp, 'attention_mask': attention_mask}) return trainloader elif split == 'validation': - testdata = load_dataset('wikitext', dataset_name, split='test') - testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') - return testenc + data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + data = tokenizer("\n\n".join(data['text']), return_tensors='pt') + nsamples = data['input_ids'].numel() // seqlen + testloader = [] + for i in tqdm(range(nsamples)): + batch = data['input_ids'][:, (i * seqlen):((i + 1) * seqlen)] + attention_mask = torch.ones_like(batch) + testloader.append({'input_ids': batch, 'attention_mask': attention_mask}) + return testloader + else: + raise ValueError(f"{split} is invalid") diff --git a/src/brevitas_examples/llm/llm_quant/data_utils.py b/src/brevitas_examples/llm/llm_quant/data_utils.py index 5375fcddf..1ff82c157 100644 --- a/src/brevitas_examples/llm/llm_quant/data_utils.py +++ b/src/brevitas_examples/llm/llm_quant/data_utils.py @@ -30,11 +30,12 @@ import numpy as np from optimum.amd.brevitas.data_utils import DatasetToDevice from optimum.amd.brevitas.data_utils import get_c4 -from optimum.amd.brevitas.data_utils import get_wikitext2 from optimum.utils.normalized_config import NormalizedConfigManager import torch from transformers import AutoConfig +from .data import get_wikitext2 + def get_dataset_for_model( model_name_or_path: str,