Skip to content

Commit

Permalink
Fix (data): updating wikitext2 data utility (#1080)
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert authored Oct 30, 2024
1 parent 7af0dc1 commit ae3ec68
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
39 changes: 24 additions & 15 deletions src/brevitas_examples/llm/llm_quant/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
"""

import random
from typing import Any

from datasets import load_dataset
import torch
from transformers import AutoTokenizer
from tqdm import tqdm


def get_c4(nsamples, seed, seqlen, tokenizer, split='train', nvalsamples=0):
Expand Down Expand Up @@ -69,27 +70,35 @@ def get_c4(nsamples, seed, seqlen, tokenizer, split='train', nvalsamples=0):
return valenc


def get_wikitext2(nsamples, seed, seqlen, tokenizer, type='raw', split='train'):
from datasets import load_dataset
dataset_name = 'wikitext-2-v1'
if type == 'raw':
dataset_name = 'wikitext-2-raw-v1'
if split == 'train':
traindata = load_dataset('wikitext', dataset_name, split='train')
def get_wikitext2(
tokenizer: Any,
seqlen: int,
nsamples: int,
split: str = 'train',
fuse_sequences: bool = True,
seed: int = 42):
random.seed(seed)

if split == 'train':
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')

import random
random.seed(seed)
trainloader = []
for _ in range(nsamples):
for _ in tqdm(range(nsamples)):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
attention_mask = torch.ones_like(inp)
trainloader.append({'input_ids': inp, 'attention_mask': attention_mask})
return trainloader
elif split == 'validation':
testdata = load_dataset('wikitext', dataset_name, split='test')
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
return testenc
data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
data = tokenizer("\n\n".join(data['text']), return_tensors='pt')
nsamples = data['input_ids'].numel() // seqlen
testloader = []
for i in tqdm(range(nsamples)):
batch = data['input_ids'][:, (i * seqlen):((i + 1) * seqlen)]
attention_mask = torch.ones_like(batch)
testloader.append({'input_ids': batch, 'attention_mask': attention_mask})
return testloader
else:
raise ValueError(f"{split} is invalid")
3 changes: 2 additions & 1 deletion src/brevitas_examples/llm/llm_quant/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
import numpy as np
from optimum.amd.brevitas.data_utils import DatasetToDevice
from optimum.amd.brevitas.data_utils import get_c4
from optimum.amd.brevitas.data_utils import get_wikitext2
from optimum.utils.normalized_config import NormalizedConfigManager
import torch
from transformers import AutoConfig

from .data import get_wikitext2


def get_dataset_for_model(
model_name_or_path: str,
Expand Down

0 comments on commit ae3ec68

Please sign in to comment.