forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_gpt2.py
448 lines (396 loc) · 19.2 KB
/
train_gpt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
"""
Reference code for GPT-2 training and inference.
Will save the model weights into files, to be read from C as initialization.
References:
1) the official GPT-2 TensorFlow implementation released by OpenAI:
https://github.com/openai/gpt-2/blob/master/src/model.py
2) huggingface/transformers PyTorch implementation:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
"""
import os
import math
import struct
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch._inductor.config as config
class NewGELU(nn.Module):
"""Careful there are a few versions of GeLU, this one is the exact one used by OpenAI"""
def forward(self, input):
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
# regularization
self.n_head = config.n_head
self.n_embd = config.n_embd
# not really a 'bias', more of a mask, but following the OpenAI/HF naming though
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
self.gelu = NewGELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50257
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = nn.LayerNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
# forward the GPT model itself
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = tok_emb + pos_emb
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None
return logits, loss
@classmethod
def from_pretrained(cls, model_type):
"""Loads pretrained GPT-2 model weights from huggingface"""
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
from transformers import GPT2LMHeadModel
print("loading weights from pretrained gpt: %s" % model_type)
# n_layer, n_head and n_embd are determined from model_type
config_args = {
'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}[model_type]
config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
# create a from-scratch initialized minGPT model
config = GPTConfig(**config_args)
model = GPT(config)
sd = model.state_dict()
sd_keys = sd.keys()
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
# init a huggingface/transformers model
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
sd_hf = model_hf.state_dict()
# copy while ensuring all of the parameters are aligned and match in names and shapes
sd_keys_hf = sd_hf.keys()
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
# this means that we have to transpose these weights when we import them
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
for k in sd_keys_hf:
if any(k.endswith(w) for w in transposed):
# special treatment for the Conv1D weights we need to transpose
assert sd_hf[k].shape[::-1] == sd[k].shape
with torch.no_grad():
sd[k].copy_(sd_hf[k].t())
else:
# vanilla copy over the other parameters
assert sd_hf[k].shape == sd[k].shape
with torch.no_grad():
sd[k].copy_(sd_hf[k])
return model
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
return idx
# a few utilities for saving params/grads/activations to files for loading in C
def write_fp32(tensor, file):
file.write(tensor.detach().cpu().numpy().astype("float32").tobytes())
def write_tensors(model_tensors, L, file):
write_fp32(model_tensors["transformer.wte.weight"], file) # (V, C)
write_fp32(model_tensors["transformer.wpe.weight"], file) # (T, C)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_1.weight"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_1.bias"], file)
for i in range(L): # (L, 3C, C)
write_fp32(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file)
for i in range(L): # (L, 3C)
write_fp32(model_tensors[f"transformer.h.{i}.attn.c_attn.bias"], file)
for i in range(L): # (L, C, C)
write_fp32(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.attn.c_proj.bias"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_2.weight"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_2.bias"], file)
for i in range(L): # (L, 4C, C)
write_fp32(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file)
for i in range(L): # (L, 4C)
write_fp32(model_tensors[f"transformer.h.{i}.mlp.c_fc.bias"], file)
for i in range(L): # (L, C, 4C)
write_fp32(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.mlp.c_proj.bias"], file)
write_fp32(model_tensors["transformer.ln_f.weight"], file) # (C, )
write_fp32(model_tensors["transformer.ln_f.bias"], file) # (C, )
def write_model(model, filename):
# everything we need to instantiate the model
# 1) header is: version int, GPTConfig ints, padding to 1024 bytes
header = torch.zeros(256, dtype=torch.int32)
header[0] = 20240326 # magic
header[1] = 1 # checkpoint version = 1
header[2] = model.config.block_size
header[3] = model.config.vocab_size
header[4] = model.config.n_layer
header[5] = model.config.n_head
header[6] = model.config.n_embd
# 2) the parameters on CPU are next
params = {name: param.cpu() for name, param in model.named_parameters()}
# now write
with open(filename, "wb") as file:
# header
file.write(header.numpy().tobytes())
# model parameters
write_tensors(params, model.config.n_layer, file)
print(f"wrote {filename}")
def write_state(model, x, y, logits, loss, filename):
# the state is used for debugging.
# it contains information about the input, logits, loss, and the parameter gradients
# this can be used for checking the computation correctness in C
header = torch.zeros(256, dtype=torch.int32)
header[0] = 20240327 # magic
header[1] = 1 # run state version = 1
header[2] = x.size(0) # batch size of the batch, B
header[3] = x.size(1) # temporal extent of the batch, T
grads = {name: param.grad.cpu() for name, param in model.named_parameters()}
with open(filename, "wb") as file:
# header
file.write(header.numpy().tobytes())
# input x
file.write(x.cpu().numpy().astype("int32").tobytes()) # (B, T)
# targets y
file.write(y.cpu().numpy().astype("int32").tobytes()) # (B, T)
# logits (result of the model forward pass)
write_fp32(logits.cpu(), file)
# loss (single float, result of the cross entropy loss)
write_fp32(loss.cpu(), file)
# gradients
write_tensors(grads, model.config.n_layer, file)
print(f"wrote {filename}")
def write_tokenizer(enc, filename):
n = enc.max_token_value + 1
header = torch.zeros(256, dtype=torch.int32)
header[0] = 20240328 # magic
header[1] = 1 # tokenizer version = 1
header[2] = n # number of tokens
with open(filename, "wb") as file:
file.write(header.numpy().tobytes())
for i in range(n):
b = enc.decode_bytes([i])
length = len(b)
assert length < 256, f"Token length exceeds 255: {length}"
file.write(struct.pack("<B", length)) # Write the length as a 1-byte unsigned integer
file.write(b) # Write the actual bytes
print(f"wrote {filename}")
if __name__ == "__main__":
import time
import argparse
import tiktoken
print(f"Running pytorch {torch.version.__version__}")
# default settings will overfit a tiny batch of data
# and save model weights and debug state to disk on the first iteration
# if you'd like to e.g. time the forward pass only, call this script as:
# python train_gpt2.py --inference_only 1 --write_tensors 0 --sequence_length 1024
parser = argparse.ArgumentParser()
parser.add_argument("--write_tensors", type=int, default=1, help="write tensors to disk")
parser.add_argument("--inference_only", type=int, default=0, help="only run inference")
parser.add_argument("--compile", type=int, default=0, help="torch.compile the model")
parser.add_argument("--tensorcores", type=int, default=0, help="use tensorcores")
parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
parser.add_argument("--batch_size", type=int, default=4, help="batch size")
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
args = parser.parse_args()
B, T = args.batch_size, args.sequence_length
assert 1 <= T <= 1024
# select a reasonable device to run on
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
print(f"using device: {device}")
# seed the random number generators
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
# init the tokenizer
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)
write_tokenizer(enc, "gpt2_tokenizer.bin")
if args.tensorcores:
torch.set_float32_matmul_precision('high')
# load the GPT-2 model weights
model = GPT.from_pretrained("gpt2")
model.train()
model.to(device)
if args.compile:
config.coordinate_descent_tuning = True # suggested by @Chillee
print("compiling the model...")
model = torch.compile(model)
# load the tokens
# prefer to use tiny_shakespeare if it's available, otherwise use tiny_stories
# we're using val instead of train split just because it is smaller/faster
shake_tokens_bin = "data/tiny_shakespeare_val.bin"
story_tokens_bin = "data/TinyStories_val.bin"
assert os.path.isfile(shake_tokens_bin) or os.path.isfile(story_tokens_bin), "you must run prepro on some dataset"
tokens_bin = shake_tokens_bin if os.path.isfile(shake_tokens_bin) else story_tokens_bin
assert os.path.isfile(tokens_bin)
print(f"loading cached tokens in {tokens_bin}")
with open(tokens_bin, "rb") as f:
tokens = np.frombuffer(f.read(), dtype=np.int32)
# np -> tensor, long, on device
tokens = torch.tensor(tokens)
tokens = tokens.to(torch.long)
tokens = tokens.to(device)
# lightweight dataloader
def get_batch():
assert B*T+1 <= len(tokens), "not enough tokens"
# for 338,025 tokens. E.g. with B=8 T=1024, this will yield 41 batches before looping
i = 0
while True:
x = tokens[i:i+B*T].view(B, T)
y = tokens[i+1:i+B*T+1].view(B, T)
yield x, y
i += B*T
if i + B*T + 1 >= len(tokens):
i = 0 # in prod we'd want to randomize the start point a bit
# forward backward for a few iterations
data_iter = iter(get_batch())
x, y = next(data_iter) # we'll overfit this batch below
# do one forward pass to generate ground truth for our C tests
if not args.inference_only and args.write_tensors:
logits, loss = model(x, y)
loss.backward()
write_model(model, "gpt2_124M.bin")
write_state(model, x, y, logits, loss, "gpt2_124M_debug_state.bin")
use_fused = device == "cuda" # only works on CUDA (?)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, fused=use_fused)
timings = []
if device == "cuda":
torch.cuda.reset_peak_memory_stats()
for i in range(args.num_iterations):
t0 = time.time()
logits, loss = model(x, y)
if not args.inference_only:
optimizer.zero_grad()
del logits
loss.backward()
optimizer.step()
if device == "mps":
torch.mps.synchronize()
elif device == "cuda":
torch.cuda.synchronize()
t1 = time.time()
if i > args.num_iterations - 20:
timings.append(t1-t0)
print(f"iteration {i}, loss: {loss.item()}, time: {(t1-t0)*1000:.3f}ms")
if len(timings) > 20:
print(f"final 20 iters avg: {np.mean(timings[-20:])*1000:.3f}ms")
else:
print(f"final {len(timings)-1} iters avg: {np.mean(timings[1:])*1000:.3f}ms")
print(f"Peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
# before we end, let's also do one round of inference
# we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence
start = "<|endoftext|>"
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
# run generation for 16 time steps (tokens)
max_new_tokens = 16
temperature = 1.0
top_k = 40
model.eval()
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print(decode(y[0].tolist()))
print('---------------')