diff --git a/notebooks/CTC_with_padding.ipynb b/notebooks/CTC_with_padding.ipynb new file mode 100644 index 00000000..9727f85d --- /dev/null +++ b/notebooks/CTC_with_padding.ipynb @@ -0,0 +1,391 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "CTC.ipynb", + "provenance": [], + "collapsed_sections": [], + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "1A6gJZcOVq-h", + "colab_type": "code", + "outputId": "0467398c-3ac6-437a-8a42-cff49d155d2f", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 53 + } + }, + "source": [ + "!pip install -qU git+https://github.com/harvardnlp/pytorch-struct\n", + "!pip install -qU git+https://github.com/harvardnlp/genbmm\n", + "!pip install -q matplotlib" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + " Building wheel for torch-struct (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for genbmm (setup.py) ... \u001b[?25l\u001b[?25hdone\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "r070BEdwVzHs", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import torch\n", + "import torch_struct\n", + "import matplotlib.pyplot as plt" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "zxRC_exrbTR4", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Character Vocab, P is a padding token\n", + "vocab = [\"a\", \"b\", \"c\", \"d\", \"e\", \"_\", \"P\"]\n", + "v_dict = { a:i for i, a in enumerate(vocab)}\n", + "L = len(vocab)\n", + "\n", + "# Char sequence\n", + "letters = [[\"a\", \"_\", \"b\", \"_\", \"c\", \"_\", \"d\", \"_\", \"e\"], \n", + " [\"a\", \"_\", \"b\", \"_\", \"c\", \"P\", \"P\", \"P\", \"P\"]]\n", + "#letters = [[\"a\", \"_\", \"b\", \"_\", \"c\"]]\n", + "\n", + "t = len(letters[0])\n", + "\n", + "# Padding\n", + "\n", + "# Frame sequence\n", + "frames = [[\"a\", \"a\", \"a\", \"_\", \"b\", \"b\", \"c\", \"c\", \"c\", \"c\", \"_\", \"_\", \"d\", \"e\"], \n", + " [\"a\", \"a\", \"a\", \"a\", \"a\", \"_\", \"b\", \"b\", \"c\", \"c\", \"c\", \"c\", \"P\", \"P\"]] \n", + "#frames = [[\"a\", \"a\", \"a\", \"a\", \"a\", \"_\", \"b\", \"b\", \"c\", \"c\", \"c\", \"c\"]] \n", + "\n", + "# Constants\n", + "T, B = len(frames[0]), len(frames)\n", + "D1, MATCH, D2 = 0, 1, 2\n", + "\n", + "batch_lengths = [(t, T), (t-4, T-2)]\n", + "#batch_lengths = [(t, T)]\n", + "\n", + "def show(m, ex):\n", + " plt.yticks(torch.arange(len(letters[ex])), letters[ex])\n", + " plt.xticks(torch.arange(T), [str(frames[ex][x.item()]) for x in torch.arange(T)])\n", + " plt.imshow(m[ex].cpu().detach())" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "rDtLyNKPa1N0", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Gold alignment. \n", + "gold = torch.zeros(B, t).long()\n", + "for b in range(B):\n", + " for i, l in enumerate(letters[0]):\n", + " gold[b, i] = v_dict[l]\n", + "gold = gold[:, None, :].expand(B, T, t)\n", + "\n", + "# Inputs (boost true frames a bit)\n", + "logits = torch.zeros(B, T, L)\n", + "for b in range(B):\n", + " for i in range(T):\n", + " logits[b, i, v_dict[frames[b][i]]] += 1 " + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "il4pW6M9YOKP", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Construct the alignment problem from CTC\n", + "\n", + "# Log-Potentials\n", + "log_potentials = torch.zeros(B, T, t, 3).fill_(-1e5)\n", + "\n", + "# Match gold to logits. \n", + "match = torch.gather(logits, 2, gold)\n", + "\n", + "# CTC Rules:\n", + "for b, (lb, la) in zip(range(B), batch_lengths):\n", + " # la and lb are the sizes of the two (without padding)\n", + "\n", + " # Never allowed to fully skip regular characters (little t)\n", + " log_potentials[b, :la, :lb:2, D2] = -1e5\n", + "\n", + " # Free to skip _ characters (little t)\n", + " log_potentials[b, :la, 1:lb:2, D2] = 0\n", + "\n", + " # First match with character is the logit. \n", + " log_potentials[b, :la, :lb, MATCH] = match[b, :la, :lb]\n", + "\n", + " # Additional match with character is the logit.\n", + " log_potentials[b, :la, :lb, D1] = match[b, :la, :lb]\n", + "\n", + " # Match padding in an L shape\n", + " log_potentials[b, la:, lb-1, D1] = 0\n", + " log_potentials[b, -1, lb:, D2] = 0\n", + "\n", + "\n", + "log_potentials = log_potentials.transpose(1, 2).cuda()" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "Hz5I8cnLfpHE", + "colab_type": "code", + "outputId": "cd673207-eb68-4964-be96-3886800ade75", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 263 + } + }, + "source": [ + "# Show input scores\n", + "show(match.transpose(1,2).exp(), 1)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD2CAYAAAD/C81vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAKH0lEQVR4nO3dz4udh3XH4e9pRlgSxiGyvREo9cKh\nLjFUi1nEbRYhDgj/AdlkEZcuBreBrLrsNpBVlw0MKBCSrOxCuougDg1KILQTIrs2Il41DWgT/4C2\nSMhGnC7m2hkmE89Vcu/MmZnnAaMR9+pwbA8fXr3z3vet7g4Ac/3JcS8AwMcTaoDhhBpgOKEGGE6o\nAYYTaoDhNlY98IlLn+inrpxb9ViGeOv1i8e9ApxK/5v33u7uJw96beWhfurKufz7jSurHssQ1y5f\nPe4V4FT6137lV7/vNac+AIYTaoDhhBpgOKEGGE6oAYYTaoDhhBpgOKEGGG6pUFfVD6rq51X1ZlVt\nrXspAH5r2U8m/k13v1tVF5L8R1X9c3e/s87FANi17KmPr1fVa0l+luRKks/sfbGqtqpqp6p2fvPO\ng1XvCHCmHRrqqvpCki8lea67/yLJL5Kc3/ue7t7u7s3u3nzy8U+sZVGAs2qZI+pPJnmvu+9W1TNJ\nPrfmnQDYY5lQ/zDJRlXdTvLN7J7+AOCIHPrDxO6+n+SFI9gFgAO4jhpgOKEGGE6oAYYTaoDhhBpg\nOKEGGG7lTyF/6/WLnlR9it24c2ut833vwO9yRA0wnFADDCfUAMMJNcBwQg0wnFADDCfUAMMJNcBw\nQg0wnFADDCfUAMOt5F4fVbWVZCtJzufiKkYCsLCSI+ru3u7uze7ePJdHVjESgAWnPgCGE2qA4YQa\nYDihBhhOqAGGO/TyvKp6PMmrB7z0fHe/s/qVANjr0FAvYuxBdgDHxKkPgOGEGmA4oQYYbiX3+uDs\nuHZ5vT+uuHHn1lrnr9O6/9twdjmiBhhOqAGGE2qA4YQaYDihBhhOqAGGE2qA4YQaYLhDQ11VT1XV\nG0exDAC/yxE1wHDLhnqjqr5fVber6pWqurjWrQD4yLKh/rMk/9Tdf57kf5L83d4Xq2qrqnaqaueD\n3F/1jgBn2rKh/nV3/3Tx9feSfH7vi9293d2b3b15Lo+sdEGAs27ZUPchvwdgTZYN9aer6rnF119J\n8pM17QPAPsuG+pdJvlZVt5N8Ksm31rcSAHst83Db/0ryzPpXAeAgrqMGGE6oAYYTaoDhhBpgOKEG\nGE6oAYY79PI82OvGnVtrnX/t8tW1zoeTyBE1wHBCDTCcUAMMJ9QAwwk1wHBCDTCcUAMMJ9QAwwk1\nwHBCDTCcUAMMt5J7fVTVVpKtJDmfi6sYCcDCSo6ou3u7uze7e/NcHlnFSAAWnPoAGE6oAYYTaoDh\nhBpguEOv+qiqx5O8esBLz3f3O6tfCYC9Dg31IsaejwRwTJz6ABhOqAGGE2qA4YQaYLiV3OuDs+Pa\nZT9X/n1u3Ll13Cucaaf5e9MRNcBwQg0wnFADDCfUAMMJNcBwQg0wnFADDCfUAMMJNcBwQg0w3FKh\nrqqvVtXrVfVaVX133UsB8FvLPOHls0n+IclfdvfbVXXpgPdsJdlKkvO5uPIlAc6yZY6ov5jk5e5+\nO0m6+939b+ju7e7e7O7Nc3lk1TsCnGnOUQMMt0yof5Tky4uH3OagUx8ArM8yD7d9s6q+keTHVfUg\nyS+S/PW6FwNg11IPDuju7yT5zpp3AeAAzlEDDCfUAMMJNcBwQg0wnFADDCfUAMMtdXkefOjGnVtr\nnX/t8tW1zl+nk7w7szmiBhhOqAGGE2qA4YQaYDihBhhOqAGGE2qA4YQaYLhlHm77IMl/Lt57O8mL\n3X133YsBsGuZI+p73X21u59N8n6Sl9a8EwB7POypj5tJnl7HIgAcbOlQV9VGkheyexpk/2tbVbVT\nVTsf5P4q9wM485a5KdOFqvrwTjw3k1zf/4bu3k6ynSSP1aVe3XoALBPqe93ttmAAx8TleQDDCTXA\ncIeGursfPYpFADiYI2qA4YQaYDihBhhOqAGGE2qA4YQaYLhlPpkIH7l22YdU4ag5ogYYTqgBhhNq\ngOGEGmA4oQYYTqgBhhNqgOGEGmC4Qz/wUlUPsvtA240kt5O82N13170YALuWOaK+191Xu/vZJO8n\neWnNOwGwx8Oe+riZ5Ol1LALAwZYOdVVtJHkhu6dB9r+2VVU7VbXzQe6vcj+AM2+ZmzJdqKpbi69v\nJrm+/w3dvZ1kO0keq0u9uvUAWCbU97rbLdMAjonL8wCGE2qA4Q4NdXc/ehSLAHAwR9QAwwk1wHBC\nDTCcUAMMJ9QAwwk1wHDLfDIRPnLjzq3D3/RHuHbZh2BhP0fUAMMJNcBwQg0wnFADDCfUAMMJNcBw\nQg0wnFADDHfoB16q6kF2H2i7keR2khe7++66FwNg1zJH1Pe6+2p3P5vk/SQvrXknAPZ42FMfN5M8\nvY5FADjY0qGuqo0kL2T3NMj+17aqaqeqdj7I/VXuB3DmLXNTpgtV9eGdeG4mub7/Dd29nWQ7SR6r\nS7269QBYJtT3utstzQCOicvzAIYTaoDhDg11dz96FIsAcDBH1ADDCTXAcEINMJxQAwwn1ADDCTXA\ncMt8MhE+cu2yD6nCUXNEDTCcUAMMJ9QAwwk1wHBCDTCcUAMMJ9QAwx16HXVVPcjucxI3ktxO8mJ3\n3133YgDsWuaI+l53X+3uZ5O8n+SlNe8EwB4Pe+rjZpKn17EIAAdbOtRVtZHkheyeBgHgiCxzr48L\nVXVr8fXNJNf3v6GqtpJsJcn5XFzddgAsFep73f2xd+Lp7u0k20nyWF3qVSwGwC6X5wEMJ9QAwx0a\n6u5+9CgWAeBgjqgBhhNqgOGEGmA4oQYYTqgBhhNqgOGEGmC46l7tJ76r6jdJfvUQf+SJJG+vdAnz\nzTf/NMw/ybv/IfP/tLufPOiFlYf6YVXVTndvmm+++eYf1eyTNt+pD4DhhBpguAmh3jbffPPNP+LZ\nJ2r+sZ+jBuDjTTiiBuBjCDUnUlU9VVVvnNT5HJ+qelBVt6rqjap6uapW+vzAdcwXauCsudfdV7v7\n2STvJ3lp+vxjDXVV/aCqfl5Vby4ekGv+KZp/BDaq6vtVdbuqXln1kdE651fVV6vq9ap6raq+u6q5\n5j+0m0meHj+/u4/tnySXFr9eSPJGksfNPz3z1/y981SSTvJXi99/O8nfn4T5ST6b5K0kT+z9/7DC\n3c3/+Pn/t/h1I8m/JPnb6fOP+9TH16vqtSQ/S3IlyWfMP1Xz1+3X3f3TxdffS/L5EzL/i0le7u63\nk6S7313RXPOXc6GqbiXZSfLfSa5Pn7/xR6/0B6qqLyT5UpLnuvtuVf1bkvPmn475R2T/taWrvtZ0\n3fM5Hve6++pJmn+cR9SfTPLeIhLPJPmc+adq/lH4dFU9t/j6K0l+ckLm/yjJl6vq8SSpqksrmmv+\nKXWcof5hdn9YczvJN7P712/zT8/8o/DLJF9b/Dt8Ksm3TsL87n4zyTeS/Hhx6ukfVzHX/NPLJxMB\nhju2c9ScXou/sr56wEvPd/c7R70PnHSOqAGGO+7L8wA4hFADDCfUAMMJNcBwQg0wnFADDPf/KfFi\nycGaaRoAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "psn-rlomc3CY", + "colab_type": "code", + "outputId": "c3c01c84-dd3f-4d87-f7a4-3a1292080b13", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 263 + } + }, + "source": [ + "show(log_potentials[:, :, :, 1].exp(), 1)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD2CAYAAAD/C81vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAKIUlEQVR4nO3dz6vld33H8de7cwMZk0YySaCb2Cwi\nTTGQu7igsS7ECCV/gBsXpnRxSSO4KmTTrWA2BTcRLqQg6iop2J0FI5XRIu0NztgMg13VupGaH9A2\nc0nC8O7iHurteJp7xpxz73vufTwgzBnOyZt3ksmT73zu9863ujsAzPV7p70AAB9MqAGGE2qA4YQa\nYDihBhhOqAGG21r3wAcvXehHHr5r3WMZ4tqvHjrtFWCpu/7jndNe4UP5r7z9Rncv/R9s7aF+5OG7\n8k9///C6xzLEEy88d9orwFJ/8PV/PO0VPpTv9yu/+P/ec/QBMJxQAwwn1ADDCTXAcEINMJxQAwwn\n1ADDCTXAcCuFuqq+W1WvVdW1qtrd9FIA/Maq35n45939VlVdTPLPVfW33f3mJhcD4NCqRx9fqaqr\nSX6S5OEkHz/6ZlXtVtV+Ve3/+s2b694R4Fw7NtRV9dkkn0/yZHc/keSnSe4++pnu3uvune7eeeiB\nCxtZFOC8WuWK+qNJ3u7uG1X1WJJPbXgnAI5YJdTfS7JVVdeTfC2Hxx8AnJBjv5jY3e8mefoEdgFg\nCfdRAwwn1ADDCTXAcEINMJxQAwwn1ADDrf0p5Nd+9ZAnVZ9hV59/caPz/dqB3+aKGmA4oQYYTqgB\nhhNqgOGEGmA4oQYYTqgBhhNqgOGEGmA4oQYYTqgBhltLqKtqt6r2q2r/5sE76xgJwMJaQt3de929\n0907Fy7es46RACw4+gAYTqgBhhNqgOGEGmA4oQYY7thHcVXVA0leXfLWU9395vpXAuCoY0O9iPH2\nCewCwBKOPgCGE2qA4YQaYLhjz6jhqCdeeG6j868+/+JG52/Spv/dcH65ogYYTqgBhhNqgOGEGmA4\noQYYTqgBhhNqgOGEGmC4Y0NdVY9U1esnsQwAv80VNcBwq4Z6q6q+U1XXq+qVqvrIRrcC4H+tGuo/\nSvJid/9xkv9M8n/+UIOq2q2q/arav3nwzrp3BDjXVg31L7v7x4vX307ymaNvdvded+90986Fi/es\ndUGA827VUPcxPwdgQ1YN9ceq6snF6y8m+dGG9gHgFquG+udJvlxV15Pcn+Qbm1sJgKNWebjtvyV5\nbPOrALCM+6gBhhNqgOGEGmA4oQYYTqgBhhNqgOGOvT0Pjrr6/Isbnf/EC88d/yE4Z1xRAwwn1ADD\nCTXAcEINMJxQAwwn1ADDCTXAcEINMJxQAwwn1ADDCTXAcGsJdVXtVtV+Ve3fPHhnHSMBWFhLqLt7\nr7t3unvnwsV71jESgAVHHwDDCTXAcEINMJxQAwx37BNequqBJK8ueeup7n5z/SsBcNSxoV7EePsE\ndgFgCUcfAMMJNcBwQg0wnFADDHfsFxPhqCdeeO60Vxjr6vMvnvYK59qffv3s3vPgihpgOKEGGE6o\nAYYTaoDhhBpgOKEGGE6oAYYTaoDhhBpgOKEGGG6lUFfVl6rqZ1V1taq+temlAPiNVZ7w8okkf5Xk\n0939RlVdWvKZ3SS7SXLX79+/9iUBzrNVrqg/l+Tl7n4jSbr7rVs/0N173b3T3TsXLt6z7h0BzjVn\n1ADDrRLqHyT5wuIht1l29AHA5qzycNtrVfXVJD+sqptJfprkzza9GACHVnpwQHd/M8k3N7wLAEs4\nowYYTqgBhhNqgOGEGmA4oQYYTqgBhqvuXuvA++pSf7KeWutMgLPu+/3Ka929s+w9V9QAwwk1wHBC\nDTCcUAMMJ9QAwwk1wHBCDTCcUAMMt8rDbW8m+ZfFZ68neaa7b2x6MQAOrXJFfdDd2939eJL3kjy7\n4Z0AOOJ2jz4uJ3l0E4sAsNzKoa6qrSRP5/AY5Nb3dqtqv6r238+769wP4Nxb5ZmJF6vqyuL15SQv\n3fqB7t5Lspcc/qFM61sPgFVCfdDd2xvfBICl3J4HMJxQAwx3bKi7+96TWASA5VxRAwwn1ADDCTXA\ncEINMJxQAwwn1ADDCTXAcEINMJxQAwwn1ADDCTXAcEINMJxQAwwn1ADDCTXAcMc+iquqbubwgbZb\nSa4neaa7b2x6MQAOrXJFfdDd2939eJL3kjy74Z0AOOJ2jz4uJ3l0E4sAsNzKoa6qrSRP5/AY5Nb3\ndqtqv6r238+769wP4Nw79ow6ycWqurJ4fTnJS7d+oLv3kuwlyX11qde3HgCrhPqgu7c3vgkAS7k9\nD2A4oQYY7thQd/e9J7EIAMu5ogYYTqgBhhNqgOGEGmA4oQYYTqgBhhNqgOGEGmA4oQYYTqgBhhNq\ngOGEGmA4oQYYTqgBhhNqgOGOfRRXVd3M4QNtt5JcT/JMd9/Y9GIAHFrlivqgu7e7+/Ek7yV5dsM7\nAXDE7R59XE7y6CYWAWC5lUNdVVtJns7hMcit7+1W1X5V7b+fd9e5H8C5d+wZdZKLVXVl8fpykpdu\n/UB37yXZS5L76lKvbz0AVgn1QXdvb3wTAJZyex7AcEINMNyxoe7ue09iEQCWc0UNMJxQAwwn1ADD\nCTXAcEINMJxQAwwn1ADDCTXAcEINMJxQAwwn1ADDCTXAcEINMJxQAwx37BNequpmDp+TuJXkepJn\nuvvGphcD4NAqV9QH3b3d3Y8neS/JsxveCYAjbvfo43KSRzexCADLrRzqqtpK8nQOj0EAOCGrPIX8\nYlVdWby+nOSlWz9QVbtJdpPk7nxkfdsBsFKoD7p7+4M+0N17SfaS5L661OtYDIBDbs8DGE6oAYY7\nNtTdfe9JLALAcq6oAYYTaoDhhBpgOKEGGE6oAYYTaoDhhBpguOpe73d8V9Wvk/ziNv6WB5O8sdYl\nzDff/LMw/07e/XeZ/4fd/dCyN9Ye6ttVVfvdvWO++eabf1Kz77T5jj4AhhNqgOEmhHrPfPPNN/+E\nZ99R80/9jBqADzbhihqADyDU3JGq6pGqev1Onc/pqaqbVXWlql6vqperaq3PD9zEfKEGzpuD7t7u\n7seTvJfk2enzTzXUVfXdqnqtqq4tHpBr/hmafwK2quo7VXW9ql5Z95XRJudX1Zeq6mdVdbWqvrWu\nuebftstJHh0/v7tP7a8klxY/XkzyepIHzD878zf8a+eRJJ3kTxY//5skf3knzE/yiST/muTBo/8d\n1ri7+R88/78XP24l+bskfzF9/mkffXylqq4m+UmSh5N83PwzNX/TftndP168/naSz9wh8z+X5OXu\nfiNJuvutNc01fzUXq+pKkv0k/57kpenztz70Sr+jqvpsks8nebK7b1TVPyS52/yzMf+E3Hpv6brv\nNd30fE7HQXdv30nzT/OK+qNJ3l5E4rEknzL/TM0/CR+rqicXr7+Y5Ed3yPwfJPlCVT2QJFV1aU1z\nzT+jTjPU38vhF2uuJ/laDn/7bf7ZmX8Sfp7ky4t/hvuTfONOmN/d15J8NckPF0dPf72OueafXb4z\nEWC4Uzuj5uxa/Jb11SVvPdXdb570PnCnc0UNMNxp354HwDGEGmA4oQYYTqgBhhNqgOGEGmC4/wEv\nD3ULKXWGYgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "xqDiA1mqfu6Q", + "colab_type": "code", + "outputId": "e2c16154-1003-40c4-b3b0-2be9b79d16ee", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 263 + } + }, + "source": [ + "# Find best alignment\n", + "dist = torch_struct.AlignmentCRF(log_potentials)\n", + "show(dist.argmax, 0)\n" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD2CAYAAAD/C81vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAJvUlEQVR4nO3dQailZR3H8d+/JmgEkxyTDLIhEBMX\nBc4iK0GcNi0Cg6KwsGhxF4WubCG6cBO0amnkIghzIRmKLYpiosRAmtEcUwelTbQTp6CF0CKeFveo\n1/HcuceZ973nf+58PjB4r/fMn2fOzHw9vue571NjjADQ1/vWvQAAzk+oAZoTaoDmhBqgOaEGaE6o\nAZo7NMfQuqpGjs4xedtNz843ez9s+PKBmYwxatm/rzn2UdexGjk1+di3LP+lbI4NXz4wk91C7dIH\nQHNCDdCcUAM0J9QAzQk1QHNCDdCcUAM0J9QAza0U6qp6oqqeraqXqmpr7kUB8LZVv4X8u2OMf1XV\n4SQnq+pXY4yzcy4MgG2rhvruqvrK4uOPJ7kuyTtCvXilvf1q+9qplgfAnpc+qurWJF9McvMY49NJ\n/prkg+c+bozx0Bjj2BjjWD4y+ToBLlmrXKO+Ism/xxhvVNWnknx25jUBsMMqof5tkkNVdSbJj5I8\nM++SANhpz2vUY4z/JvnSPqwFgCXsowZoTqgBmhNqgOaEGqA5oQZoTqgBmpvnFPKq6YceJHM/O445\nh43kFHKADSXUAM0JNUBzQg3QnFADNCfUAM0JNUBzQg3QnFADNCfUAM0JNUBzex7Ftaqq2kqyNdU8\nALa5KdM6uCkTsISbMgFsKKEGaE6oAZoTaoDmhBqguT2351XVkSQnlnzp+Bjj7PRLAmAn2/PWwfY8\nYAnb8wA2lFADNCfUAM1Ndq8P3oO5ryHPfA18+VW06bjEDu/kFTVAc0IN0JxQAzQn1ADNCTVAc0IN\n0JxQAzQn1ADN7RnqqjpaVS/ux2IAeDevqAGaWzXUh6rqkao6U1WPVdVls64KgLesGurrkzw4xrgh\nyX+SfO/cB1TVVlWdqqpTUy4Q4FK358EBVXU0yVNjjGsXn9+W5O4xxu3n+TkODlgnN2WCjXSxBwec\n+1dfiAH2yaqhvraqbl58fEeSp2daDwDnWDXUryT5flWdSfLhJD+Zb0kA7ORw24PINWrYSA63BdhQ\nQg3QnFADNCfUAM0JNUBzQg3Q3KF1L4AZzL59bu7dlzbowU5eUQM0J9QAzQk1QHNCDdCcUAM0J9QA\nzQk1QHNCDdCcUAM0J9QAzQk1QHOT3eujqraSbE01D4BtzkzkArgpE8zBmYkAG0qoAZoTaoDmhBqg\nuT13fVTVkSQnlnzp+Bjj7PRLAmAnuz64AHZ9wBzs+gDYUEIN0JxQAzQn1ADNTXavDy4lM7/ZN+N7\nlcvfquFNnp6evKIGaE6oAZoTaoDmhBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZpbOdRVdWdVvVBV\np6vq4TkXBcDbVrrXR1XdmOT+JJ8bY7xeVVcuecxWkq2J1wdwyVvphJequivJR8cY96001AkvXAw3\nZVobT896OeEFYEOtGuo/JPna4qDbLLv0AcA8Vj7ctqq+neQHSf6X5K9jjO+c57EufXDhXPpYG0/P\neu126cMp5PQj1Gvj6Vkv16gBNpRQAzQn1ADNCTVAc0IN0JxQAzS30r0+YF/NuEfM9rM9zLyxdu7t\nkQf199craoDmhBqgOaEGaE6oAZoTaoDmhBqgOaEGaE6oAZoTaoDmhBqgOaEGaG6ye31U1VaSranm\nAbDNmYnA29yUaa2cmQiwoYQaoDmhBmhOqAGa23PXR1UdSXJiyZeOjzHOTr8kAHay6wN4m10fa2XX\nB8CGEmqA5oQaoDmhBmhusnt9AAfApr/ZN/c2hjW9W+kVNUBzQg3QnFADNCfUAM0JNUBzQg3QnFAD\nNCfUAM2951BX1QNVdc8ciwHg3byiBmhupVBX1X1V9WpVPZ3k+pnXBMAOq5zwclOSbyT5zOLxzyV5\ndsnjtpJsTb1AgEvdKjdluiXJ42OMN5Kkqp5c9qAxxkNJHlo8xgkvABNxjRqguVVC/VSS26vqcFVd\nnuTLM68JgB32vPQxxniuqh5NcjrJa0lOzr4qAN7iFHLg4NjwgwOcQg6woYQaoDmhBmhOqAGaE2qA\n5oQaoLlVvoUcYDPMvjN45v15u/CKGqA5oQZoTqgBmhNqgOaEGqA5oQZoTqgBmhNqgOaEGqA5oQZo\nTqgBmpvsXh9VtZVka6p5AGxzZiJwgGz2TZmcmQiwoYQaoDmhBmhOqAGa23PXR1UdSXJiyZeOjzHO\nTr8kAHay6wM4QOz6AGANhBqgOaEGaE6oAZqb7F4fAOs375t9s75XeWz3L3lFDdCcUAM0J9QAzQk1\nQHNCDdCcUAM0J9QAzQk1QHNCDdDcSqGuqm9V1V+q6vmq+mlVvX/uhQGwbc9QV9UNSb6e5PNjjM8k\n+V+Sb869MAC2rXKvj+NJbkpysqqS5HCS1859UFVtJdmadHUA7H3CS1XdleRjY4x7Vx7qhBfgIJr5\npkzj1IWf8HIiyVer6uokqaorq+oTU64PgN3tGeoxxstJ7k/yu6p6Icnvk1wz98IA2OZwW4BVNb70\nAcAaCTVAc0IN0JxQAzQn1ADNCTVAc0IN0Nwq9/q4EK8n+cd7ePxVi58zhzlnm2+++ZfS/KW7nCea\nnez6Hd+zfMPLe1VVp8YYxzZttvnmm2/+fsx26QOgOaEGaK5LqB/a0Nnmm2+++bPPbnGNGoDddXlF\nDcAuhPoSVVVHq+rFTZ0PU6mqB6rqnnWv43yEGqC5tYa6qp6oqmer6qXF4bjm769DVfVIVZ2pqseq\n6rJNml9Vd1bVC1V1uqoennK2+QdbVd1XVa9W1dNJrp9h/req6i9V9XxV/bSq3n9RA8cYa/uR5MrF\nPw8neTHJEfP37bk/mu3zKj6/+PxnSe7ZoPk3Jnk1yVU7fy/M35/5m/wjyU1J/pbksiQfSvL3if9s\n3pDk10k+sPj8wSR3XszMdV/6uLuqTid5JsnHk1xn/r765xjjz4uPf5HkCxs0/7YkvxxjvJ4kY4x/\nTTjb/IPtliSPjzHeGGP8J8mTE88/nu3/GJysqucXn3/yYgbOda+PPVXVrUm+mOTmMcYbVfXHJB80\nf1+duzdz6r2ac8+HjirJz8cY9041cJ2vqK9I8u9F5D6V5LPm77trq+rmxcd3JHl6g+b/IcnXqupI\nklTVlRPONv9geyrJ7VV1uKouT/LlieefSPLVqro62X7uq2rXGy6tYp2h/m2232w6k+RH2b58YP7+\neiXJ9xe/hg8n+cmmzB9jvJTkh0n+tLj89OOpZpt/sI0xnkvyaJLTSX6T5OTE819Ocn+S31XVC0l+\nn+Sai5npOxMBmlvbNWrOb/G/rCeWfOn4GOPsfq8H3jT3n01/9t/NK2qA5ta9PQ+APQg1QHNCDdCc\nUAM0J9QAzQk1QHP/BwKWOQXFkW6QAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "RdOaPe6rfU5k", + "colab_type": "code", + "outputId": "2db5d26f-7bff-4d48-c08a-cf701fd31451", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 263 + } + }, + "source": [ + "show(dist.argmax, 1)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD2CAYAAAD/C81vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAJw0lEQVR4nO3dMYhl53nH4f8bprCMwFhSII2VLdZE\nwQtZyBRW4sJYbrZM4caFFVIMm5i4UbqkNbhKGYPAaWxXdsDu3MjErA2GzGDJkRicKo5byYZgdsmC\neFPMRTte3517x75n5p3Z5wGxI87Rq0874sfZb849p7o7AMz1B5e9AADOJtQAwwk1wHBCDTCcUAMM\nJ9QAw+3temC9UJ0bu576yJ8fLTebzfz2w2Le7e4/XHdg56HOjSSHO5/6gcNabjab+e2Hxfz8SQds\nfQAMJ9QAwwk1wHBCDTCcUAMMJ9QAwwk1wHBCDTDcVqGuqu9U1VFVvVNVB0svCoBHtv1k4t909y+r\n6pkk/1FV/9bd7y25MABObBvqL1XVX62+/liSjyf5INSrq+yTK+0Xd7k8ADZufVTVp5N8NsnL3f1n\nSX6S5EOnz+nu17t7v7v3s/aRIgD8rrbZo/5Ikl919/2qeinJJxdeEwCnbBPq7yXZq6rjJF9J8uNl\nlwTAaRv3qLv7/5LcuYC1ALCG+6gBhhNqgOGEGmA4oQYYTqgBhhNqgOF2/xbyoyz6qmpvwb5kvfB8\n32D4La6oAYYTaoDhhBpgOKEGGE6oAYYTaoDhhBpgOKEGGE6oAYYTaoDhhBpguJ0866OqDpIc7GIW\nAL+punf7lJ2qWvqxPVwmD2WCpRx19/66A7Y+AIYTaoDhhBpgOKEGGE6oAYbbeHteVT2f5I01h17p\n7vd2vyQATtsY6lWMb1/AWgBYw9YHwHBCDTCcUAMMt5NnffAUWfoj3gt+RL0XXrtPv7MUV9QAwwk1\nwHBCDTCcUAMMJ9QAwwk1wHBCDTCcUAMMtzHUVXWjqt6+iMUA8NtcUQMMt22o96rqm1V1XFXfrqoP\nL7oqAD6wbaj/JMm/dPefJvnfJH93+mBVHVTVYVUd7nqBAE+7bUP9i+7+0errbyT51OmD3f16d+93\n9/5OVwfA1qF+/JlmCz7jDIDTtg31i1X18urrzyf54ULrAeAx24b6Z0m+WFXHST6a5KvLLQmA07Z5\nue1/J3lp+aUAsI77qAGGE2qA4YQaYDihBhhOqAGGE2qA4TbengcXqpYcvfQHahdcPE81V9QAwwk1\nwHBCDTCcUAMMJ9QAwwk1wHBCDTCcUAMMJ9QAwwk1wHBCDTDcTp71UVUHSQ52MQuA31Tdu31QTVUt\n/eQb+B15KBOjHXX3/roDtj4AhhNqgOGEGmA4oQYYbuNdH1X1fJI31hx6pbvf2/2SADhtY6hXMb59\nAWsBYA1bHwDDCTXAcEINMJxQAwy3k2d9wNWw8Ee8F/6EevuE+rV21rfXFTXAcEINMJxQAwwn1ADD\nCTXAcEINMJxQAwwn1ADDCTXAcEINMNxWoa6qL1TVT6vqrar6+tKLAuCRbd7w8okk/5TkL7r73ap6\nbs05B0kOFlgfwFOvus9+kkxV/X2SP+ruf9xqYNXCj6aBoTyUid9DJUfdvb/umD1qgOG2CfX3k3xu\n9ZLbrNv6AGA527zc9p2q+nKSH1TV+0l+kuSvl14YACc27lGfe6A9ap5W9qj5PdijBrjChBpgOKEG\nGE6oAYYTaoDhhBpguI33UQNbWvj2OXfnXbal7zx+8nfYFTXAcEINMJxQAwwn1ADDCTXAcEINMJxQ\nAwwn1ADDbfNy2/eT/Ofq3OMkr3b3/aUXBsCJba6oH3T37e6+leRhkrsLrwmAU8679XEvyc0lFgLA\neluHuqr2ktzJyTbI48cOquqwqg53uTgAtnhn4qk96uTkivq17n54xvnemQhcQ4s/lOmJ70zc5ul5\nD7r79o5XBMCW3J4HMJxQAwy3MdTd/exFLASA9VxRAwwn1ADDCTXAcEINMJxQAwwn1ADDCTXAcEIN\nMJxQAwwn1ADDCTXAcEINMJxQAwwn1ADDCTXAcBtfxXXqnYl7SY6TvNrd95deGAAntrmiftDdt7v7\nVpKHSe4uvCYATjnv1se9JDeXWAgA620d6qraS3InJ9sgjx87qKrDqjrc5eIASKq7zz7h0R51cnJF\n/Vp3Pzzj/LMHAlxJS6etjrp7f92RjT9MzGqPescrAmBLbs8DGE6oAYbbGOrufvYiFgLAeq6oAYYT\naoDhhBpgOKEGGE6oAYYTaoDhhBpgOKEGGE6oAYYTaoDhhBpgOKEGGE6oAYYTaoDhhBpguI2v4jr1\nzsS9JMdJXu3u+0svDIAT21xRP+ju2919K8nDJHcXXhMAp5x36+NekptLLASA9bYOdVXtJbmTk22Q\nx48dVNVhVR3ucnEAJNXdZ5/waI86Obmifq27H55x/tkDAa6kpdNWR929v+7Ixh8mZrVHveMVAbAl\nt+cBDCfUAMNtDHV3P3sRCwFgPVfUAMMJNcBwQg0wnFADDCfUAMMJNcBwQg0wnFADDCfUAMMJNcBw\nQg0wnFADDCfUAMMJNcBwG9/wcupVXHtJjpO82t33l14YACe2uaJ+0N23u/tWkodJ7i68JgBOOe/W\nx70kN5dYCADrbR3qqtpLcieP3kgOwAXY5i3kz1TVm6uv7yX52uMnVNVBkoNdLgyAE9XdZ59Q9evz\nvDexqs4eCHAlLZ22Ouru/XVH3J4HMJxQAwy3MdTn2fYAYPdcUQMMJ9QAwwk1wHBCDTCcUAMMJ9QA\nwwk1wHDbPOvjvN5N8vNznP/C6p9Zivnmm3815w9bey08P3/8xH/zpmd9LK2qDp/0+XbzzTf/6Z1/\nlde+6/m2PgCGE2qA4SaE+nXzzTff/AuefaXmX/oeNQBnm3BFDcAZhJorqapuVNXbV3U+l6eq3q+q\nN6vq7ar6VlV9ePp8oQaeNg+6+3Z330ryMMnd6fMvNdRV9Z2qOqqqd1YvyDX/Gs2/AHtV9c2qOq6q\nb+/6ymjJ+VX1har6aVW9VVVf39Vc88/tXpKb4+d396X9leS51a/PJHk7yfPmX5/5C/+/cyMnbxv9\ny9Xf/2uSf7gK85N8Isl/JXnh9Pdhh2s3/+z5v179upfku0n+dvr8y976+FJVvZXkx0k+luTj5l+r\n+Uv7RXf/aPX1N5J86orM/0ySb3X3u0nS3b/c0Vzzt/NMVb2Z5DDJ/yT52vT5SzzrYytV9ekkn03y\ncnffr6p/T/Ih86/H/Avy+L2lu77XdOn5XI4H3X37Ks2/zCvqjyT51SoSLyX5pPnXav5FeLGqXl59\n/fkkP7wi87+f5HNV9XySVNVzO5pr/jV1maH+Xk5+WHOc5Cs5+eO3+ddn/kX4WZIvrv4bPprkq1dh\nfne/k+TLSX6w2nr6513MNf/68slEgOEubY+a62v1R9Y31hx6pbvfu+j1wFXnihpguMu+PQ+ADYQa\nYDihBhhOqAGGE2qA4YQaYLj/B4z+V3jbsn5ZAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "wdMZ89-ehpxq", + "colab_type": "code", + "outputId": "2d836e19-2e0f-405d-a79e-58bc548622f9", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 263 + } + }, + "source": [ + "# Find marginals (see uncertainty from randomness)\n", + "show(dist.marginals, 0)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD2CAYAAAD/C81vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAALKklEQVR4nO3dX6jed30H8PfHJLMpRmfSBgT7Zxtd\nW9yfjubCTh1ivNlFwYHicFLGYGcwwaveiB24i4JXu9iF4rkYlOpF0WFxsEklbkodYtqadq1di3Mb\njinatKxl3crafnZxTmuWPu15uvy+J9+TvF4Q8pznefLmk+Scd375nt/v+6vuDgDzesP5HgCA16ao\nASanqAEmp6gBJqeoASanqAEmt39EaF1WnatHJG/5lfvHZSfJM2Pj8++D858fnO+EThiju2vV8zXi\nPOo6Vp2Ti8e+7J8G/z/gb8fG508H5/94cP7/DM6Hi9WrFbWlD4DJKWqAySlqgMkpaoDJKWqAySlq\ngMkpaoDJKWqAya1V1FV1d1XdX1WPVNXG6KEA+Jl1LyH/g+5+sqoOJjlZVX/Z3adHDgbAlnWL+uNV\n9Tvbj69Ick2S/1PU20faW0fbVy41HgA7Ln1U1XuTvD/JTd3960m+m+SSs9/X3Zvdfay7j+XyxecE\nuGits0b9liRPdfezVXVdkncOngmAM6xT1F9Nsr+qHk3y6STfHjsSAGfacY26u59L8tu7MAsAKziP\nGmByihpgcooaYHKKGmByihpgcooaYHJj7kJe+zv5+cVzX3Lo0OFh2UnyC4d+aWj+b7194C3ak9zz\nwH8Mzf/n558fmj8yffnPdliOu5AD7FGKGmByihpgcooaYHKKGmByihpgcooaYHKKGmByihpgcooa\nYHKKGmByO96Ka11VtZFkY+sj/Q+wlMWKurs3k2wmL23KBMASHPoCTE5RA0xOUQNMTlEDTE5RA0xu\nx7M+qupIkhMrXjre3aeXHwmAM+1Y1NtlfMMuzALACpY+ACanqAEmp6gBJlfdy1/tXVVjLyE/MDQ9\n+988Nv/Sw2Pzr/jjsfk3b47N/+vvj8v+wfPjspPk2cGbJ7w4Np7zrLtr1fOOqAEmp6gBJqeoASan\nqAEmp6gBJqeoASanqAEmp6gBJrdjUVfV1VX18G4MA8ArOaIGmNy6Rb2/qr5QVY9W1Zeq6tKhUwHw\nsnWL+tokn+nu65M8neQVu0lU1UZV3VdV9y05IMDFbt2i/mF3f2v78eeTvPvsN3T3Zncf6+5ji00H\nwNpFffaeYIP3CAPgJesW9ZVVddP2448kuXfQPACcZd2ifizJx6rq0SRvTfLZcSMBcKZ1bm77L0mu\nGz8KAKs4jxpgcooaYHKKGmByihpgcooaYHKKGmBy1b38RYZVtcevXNzxrMVzcmBfjc3PNUPzj17+\na0PzD+37+2HZT53+6bDsJDn93HND8/+7Xxyav8e/cPe87l5ZDo6oASanqAEmp6gBJqeoASanqAEm\np6gBJqeoASanqAEmp6gBJqeoASanqAEmt9imFlW1kWRjqTwAtixW1N29mWQzuRA2ZQKYh6UPgMkp\naoDJKWqAySlqgMnt+M3EqjqS5MSKl4539+nlRwLgTDsW9XYZ37ALswCwgqUPgMkpaoDJKWqAySlq\ngMlV9/JXe7uE/DxbbGOAV4k/Mjb/6M3jsj/8o3HZSfLLj43N/7cfj82/4z/H5v9ocDO8MDZ+uO6u\nVc87ogaYnKIGmJyiBpicogaYnKIGmJyiBpicogaYnKIGmJyiBpicogaY3NpFXVW3VNVDVfVgVd05\ncigAfmatXSGq6h1Jbkvym939RFUdXvGejSQbC88HcNFbd/ue9yX5Ync/kSTd/eTZb+juzSSbiU2Z\nAJZkjRpgcusW9deTfGj7RrdZtfQBwBhrLX109yNVdXuSb1TVC0m+m+T3Rw4GwJa1t5jv7juS3DFw\nFgBWsEYNMDlFDTA5RQ0wOUUNMDlFDTA5RQ0wuepe/mpvl5Bf6Gpw/huHJf/cgbcMy06SQ2+8bGj+\n4YOXD83/1aOnhua/+INnhubf/V8vDM0frbtXfnE5ogaYnKIGmJyiBpicogaYnKIGmJyiBpicogaY\nnKIGmJyiBpicogaYnKIGmNzat+LaSVVtJNlYKg+ALYsVdXdvJtlMbMoEsCRLHwCTU9QAk1PUAJNT\n1ACT2/GbiVV1JMmJFS8d7+7Ty48EwJl2LOrtMr5hF2YBYAVLHwCTU9QAk1PUAJNT1ACTq+7lr/Z2\nCTnTWmzThNX2vWls/iWXj80/sG9s/huOjs2/7Kqx+Y/fOTa/u2vV846oASanqAEmp6gBJqeoASan\nqAEmp6gBJqeoASanqAEm97qLuqo+VVW3jhgGgFdyRA0wubWKuqo+WVWPV9W9Sa4dPBMAZ1jnDi83\nJvndbN08YH+SB5Lcv+J9G0k2lh4Q4GK3zhY170ny5e5+Nkmq6iur3tTdm0k2t99jUyaAhVijBpjc\nOkX9zSQfqKqDVXUoyc2DZwLgDOvc3PaBqroryYNJfpLk5PCpAHjZWtuod/ftSW4fPAsAK1ijBpic\nogaYnKIGmJyiBpicogaYnKIGmFx1L3+1t0vIuXitdcbr/9uBfW8amn/k0Nj5316XDM1//oU/Gpp/\n6uk/GZrf3bXqeUfUAJNT1ACTU9QAk1PUAJNT1ACTU9QAk1PUAJNT1ACTU9QAk1PUAJNT1ACTW+zC\n/qraSLKxVB4AWxYr6u7eTLKZ2JQJYEmWPgAmp6gBJqeoASanqAEmt+M3E6vqSJITK1463t2nlx8J\ngDPtWNTbZXzDLswCwAqWPgAmp6gBJqeoASanqAEmV93LX+3tEnLYo2pw/CVj8y+9dmz+b/z5uOxT\nf5g884+98m/AETXA5BQ1wOQUNcDkFDXA5BQ1wOQUNcDkFDXA5BQ1wOQUNcDk1irqqvpoVX2nqk5V\n1eeqat/owQDYsmNRV9X1ST6c5F3dfUOSF5L83ujBANiy440DkhxPcmOSk1WVJAeT/OTsN1XVRpKN\nRacDYK2iriR3dPcnXutN3b2ZZDOxKRPAktZZoz6R5INVdTRJqupwVV01diwAXrJjUXf395LcluSe\nqnooydeSvG30YABsWWfpI919V5K7Bs8CwArOowaYnKIGmJyiBpicogaYnKIGmJyiBpicogaYXHUv\nf7V3Vf00yb++jl9yWZInFh9kfLZ8+fLlL5V9VXdfvuqFIUX9elXVfd19bK9ly5cvX/5uZFv6AJic\nogaY3CxFvblHs+XLly9/ePYUa9QAvLpZjqgBeBWK+iJVVVdX1cN7NR+WUlWfqqpbz/ccr0VRA0zu\nvBZ1Vd1dVfdX1SPbN8eVv7v2V9UXqurRqvpSVV26l/Kr6paqeqiqHqyqO5fMln9hq6pPVtXjVXVv\nkmsH5H+0qr5TVaeq6nNVte+cArv7vP1Icnj754NJHk5yRP6u/dlfnaSTvGv7479Icuseyn9HkseT\nXHbm34X83cnfyz+S3JjkH5JcmuTNSb6/8Ofm9Un+KsmB7Y8/k+SWc8k830sfH6+qB5N8O8kVSa6R\nv6t+2N3f2n78+STv3kP570vyxe5+Ikm6+8kFs+Vf2N6T5Mvd/Wx3P53kKwvnH8/WPwYnq+rU9se/\neC6Ba90zcYSqem+S9ye5qbufraq/S3KJ/F119rmZS5+rOTofZlRJ7ujuTywVeD6PqN+S5Kntkrsu\nyTvl77orq+qm7ccfSXLvHsr/epIPVdWRJKmqwwtmy7+wfTPJB6rqYFUdSnLzwvknknywqo4mW3/2\nVXXVuQSez6L+ara+2fRokk9na/lA/u56LMnHtn8Pb03y2b2S392PJLk9yTe2l5/+bKls+Re27n4g\nyV1JHkzyN0lOLpz/vSS3Jbmnqh5K8rUkbzuXTFcmAkzuvK1R89q2/8t6YsVLx7v79G7PAy8Z/bnp\nc/+VHFEDTO58n54HwA4UNcDkFDXA5BQ1wOQUNcDkFDXA5P4Xb7dm40sU2EcAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "b2rayj96fpnF", + "colab_type": "code", + "outputId": "51966a09-b8ba-441d-edcc-d6cb69df6e9f", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 263 + } + }, + "source": [ + "# No uncertainty for the padding.\n", + "show(dist.marginals, 1)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "display_data", + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD2CAYAAAD/C81vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAKm0lEQVR4nO3dT4jnd3kH8PeTjNGEQEw2pbTUNEKk\nKS50KXsw1YMYLwuF0oNQPJhSypBW8JIeexU8lB4rLFgq6ikW7E1KI5UYENwliSYsVkqrubR1E6sN\nu2aT7dPDTNxx+8vMb/X3nXlm8npB2N/y/e57n2TDez/zme+f6u4AMNdtRz0AAPtT1ADDKWqA4RQ1\nwHCKGmA4RQ0w3NamA+v+6jy46dQbfvfictlJ8say8Xl94fxXF87/j4Xzl/zv70JUhrvc3b+y6sDG\nizoPJvnWxlN/5pmFvwb48bLxeWnh/GcWzv+rhfP/c8Hspf+ShF/S99/qgK0PgOEUNcBwihpgOEUN\nMJyiBhhOUQMMp6gBhlPUAMOtVdRV9ZWqulhVL1bV9tJDAXDDuncm/kl3v1JVdyb5VlX9fXe/vORg\nAOxYt6g/VVV/uPv5PUnel+RnRb27yt5ZaT+wyfEAOHDro6o+nOSjSR7p7t9J8mySd+09p7vPd/fZ\n7j6blY8UAeAXtc4e9T1JftTdV6rq4SQfWHgmAPZYp6i/mmSrqi4l+UySby47EgB7HbhH3d2vJTl3\nCLMAsILrqAGGU9QAwylqgOEUNcBwihpgOEUNMNzm30J+Mcltm4990923L5edJO+8/c5F8++5492L\n5v/63e9dNP8Pfu25RfP/6Ts/WSz73954Y7HsZPm3nPfC+cxlRQ0wnKIGGE5RAwynqAGGU9QAwylq\ngOEUNcBwihpgOEUNMJyiBhhOUQMMt5EHZ1TVdpLtTWQB8PM2UtTdfT7J+SSpKs+OAdggWx8Awylq\ngOEUNcBwihpgOEUNMNyBV31U1akkT6049Gh3v7z5kQDY68Ci3i3jM4cwCwAr2PoAGE5RAwynqAGG\nq+7N3vG9+C3ktWj64n913X7Hsvl33LVs/l33Lpv/G3+6XPbv/91y2Unyj/+6bP733lg2/38WfvjD\n9WXjT4KL3X121QEraoDhFDXAcIoaYDhFDTCcogYYTlEDDKeoAYZT1ADDHVjUVfVgVb1wGMMA8P9Z\nUQMMt25Rb1XVl6rqUlV9uaoWvlEZgDetW9S/leRvuvu3k/wkyZ/vPVhV21V1oaoubHpAgLe7dYv6\npe5+ZvfzF5N8aO/B7j7f3Wff6oEiAPzi1i3qm5+rtfBztgB407pF/UBVPbL7+eNJvrHQPADcZN2i\n/m6ST1bVpST3JvnsciMBsNc6L7f99yQPLz8KAKu4jhpgOEUNMJyiBhhOUQMMp6gBhlPUAMNV92Zv\nMqwqdy3uqxZNv60OvOLyl3LH7cv+8b6j37tY9qlTpxfLTpJ3v/PZRfNfffmHi+Zffu2ni+b/9/Xr\ni+afABff6jEcVtQAwylqgOEUNcBwihpgOEUNMJyiBhhOUQMMp6gBhlPUAMMpaoDhFDXAcBt5MERV\nbSfZ3kQWAD9vI0Xd3eeTnE88lAlg02x9AAynqAGGU9QAwylqgOEO/GZiVZ1K8tSKQ49298ubHwmA\nvQ4s6t0yPnMIswCwgq0PgOEUNcBwihpgOEUNMNxGbiHnVix7h/3/9uuL5v/0jUXj89o7vrdcdi+X\nnSSvf3DR+PzRj5fNP/eDZfNP/3DZ/F/9r2Xzl1b7HLOiBhhOUQMMp6gBhlPUAMMpaoDhFDXAcIoa\nYDhFDTCcogYYTlEDDLdWUVfVJ6rq21X1fFV9YemhALhhnTe8vD/JXyb5ve6+XFX3rThnO8n2AvMB\nvO2t81CmjyR5srsvJ0l3v3LzCd19Psn5JKmqZZ86BPA2Y48aYLh1ivprST62+5LbrNr6AGA567zc\n9sWq+nSSr1fV9STPJvnjpQcDYMdaLw7o7s8n+fzCswCwgj1qgOEUNcBwihpgOEUNMJyiBhhOUQMM\nV92bvePbLeTAybR0tdXF7j676ogVNcBwihpgOEUNMJyiBhhOUQMMp6gBhlPUAMMpaoDh1nm57fUk\n39k991KSx7r7ytKDAbBjnRX11e4+092nk1xL8vjCMwGwx61ufTyd5KElBgFgtbWLuqq2kpzLzjbI\nzce2q+pCVV3Y5HAArPFQpj171MnOivqJ7r62z/keygScQEf3UKZ1Xm57tbvPbHgiANbk8jyA4RQ1\nwHAHFnV3330YgwCwmhU1wHCKGmA4RQ0wnKIGGE5RAwynqAGGU9QAwylqgOEUNcBwihpgOEUNMJyi\nBhhOUQMMp6gBhlPUAMMd+CquPe9M3EpyKclj3X1l6cEA2LHOivpqd5/p7tNJriV5fOGZANjjVrc+\nnk7y0BKDALDa2kVdVVtJzmVnG+TmY9tVdaGqLmxyOACS6u79T7ixR53srKif6O5r+5y/fyDAsbR0\ntdXF7j676siB30zM7h71hicCYE0uzwMYTlEDDHdgUXf33YcxCACrWVEDDKeoAYZT1ADDKWqA4RQ1\nwHCKGmA4RQ0wnKIGGE5RAwynqAGGU9QAwylqgOEUNcBwihpgOEUNMNyBr+La887ErSSXkjzW3VeW\nHgyAHeusqK9295nuPp3kWpLHF54JgD1udevj6SQPLTEIAKutXdRVtZXkXHa2QW4+tl1VF6rqwiaH\nAyCp7t7/hBt71MnOivqJ7r62z/n7BwIcS0tXW13s7rOrjhz4zcTs7lFveCIA1uTyPIDhFDXAcAcW\ndXfffRiDALCaFTXAcIoaYDhFDTCcogYYTlEDDKeoAYZT1ADDKWqA4RQ1wHCKGmA4RQ0wnKIGGE5R\nAwynqAGGO/ANL3texbWV5FKSx7r7ytKDAbBjnRX11e4+092nk1xL8vjCMwGwx61ufTyd5KElBgFg\ntbWLuqq2kpzLjTeSA3AI1nkL+Z1V9dzu56eTfO7mE6pqO8n2JgcDYEd19/4nVL16K+9NrKr9AwGO\npaWrrS5299lVR1yeBzCcogYY7sCivpVtDwA2z4oaYDhFDTCcogYYTlEDDKeoAYZT1ADDKWqA4dZ5\n1setupzk+7dw/v27v2Yp8uXLP575w2avhfPzm2/5Ox/0rI+lVdWFt7q/Xb58+W/f/OM8+6bzbX0A\nDKeoAYabUNTn5cuXL/+Qs49V/pHvUQOwvwkragD2oag5lqrqwap64bjmc3Sq6npVPVdVL1TVk1V1\n1/R8RQ283Vzt7jPdfTrJtSSPT88/0qKuqq9U1cWqenH3BbnyT1D+Idiqqi9V1aWq+vKmV0ZL5lfV\nJ6rq21X1fFV9YVO58m/Z00keGp/f3Uf2T5L7dn+8M8kLSU7JPzn5C/+/82B23jb6wd2f/22SvzgO\n+Unen+Rfkty/989hg7PL3z//1d0ft5L8Q5I/m55/1Fsfn6qq55N8M8l7krxP/onKX9pL3f3M7ucv\nJvnQMcn/SJInu/tyknT3KxvKlb+eO6vquSQXkvwgyeem5y/xrI+1VNWHk3w0ySPdfaWq/jnJu+Sf\njPxDcvO1pZu+1nTpfI7G1e4+c5zyj3JFfU+SH+2WxMNJPiD/ROUfhgeq6pHdzx9P8o1jkv+1JB+r\nqlNJUlX3bShX/gl1lEX91ex8s+ZSks9k58tv+Scn/zB8N8knd/8d7k3y2eOQ390vJvl0kq/vbj39\n9SZy5Z9c7kwEGO7I9qg5uXa/ZH1qxaFHu/vlw54HjjsraoDhjvryPAAOoKgBhlPUAMMpaoDhFDXA\ncIoaYLj/A4hxk4z6BaJJAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "sWLCBtu1hr5S", + "colab_type": "code", + "outputId": "bd0fc5ec-3da5-43cf-b42b-0494ea5983ea", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 35 + } + }, + "source": [ + "# Print the log sum exp of the data (for training)\n", + "print(dist.partition)" + ], + "execution_count": 0, + "outputs": [ + { + "output_type": "stream", + "text": [ + "tensor([19.6766, 15.4172], device='cuda:0', grad_fn=)\n" + ], + "name": "stdout" + } + ] + } + ] +} \ No newline at end of file