Skip to content

Commit

Permalink
add views
Browse files Browse the repository at this point in the history
  • Loading branch information
hellerve committed Aug 27, 2024
1 parent 7b84d73 commit 29938e2
Showing 1 changed file with 101 additions and 39 deletions.
140 changes: 101 additions & 39 deletions python/nanogpt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

from gtoolkit_bridge import gtView
import torch.nn as nn
import torch


# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
Expand Down Expand Up @@ -33,9 +40,9 @@ def estimate_loss():
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
Expand Down Expand Up @@ -86,10 +93,10 @@ class FeedFoward(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)

def forward(self, x):
Expand Down Expand Up @@ -136,38 +143,57 @@ def forward(self, idx, targets=None):
logits = self.lm_head(x) # (B,T,vocab_size)

if targets is None:
loss = None
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)

return logits, loss

def generate(self, idx, max_new_tokens):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# crop idx to the last block_size tokens
idx_cond = idx[:, -block_size:]
# get the predictions
logits, loss = self(idx_cond)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
# crop idx to the last block_size tokens
idx_cond = idx[:, -block_size:]
# get the predictions
logits, loss = self(idx_cond)
# focus only on the last time step
logits = logits[:, -1, :] # becomes (B, C)
# apply softmax to get probabilities
probs = F.softmax(logits, dim=-1) # (B, C)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
return idx

context = None
model = None

@dataclass
class TrainingContext:
model = nn.Model
context = torch.Tensor

@gtView
def gt_view_children(self, builder):
fwd = builder.forward()
fwd.title('Architecture')
fwd.priority(10)
fwd.object(self.model)
fwd.view('gt_view_children')
return fwd

@gtView
def gt_view_matrix(self, builder):
fwd = builder.forward()
fwd.title('Context')
fwd.priority(10)
fwd.object(self.context)
fwd.view('gt_view_matrix')
return fwd


def train(data, vocab_size):
global model
global context
m = BigramLanguageModel(vocab_size)
model = m.to(device)

Expand All @@ -193,20 +219,56 @@ def train(data, vocab_size):
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)

def generate_tokens(max):
return model.generate(context, max_new_tokens=max)[0].tolist()
return TrainingContext(model=model, context=context)


from gtoolkit_bridge import gtView
def generate_tokens(training_context max):
return training_context.model.generate(training_context.context, max_new_tokens=max)[0].tolist()


@gtView
def nn_gt_view_children(self, builder):
tree = builder.columned_tree()
tree.title('Children')
tree.priority(10)
tree.items(lambda: self.named_children())
tree.children(lambda item: item[1].named_children())
tree.column('Name', lambda each: each[0])
return tree
tree = builder.columnedTree()
tree.title('Children')
tree.priority(10)
tree.items(lambda: self.named_children())
tree.children(lambda item: item[1].named_children())
tree.column('Name', lambda each: each[0])
tree.column('Parameters', lambda each: [e[0] for e in each[1].named_parameters()])
tree.set_accessor(lambda each: each[1])
return tree

setattr(nn.Module, 'gt_view_children', nn_gt_view_children)

@gtView
def nn_gt_view_parameters(self, builder):
lst = builder.columnedList()
lst.title('Parameters')
lst.priority(15)
lst.items(lambda: self.named_parameters())
lst.column('Name', lambda each: each[0])
lst.column('Value', lambda each: each[1])
lst.set_accessor(lambda idx: list(self.named_parameters())[idx][1])
return lst

setattr(nn.Module, 'gt_view_parameters', nn_gt_view_parameters)


@gtView
def tensor_gt_view_matrix(self, builder):
if self.ndim == 0:
return builder.empty()

lst = builder.list()
lst.title('Data')
lst.priority(5)
lst.items(lambda: self)
return lst

setattr(torch.Tensor, 'gt_view_matrix',tensor_gt_view_matrix)


def tensor_flattened_list(self):
return [e.item() for e in torch.flatten(self)]

setattr(torch.Tensor, 'flattened_list', tensor_flattened_list)

0 comments on commit 29938e2

Please sign in to comment.