-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
247 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.safetensors filter=lfs diff=lfs merge=lfs -text |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "c5f64ca6-e967-4fed-bcfd-20f060b91ac9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 23, | ||
"id": "4bf697d6-22e7-457d-93ab-454f6be7aa8e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%autoreload 2\n", | ||
"\n", | ||
"import torch\n", | ||
"from torch import nn, optim\n", | ||
"from torchdrive.tasks.diff_traj import XYMLPEncoder\n", | ||
"\n", | ||
"\n", | ||
"device = torch.device(\"cuda\")\n", | ||
"MAX_DIST = 128\n", | ||
"DIM = 1024\n", | ||
"BS = 16000\n", | ||
"m = XYMLPEncoder(dim=DIM, max_dist=MAX_DIST).to(device)\n", | ||
"\n", | ||
"optimizer = optim.AdamW(m.parameters(), lr=1e-4)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 27, | ||
"id": "5c22065e-2df3-4e72-9420-14a1cabfc5eb", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"0.5" | ||
] | ||
}, | ||
"execution_count": 27, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# theoretical\n", | ||
"bucket_m = MAX_DIST*2/(DIM/2)\n", | ||
"bucket_m" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 38, | ||
"id": "28275756-7185-4d6a-b129-e75a15909794", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"0 - magnitude 16.623794555664062, scale 8.333362579345703\n", | ||
"0 - normal: loss 0.011047407984733582, mae 0.3838900923728943\n", | ||
"0 - noise: loss 9.734607696533203, mae 110.0469741821289\n", | ||
"0 batch 9.745655059814453\n", | ||
"\n", | ||
"100 - magnitude 16.00202178955078, scale 8.027247428894043\n", | ||
"100 - normal: loss 0.01364201307296753, mae 0.3835380971431732\n", | ||
"100 - noise: loss 9.665282249450684, mae 109.15914154052734\n", | ||
"100 batch 9.678924560546875\n", | ||
"\n", | ||
"200 - magnitude 15.547680854797363, scale 7.71797513961792\n", | ||
"200 - normal: loss 0.013021253049373627, mae 0.3811604678630829\n", | ||
"200 - noise: loss 9.588726043701172, mae 108.6501693725586\n", | ||
"200 batch 9.601747512817383\n", | ||
"\n", | ||
"300 - magnitude 15.0445556640625, scale 7.536011219024658\n", | ||
"300 - normal: loss 0.013400120660662651, mae 0.3808329105377197\n", | ||
"300 - noise: loss 9.600756645202637, mae 109.33332061767578\n", | ||
"300 batch 9.614156723022461\n", | ||
"\n", | ||
"400 - magnitude 14.528416633605957, scale 7.243579387664795\n", | ||
"400 - normal: loss 0.013808717019855976, mae 0.38157016038894653\n", | ||
"400 - noise: loss 9.557799339294434, mae 110.33661651611328\n", | ||
"400 batch 9.57160758972168\n", | ||
"\n", | ||
"500 - magnitude 13.987008094787598, scale 6.963778495788574\n", | ||
"500 - normal: loss 0.014187517575919628, mae 0.38324692845344543\n", | ||
"500 - noise: loss 9.548577308654785, mae 108.839111328125\n", | ||
"500 batch 9.562765121459961\n", | ||
"\n", | ||
"600 - magnitude 13.436155319213867, scale 6.736222743988037\n", | ||
"600 - normal: loss 0.01485936064273119, mae 0.3829646408557892\n", | ||
"600 - noise: loss 9.564591407775879, mae 108.64825439453125\n", | ||
"600 batch 9.579450607299805\n", | ||
"\n", | ||
"700 - magnitude 12.872819900512695, scale 6.411044120788574\n", | ||
"700 - normal: loss 0.015887171030044556, mae 0.38363704085350037\n", | ||
"700 - noise: loss 9.495373725891113, mae 109.53296661376953\n", | ||
"700 batch 9.511260986328125\n", | ||
"\n", | ||
"800 - magnitude 12.251294136047363, scale 6.110291004180908\n", | ||
"800 - normal: loss 0.01734350249171257, mae 0.3837227523326874\n", | ||
"800 - noise: loss 9.465378761291504, mae 108.51081848144531\n", | ||
"800 batch 9.482722282409668\n", | ||
"\n", | ||
"900 - magnitude 11.661234855651855, scale 5.795229911804199\n", | ||
"900 - normal: loss 0.01856495998799801, mae 0.3846305012702942\n", | ||
"900 - noise: loss 9.43662166595459, mae 109.53632354736328\n", | ||
"900 batch 9.45518684387207\n", | ||
"\n", | ||
"1000 - magnitude 11.037848472595215, scale 5.4746012687683105\n", | ||
"1000 - normal: loss 0.020276248455047607, mae 0.3821403682231903\n", | ||
"1000 - noise: loss 9.37554931640625, mae 109.10408020019531\n", | ||
"1000 batch 9.395825386047363\n", | ||
"\n", | ||
"1100 - magnitude 10.422453880310059, scale 5.2114338874816895\n", | ||
"1100 - normal: loss 0.022657360881567, mae 0.3827858567237854\n", | ||
"1100 - noise: loss 9.3831205368042, mae 108.37946319580078\n", | ||
"1100 batch 9.405777931213379\n", | ||
"\n", | ||
"1200 - magnitude 9.819730758666992, scale 4.923969268798828\n", | ||
"1200 - normal: loss 0.024695200845599174, mae 0.3811036944389343\n", | ||
"1200 - noise: loss 9.379384994506836, mae 108.5765380859375\n", | ||
"1200 batch 9.404080390930176\n", | ||
"\n", | ||
"1300 - magnitude 9.261908531188965, scale 4.6442646980285645\n", | ||
"1300 - normal: loss 0.026335854083299637, mae 0.38449397683143616\n", | ||
"1300 - noise: loss 9.365718841552734, mae 109.20559692382812\n", | ||
"1300 batch 9.392054557800293\n", | ||
"\n", | ||
"1400 - magnitude 8.725689888000488, scale 4.39845609664917\n", | ||
"1400 - normal: loss 0.028450658544898033, mae 0.3838055729866028\n", | ||
"1400 - noise: loss 9.33014965057373, mae 109.24152374267578\n", | ||
"1400 batch 9.358600616455078\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"ename": "KeyboardInterrupt", | ||
"evalue": "", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[38], line 32\u001b[0m\n\u001b[1;32m 30\u001b[0m losses \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, enc \u001b[38;5;129;01min\u001b[39;00m variations\u001b[38;5;241m.\u001b[39mitems():\n\u001b[0;32m---> 32\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloss\u001b[49m\u001b[43m(\u001b[49m\u001b[43menc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 33\u001b[0m mae \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mlinalg\u001b[38;5;241m.\u001b[39mvector_norm(m\u001b[38;5;241m.\u001b[39mdecode(enc)\u001b[38;5;241m-\u001b[39mbatch, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m should_log:\n", | ||
"File \u001b[0;32m~/torchdrive/torchdrive/tasks/diff_traj.py:340\u001b[0m, in \u001b[0;36mXYMLPEncoder.loss\u001b[0;34m(self, predicted, target)\u001b[0m\n\u001b[1;32m 338\u001b[0m target \u001b[38;5;241m=\u001b[39m target\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 339\u001b[0m emb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoder(predicted)\n\u001b[0;32m--> 340\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloss\u001b[49m\u001b[43m(\u001b[49m\u001b[43memb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m)\u001b[49m\n", | ||
"File \u001b[0;32m~/torchdrive/torchdrive/models/path.py:92\u001b[0m, in \u001b[0;36mXYEncoder.loss\u001b[0;34m(self, predicted, target)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mloss\u001b[39m(\u001b[38;5;28mself\u001b[39m, predicted: torch\u001b[38;5;241m.\u001b[39mTensor, target: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[1;32m 91\u001b[0m x, y \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msplit_xy_one_hot(predicted)\n\u001b[0;32m---> 92\u001b[0m xl, yl \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode_labels\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mcross_entropy(x, xl, reduction\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;241m+\u001b[39m F\u001b[38;5;241m.\u001b[39mcross_entropy(\n\u001b[1;32m 94\u001b[0m y, yl, reduction\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 95\u001b[0m )\n", | ||
"File \u001b[0;32m~/torchdrive/torchdrive/models/path.py:45\u001b[0m, in \u001b[0;36mXYEncoder.encode_labels\u001b[0;34m(self, xy)\u001b[0m\n\u001b[1;32m 36\u001b[0m y \u001b[38;5;241m=\u001b[39m xy[:, \u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 38\u001b[0m x \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 39\u001b[0m (((x \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_dist)) \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m0.5\u001b[39m) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_buckets)\n\u001b[1;32m 40\u001b[0m \u001b[38;5;241m.\u001b[39mlong()\n\u001b[1;32m 41\u001b[0m \u001b[38;5;241m.\u001b[39mclamp(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_buckets \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 42\u001b[0m )\n\u001b[1;32m 43\u001b[0m y \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 44\u001b[0m \u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43my\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m/\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_dist\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.5\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_buckets\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m---> 45\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlong\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;241m.\u001b[39mclamp(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_buckets \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 47\u001b[0m )\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x, y\n", | ||
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | ||
] | ||
} | ||
], | ||
"source": [ | ||
"BATCHES = 2000\n", | ||
"LOG_EVERY = 100\n", | ||
"\n", | ||
"for i in range(BATCHES):\n", | ||
" should_log = i % LOG_EVERY == 0\n", | ||
" optimizer.zero_grad()\n", | ||
" \n", | ||
" batch = (torch.rand(BS, 1, 2, device=device) - 0.5) * (2 * MAX_DIST)\n", | ||
"\n", | ||
" encoded = m(batch)\n", | ||
"\n", | ||
" encoded_mag = torch.linalg.vector_norm(encoded, dim=-1).mean()\n", | ||
" # scale by 0-1\n", | ||
" noise_scale = encoded_mag * torch.rand(BS, device=device)\n", | ||
" noise_scale = noise_scale.unsqueeze(1).unsqueeze(1)\n", | ||
" \n", | ||
" noise = torch.randn_like(encoded)\n", | ||
" noise = noise * noise_scale\n", | ||
"\n", | ||
" \n", | ||
" if should_log:\n", | ||
" print(f\"{i} - magnitude {encoded_mag}, scale {noise_scale.mean().item()}\")\n", | ||
"\n", | ||
" noise_encoded = encoded + noise\n", | ||
" \n", | ||
" variations = {\n", | ||
" \"normal\": encoded,\n", | ||
" \"noise\": noise_encoded,\n", | ||
" }\n", | ||
" losses = {}\n", | ||
" for k, enc in variations.items():\n", | ||
" loss = m.loss(enc, batch).mean()\n", | ||
" mae = torch.linalg.vector_norm(m.decode(enc)-batch, dim=-1).mean()\n", | ||
" if should_log:\n", | ||
" print(f\"{i} - {k}: loss {loss.item()}, mae {mae.item()}\")\n", | ||
"\n", | ||
" losses[k] = loss\n", | ||
"\n", | ||
" total_loss = sum(losses.values())\n", | ||
" total_loss.backward()\n", | ||
"\n", | ||
" optimizer.step()\n", | ||
"\n", | ||
" if should_log:\n", | ||
" print(i, \"batch\", total_loss.item())\n", | ||
" print()\n", | ||
" \n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 41, | ||
"id": "5a3cdaa7-2a73-4b7c-b388-35989e0ff56a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from safetensors.torch import save_file\n", | ||
"\n", | ||
"save_file(m.state_dict(), \"xy_mlp_vae.safetensors\")" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "torchdrive", | ||
"language": "python", | ||
"name": "torchdrive" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |