Skip to content

Commit

Permalink
xy_mlp_vae: pretrained checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Jul 7, 2024
1 parent db40733 commit eee60aa
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.safetensors filter=lfs diff=lfs merge=lfs -text
Binary file added data/xy_mlp_vae.safetensors
Binary file not shown.
246 changes: 246 additions & 0 deletions notebooks/traj_ae.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c5f64ca6-e967-4fed-bcfd-20f060b91ac9",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "4bf697d6-22e7-457d-93ab-454f6be7aa8e",
"metadata": {},
"outputs": [],
"source": [
"%autoreload 2\n",
"\n",
"import torch\n",
"from torch import nn, optim\n",
"from torchdrive.tasks.diff_traj import XYMLPEncoder\n",
"\n",
"\n",
"device = torch.device(\"cuda\")\n",
"MAX_DIST = 128\n",
"DIM = 1024\n",
"BS = 16000\n",
"m = XYMLPEncoder(dim=DIM, max_dist=MAX_DIST).to(device)\n",
"\n",
"optimizer = optim.AdamW(m.parameters(), lr=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "5c22065e-2df3-4e72-9420-14a1cabfc5eb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.5"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# theoretical\n",
"bucket_m = MAX_DIST*2/(DIM/2)\n",
"bucket_m"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "28275756-7185-4d6a-b129-e75a15909794",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 - magnitude 16.623794555664062, scale 8.333362579345703\n",
"0 - normal: loss 0.011047407984733582, mae 0.3838900923728943\n",
"0 - noise: loss 9.734607696533203, mae 110.0469741821289\n",
"0 batch 9.745655059814453\n",
"\n",
"100 - magnitude 16.00202178955078, scale 8.027247428894043\n",
"100 - normal: loss 0.01364201307296753, mae 0.3835380971431732\n",
"100 - noise: loss 9.665282249450684, mae 109.15914154052734\n",
"100 batch 9.678924560546875\n",
"\n",
"200 - magnitude 15.547680854797363, scale 7.71797513961792\n",
"200 - normal: loss 0.013021253049373627, mae 0.3811604678630829\n",
"200 - noise: loss 9.588726043701172, mae 108.6501693725586\n",
"200 batch 9.601747512817383\n",
"\n",
"300 - magnitude 15.0445556640625, scale 7.536011219024658\n",
"300 - normal: loss 0.013400120660662651, mae 0.3808329105377197\n",
"300 - noise: loss 9.600756645202637, mae 109.33332061767578\n",
"300 batch 9.614156723022461\n",
"\n",
"400 - magnitude 14.528416633605957, scale 7.243579387664795\n",
"400 - normal: loss 0.013808717019855976, mae 0.38157016038894653\n",
"400 - noise: loss 9.557799339294434, mae 110.33661651611328\n",
"400 batch 9.57160758972168\n",
"\n",
"500 - magnitude 13.987008094787598, scale 6.963778495788574\n",
"500 - normal: loss 0.014187517575919628, mae 0.38324692845344543\n",
"500 - noise: loss 9.548577308654785, mae 108.839111328125\n",
"500 batch 9.562765121459961\n",
"\n",
"600 - magnitude 13.436155319213867, scale 6.736222743988037\n",
"600 - normal: loss 0.01485936064273119, mae 0.3829646408557892\n",
"600 - noise: loss 9.564591407775879, mae 108.64825439453125\n",
"600 batch 9.579450607299805\n",
"\n",
"700 - magnitude 12.872819900512695, scale 6.411044120788574\n",
"700 - normal: loss 0.015887171030044556, mae 0.38363704085350037\n",
"700 - noise: loss 9.495373725891113, mae 109.53296661376953\n",
"700 batch 9.511260986328125\n",
"\n",
"800 - magnitude 12.251294136047363, scale 6.110291004180908\n",
"800 - normal: loss 0.01734350249171257, mae 0.3837227523326874\n",
"800 - noise: loss 9.465378761291504, mae 108.51081848144531\n",
"800 batch 9.482722282409668\n",
"\n",
"900 - magnitude 11.661234855651855, scale 5.795229911804199\n",
"900 - normal: loss 0.01856495998799801, mae 0.3846305012702942\n",
"900 - noise: loss 9.43662166595459, mae 109.53632354736328\n",
"900 batch 9.45518684387207\n",
"\n",
"1000 - magnitude 11.037848472595215, scale 5.4746012687683105\n",
"1000 - normal: loss 0.020276248455047607, mae 0.3821403682231903\n",
"1000 - noise: loss 9.37554931640625, mae 109.10408020019531\n",
"1000 batch 9.395825386047363\n",
"\n",
"1100 - magnitude 10.422453880310059, scale 5.2114338874816895\n",
"1100 - normal: loss 0.022657360881567, mae 0.3827858567237854\n",
"1100 - noise: loss 9.3831205368042, mae 108.37946319580078\n",
"1100 batch 9.405777931213379\n",
"\n",
"1200 - magnitude 9.819730758666992, scale 4.923969268798828\n",
"1200 - normal: loss 0.024695200845599174, mae 0.3811036944389343\n",
"1200 - noise: loss 9.379384994506836, mae 108.5765380859375\n",
"1200 batch 9.404080390930176\n",
"\n",
"1300 - magnitude 9.261908531188965, scale 4.6442646980285645\n",
"1300 - normal: loss 0.026335854083299637, mae 0.38449397683143616\n",
"1300 - noise: loss 9.365718841552734, mae 109.20559692382812\n",
"1300 batch 9.392054557800293\n",
"\n",
"1400 - magnitude 8.725689888000488, scale 4.39845609664917\n",
"1400 - normal: loss 0.028450658544898033, mae 0.3838055729866028\n",
"1400 - noise: loss 9.33014965057373, mae 109.24152374267578\n",
"1400 batch 9.358600616455078\n",
"\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[38], line 32\u001b[0m\n\u001b[1;32m 30\u001b[0m losses \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, enc \u001b[38;5;129;01min\u001b[39;00m variations\u001b[38;5;241m.\u001b[39mitems():\n\u001b[0;32m---> 32\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloss\u001b[49m\u001b[43m(\u001b[49m\u001b[43menc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 33\u001b[0m mae \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mlinalg\u001b[38;5;241m.\u001b[39mvector_norm(m\u001b[38;5;241m.\u001b[39mdecode(enc)\u001b[38;5;241m-\u001b[39mbatch, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mmean()\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m should_log:\n",
"File \u001b[0;32m~/torchdrive/torchdrive/tasks/diff_traj.py:340\u001b[0m, in \u001b[0;36mXYMLPEncoder.loss\u001b[0;34m(self, predicted, target)\u001b[0m\n\u001b[1;32m 338\u001b[0m target \u001b[38;5;241m=\u001b[39m target\u001b[38;5;241m.\u001b[39mpermute(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m, \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 339\u001b[0m emb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdecoder(predicted)\n\u001b[0;32m--> 340\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloss\u001b[49m\u001b[43m(\u001b[49m\u001b[43memb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/torchdrive/torchdrive/models/path.py:92\u001b[0m, in \u001b[0;36mXYEncoder.loss\u001b[0;34m(self, predicted, target)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mloss\u001b[39m(\u001b[38;5;28mself\u001b[39m, predicted: torch\u001b[38;5;241m.\u001b[39mTensor, target: torch\u001b[38;5;241m.\u001b[39mTensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m torch\u001b[38;5;241m.\u001b[39mTensor:\n\u001b[1;32m 91\u001b[0m x, y \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msplit_xy_one_hot(predicted)\n\u001b[0;32m---> 92\u001b[0m xl, yl \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode_labels\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mcross_entropy(x, xl, reduction\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;241m+\u001b[39m F\u001b[38;5;241m.\u001b[39mcross_entropy(\n\u001b[1;32m 94\u001b[0m y, yl, reduction\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 95\u001b[0m )\n",
"File \u001b[0;32m~/torchdrive/torchdrive/models/path.py:45\u001b[0m, in \u001b[0;36mXYEncoder.encode_labels\u001b[0;34m(self, xy)\u001b[0m\n\u001b[1;32m 36\u001b[0m y \u001b[38;5;241m=\u001b[39m xy[:, \u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 38\u001b[0m x \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 39\u001b[0m (((x \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_dist)) \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m0.5\u001b[39m) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_buckets)\n\u001b[1;32m 40\u001b[0m \u001b[38;5;241m.\u001b[39mlong()\n\u001b[1;32m 41\u001b[0m \u001b[38;5;241m.\u001b[39mclamp(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_buckets \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 42\u001b[0m )\n\u001b[1;32m 43\u001b[0m y \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 44\u001b[0m \u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43my\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m/\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_dist\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.5\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_buckets\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m---> 45\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlong\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;241m.\u001b[39mclamp(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_buckets \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 47\u001b[0m )\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x, y\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"BATCHES = 2000\n",
"LOG_EVERY = 100\n",
"\n",
"for i in range(BATCHES):\n",
" should_log = i % LOG_EVERY == 0\n",
" optimizer.zero_grad()\n",
" \n",
" batch = (torch.rand(BS, 1, 2, device=device) - 0.5) * (2 * MAX_DIST)\n",
"\n",
" encoded = m(batch)\n",
"\n",
" encoded_mag = torch.linalg.vector_norm(encoded, dim=-1).mean()\n",
" # scale by 0-1\n",
" noise_scale = encoded_mag * torch.rand(BS, device=device)\n",
" noise_scale = noise_scale.unsqueeze(1).unsqueeze(1)\n",
" \n",
" noise = torch.randn_like(encoded)\n",
" noise = noise * noise_scale\n",
"\n",
" \n",
" if should_log:\n",
" print(f\"{i} - magnitude {encoded_mag}, scale {noise_scale.mean().item()}\")\n",
"\n",
" noise_encoded = encoded + noise\n",
" \n",
" variations = {\n",
" \"normal\": encoded,\n",
" \"noise\": noise_encoded,\n",
" }\n",
" losses = {}\n",
" for k, enc in variations.items():\n",
" loss = m.loss(enc, batch).mean()\n",
" mae = torch.linalg.vector_norm(m.decode(enc)-batch, dim=-1).mean()\n",
" if should_log:\n",
" print(f\"{i} - {k}: loss {loss.item()}, mae {mae.item()}\")\n",
"\n",
" losses[k] = loss\n",
"\n",
" total_loss = sum(losses.values())\n",
" total_loss.backward()\n",
"\n",
" optimizer.step()\n",
"\n",
" if should_log:\n",
" print(i, \"batch\", total_loss.item())\n",
" print()\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "5a3cdaa7-2a73-4b7c-b388-35989e0ff56a",
"metadata": {},
"outputs": [],
"source": [
"from safetensors.torch import save_file\n",
"\n",
"save_file(m.state_dict(), \"xy_mlp_vae.safetensors\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "torchdrive",
"language": "python",
"name": "torchdrive"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit eee60aa

Please sign in to comment.