diff --git a/python/nanogpt/__init__.py b/python/nanogpt/__init__.py index f32b9c30..a17aff94 100644 --- a/python/nanogpt/__init__.py +++ b/python/nanogpt/__init__.py @@ -1,7 +1,14 @@ +from dataclasses import dataclass + import torch import torch.nn as nn from torch.nn import functional as F +from gtoolkit_bridge import gtView +import torch.nn as nn +import torch + + # hyperparameters batch_size = 16 # how many independent sequences will we process in parallel? block_size = 32 # what is the maximum context length for predictions? @@ -33,9 +40,9 @@ def estimate_loss(): for split in ['train', 'val']: losses = torch.zeros(eval_iters) for k in range(eval_iters): - X, Y = get_batch(split) - logits, loss = model(X, Y) - losses[k] = loss.item() + X, Y = get_batch(split) + logits, loss = model(X, Y) + losses[k] = loss.item() out[split] = losses.mean() model.train() return out @@ -86,10 +93,10 @@ class FeedFoward(nn.Module): def __init__(self, n_embd): super().__init__() self.net = nn.Sequential( - nn.Linear(n_embd, 4 * n_embd), - nn.ReLU(), - nn.Linear(4 * n_embd, n_embd), - nn.Dropout(dropout), + nn.Linear(n_embd, 4 * n_embd), + nn.ReLU(), + nn.Linear(4 * n_embd, n_embd), + nn.Dropout(dropout), ) def forward(self, x): @@ -136,38 +143,57 @@ def forward(self, idx, targets=None): logits = self.lm_head(x) # (B,T,vocab_size) if targets is None: - loss = None + loss = None else: - B, T, C = logits.shape - logits = logits.view(B*T, C) - targets = targets.view(B*T) - loss = F.cross_entropy(logits, targets) + B, T, C = logits.shape + logits = logits.view(B*T, C) + targets = targets.view(B*T) + loss = F.cross_entropy(logits, targets) return logits, loss def generate(self, idx, max_new_tokens): # idx is (B, T) array of indices in the current context for _ in range(max_new_tokens): - # crop idx to the last block_size tokens - idx_cond = idx[:, -block_size:] - # get the predictions - logits, loss = self(idx_cond) - # focus only on the last time step - logits = logits[:, -1, :] # becomes (B, C) - # apply softmax to get probabilities - probs = F.softmax(logits, dim=-1) # (B, C) - # sample from the distribution - idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) - # append sampled index to the running sequence - idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) + # crop idx to the last block_size tokens + idx_cond = idx[:, -block_size:] + # get the predictions + logits, loss = self(idx_cond) + # focus only on the last time step + logits = logits[:, -1, :] # becomes (B, C) + # apply softmax to get probabilities + probs = F.softmax(logits, dim=-1) # (B, C) + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) + # append sampled index to the running sequence + idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) return idx -context = None -model = None - +@dataclass +class TrainingContext: + model = nn.Model + context = torch.Tensor + + @gtView + def gt_view_children(self, builder): + fwd = builder.forward() + fwd.title('Architecture') + fwd.priority(10) + fwd.object(self.model) + fwd.view('gt_view_children') + return fwd + + @gtView + def gt_view_matrix(self, builder): + fwd = builder.forward() + fwd.title('Context') + fwd.priority(10) + fwd.object(self.context) + fwd.view('gt_view_matrix') + return fwd + + def train(data, vocab_size): - global model - global context m = BigramLanguageModel(vocab_size) model = m.to(device) @@ -193,20 +219,56 @@ def train(data, vocab_size): # generate from the model context = torch.zeros((1, 1), dtype=torch.long, device=device) -def generate_tokens(max): - return model.generate(context, max_new_tokens=max)[0].tolist() + return TrainingContext(model=model, context=context) -from gtoolkit_bridge import gtView +def generate_tokens(training_context max): + return training_context.model.generate(training_context.context, max_new_tokens=max)[0].tolist() + @gtView def nn_gt_view_children(self, builder): - tree = builder.columned_tree() - tree.title('Children') - tree.priority(10) - tree.items(lambda: self.named_children()) - tree.children(lambda item: item[1].named_children()) - tree.column('Name', lambda each: each[0]) - return tree + tree = builder.columnedTree() + tree.title('Children') + tree.priority(10) + tree.items(lambda: self.named_children()) + tree.children(lambda item: item[1].named_children()) + tree.column('Name', lambda each: each[0]) + tree.column('Parameters', lambda each: [e[0] for e in each[1].named_parameters()]) + tree.set_accessor(lambda each: each[1]) + return tree setattr(nn.Module, 'gt_view_children', nn_gt_view_children) + +@gtView +def nn_gt_view_parameters(self, builder): + lst = builder.columnedList() + lst.title('Parameters') + lst.priority(15) + lst.items(lambda: self.named_parameters()) + lst.column('Name', lambda each: each[0]) + lst.column('Value', lambda each: each[1]) + lst.set_accessor(lambda idx: list(self.named_parameters())[idx][1]) + return lst + +setattr(nn.Module, 'gt_view_parameters', nn_gt_view_parameters) + + +@gtView +def tensor_gt_view_matrix(self, builder): + if self.ndim == 0: + return builder.empty() + + lst = builder.list() + lst.title('Data') + lst.priority(5) + lst.items(lambda: self) + return lst + +setattr(torch.Tensor, 'gt_view_matrix',tensor_gt_view_matrix) + + +def tensor_flattened_list(self): + return [e.item() for e in torch.flatten(self)] + +setattr(torch.Tensor, 'flattened_list', tensor_flattened_list)