diff --git a/notebooks/CTC_with_padding.ipynb b/notebooks/CTC_with_padding.ipynb
new file mode 100644
index 00000000..9727f85d
--- /dev/null
+++ b/notebooks/CTC_with_padding.ipynb
@@ -0,0 +1,391 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "name": "CTC.ipynb",
+ "provenance": [],
+ "collapsed_sections": [],
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "1A6gJZcOVq-h",
+ "colab_type": "code",
+ "outputId": "0467398c-3ac6-437a-8a42-cff49d155d2f",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 53
+ }
+ },
+ "source": [
+ "!pip install -qU git+https://github.com/harvardnlp/pytorch-struct\n",
+ "!pip install -qU git+https://github.com/harvardnlp/genbmm\n",
+ "!pip install -q matplotlib"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "text": [
+ " Building wheel for torch-struct (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Building wheel for genbmm (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
+ ],
+ "name": "stdout"
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "r070BEdwVzHs",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "import torch\n",
+ "import torch_struct\n",
+ "import matplotlib.pyplot as plt"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "zxRC_exrbTR4",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Character Vocab, P is a padding token\n",
+ "vocab = [\"a\", \"b\", \"c\", \"d\", \"e\", \"_\", \"P\"]\n",
+ "v_dict = { a:i for i, a in enumerate(vocab)}\n",
+ "L = len(vocab)\n",
+ "\n",
+ "# Char sequence\n",
+ "letters = [[\"a\", \"_\", \"b\", \"_\", \"c\", \"_\", \"d\", \"_\", \"e\"], \n",
+ " [\"a\", \"_\", \"b\", \"_\", \"c\", \"P\", \"P\", \"P\", \"P\"]]\n",
+ "#letters = [[\"a\", \"_\", \"b\", \"_\", \"c\"]]\n",
+ "\n",
+ "t = len(letters[0])\n",
+ "\n",
+ "# Padding\n",
+ "\n",
+ "# Frame sequence\n",
+ "frames = [[\"a\", \"a\", \"a\", \"_\", \"b\", \"b\", \"c\", \"c\", \"c\", \"c\", \"_\", \"_\", \"d\", \"e\"], \n",
+ " [\"a\", \"a\", \"a\", \"a\", \"a\", \"_\", \"b\", \"b\", \"c\", \"c\", \"c\", \"c\", \"P\", \"P\"]] \n",
+ "#frames = [[\"a\", \"a\", \"a\", \"a\", \"a\", \"_\", \"b\", \"b\", \"c\", \"c\", \"c\", \"c\"]] \n",
+ "\n",
+ "# Constants\n",
+ "T, B = len(frames[0]), len(frames)\n",
+ "D1, MATCH, D2 = 0, 1, 2\n",
+ "\n",
+ "batch_lengths = [(t, T), (t-4, T-2)]\n",
+ "#batch_lengths = [(t, T)]\n",
+ "\n",
+ "def show(m, ex):\n",
+ " plt.yticks(torch.arange(len(letters[ex])), letters[ex])\n",
+ " plt.xticks(torch.arange(T), [str(frames[ex][x.item()]) for x in torch.arange(T)])\n",
+ " plt.imshow(m[ex].cpu().detach())"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "rDtLyNKPa1N0",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Gold alignment. \n",
+ "gold = torch.zeros(B, t).long()\n",
+ "for b in range(B):\n",
+ " for i, l in enumerate(letters[0]):\n",
+ " gold[b, i] = v_dict[l]\n",
+ "gold = gold[:, None, :].expand(B, T, t)\n",
+ "\n",
+ "# Inputs (boost true frames a bit)\n",
+ "logits = torch.zeros(B, T, L)\n",
+ "for b in range(B):\n",
+ " for i in range(T):\n",
+ " logits[b, i, v_dict[frames[b][i]]] += 1 "
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "il4pW6M9YOKP",
+ "colab_type": "code",
+ "colab": {}
+ },
+ "source": [
+ "# Construct the alignment problem from CTC\n",
+ "\n",
+ "# Log-Potentials\n",
+ "log_potentials = torch.zeros(B, T, t, 3).fill_(-1e5)\n",
+ "\n",
+ "# Match gold to logits. \n",
+ "match = torch.gather(logits, 2, gold)\n",
+ "\n",
+ "# CTC Rules:\n",
+ "for b, (lb, la) in zip(range(B), batch_lengths):\n",
+ " # la and lb are the sizes of the two (without padding)\n",
+ "\n",
+ " # Never allowed to fully skip regular characters (little t)\n",
+ " log_potentials[b, :la, :lb:2, D2] = -1e5\n",
+ "\n",
+ " # Free to skip _ characters (little t)\n",
+ " log_potentials[b, :la, 1:lb:2, D2] = 0\n",
+ "\n",
+ " # First match with character is the logit. \n",
+ " log_potentials[b, :la, :lb, MATCH] = match[b, :la, :lb]\n",
+ "\n",
+ " # Additional match with character is the logit.\n",
+ " log_potentials[b, :la, :lb, D1] = match[b, :la, :lb]\n",
+ "\n",
+ " # Match padding in an L shape\n",
+ " log_potentials[b, la:, lb-1, D1] = 0\n",
+ " log_potentials[b, -1, lb:, D2] = 0\n",
+ "\n",
+ "\n",
+ "log_potentials = log_potentials.transpose(1, 2).cuda()"
+ ],
+ "execution_count": 0,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "metadata": {
+ "id": "Hz5I8cnLfpHE",
+ "colab_type": "code",
+ "outputId": "cd673207-eb68-4964-be96-3886800ade75",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 263
+ }
+ },
+ "source": [
+ "# Show input scores\n",
+ "show(match.transpose(1,2).exp(), 1)"
+ ],
+ "execution_count": 0,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD2CAYAAAD/C81vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAKH0lEQVR4nO3dz4udh3XH4e9pRlgSxiGyvREo9cKh\nLjFUi1nEbRYhDgj/AdlkEZcuBreBrLrsNpBVlw0MKBCSrOxCuougDg1KILQTIrs2Il41DWgT/4C2\nSMhGnC7m2hkmE89Vcu/MmZnnAaMR9+pwbA8fXr3z3vet7g4Ac/3JcS8AwMcTaoDhhBpgOKEGGE6o\nAYYTaoDhNlY98IlLn+inrpxb9ViGeOv1i8e9ApxK/5v33u7uJw96beWhfurKufz7jSurHssQ1y5f\nPe4V4FT6137lV7/vNac+AIYTaoDhhBpgOKEGGE6oAYYTaoDhhBpgOKEGGG6pUFfVD6rq51X1ZlVt\nrXspAH5r2U8m/k13v1tVF5L8R1X9c3e/s87FANi17KmPr1fVa0l+luRKks/sfbGqtqpqp6p2fvPO\ng1XvCHCmHRrqqvpCki8lea67/yLJL5Kc3/ue7t7u7s3u3nzy8U+sZVGAs2qZI+pPJnmvu+9W1TNJ\nPrfmnQDYY5lQ/zDJRlXdTvLN7J7+AOCIHPrDxO6+n+SFI9gFgAO4jhpgOKEGGE6oAYYTaoDhhBpg\nOKEGGG7lTyF/6/WLnlR9it24c2ut833vwO9yRA0wnFADDCfUAMMJNcBwQg0wnFADDCfUAMMJNcBw\nQg0wnFADDCfUAMOt5F4fVbWVZCtJzufiKkYCsLCSI+ru3u7uze7ePJdHVjESgAWnPgCGE2qA4YQa\nYDihBhhOqAGGO/TyvKp6PMmrB7z0fHe/s/qVANjr0FAvYuxBdgDHxKkPgOGEGmA4oQYYbiX3+uDs\nuHZ5vT+uuHHn1lrnr9O6/9twdjmiBhhOqAGGE2qA4YQaYDihBhhOqAGGE2qA4YQaYLhDQ11VT1XV\nG0exDAC/yxE1wHDLhnqjqr5fVber6pWqurjWrQD4yLKh/rMk/9Tdf57kf5L83d4Xq2qrqnaqaueD\n3F/1jgBn2rKh/nV3/3Tx9feSfH7vi9293d2b3b15Lo+sdEGAs27ZUPchvwdgTZYN9aer6rnF119J\n8pM17QPAPsuG+pdJvlZVt5N8Ksm31rcSAHst83Db/0ryzPpXAeAgrqMGGE6oAYYTaoDhhBpgOKEG\nGE6oAYY79PI82OvGnVtrnX/t8tW1zoeTyBE1wHBCDTCcUAMMJ9QAwwk1wHBCDTCcUAMMJ9QAwwk1\nwHBCDTCcUAMMt5J7fVTVVpKtJDmfi6sYCcDCSo6ou3u7uze7e/NcHlnFSAAWnPoAGE6oAYYTaoDh\nhBpguEOv+qiqx5O8esBLz3f3O6tfCYC9Dg31IsaejwRwTJz6ABhOqAGGE2qA4YQaYLiV3OuDs+Pa\nZT9X/n1u3Ll13Cucaaf5e9MRNcBwQg0wnFADDCfUAMMJNcBwQg0wnFADDCfUAMMJNcBwQg0w3FKh\nrqqvVtXrVfVaVX133UsB8FvLPOHls0n+IclfdvfbVXXpgPdsJdlKkvO5uPIlAc6yZY6ov5jk5e5+\nO0m6+939b+ju7e7e7O7Nc3lk1TsCnGnOUQMMt0yof5Tky4uH3OagUx8ArM8yD7d9s6q+keTHVfUg\nyS+S/PW6FwNg11IPDuju7yT5zpp3AeAAzlEDDCfUAMMJNcBwQg0wnFADDCfUAMMtdXkefOjGnVtr\nnX/t8tW1zl+nk7w7szmiBhhOqAGGE2qA4YQaYDihBhhOqAGGE2qA4YQaYLhlHm77IMl/Lt57O8mL\n3X133YsBsGuZI+p73X21u59N8n6Sl9a8EwB7POypj5tJnl7HIgAcbOlQV9VGkheyexpk/2tbVbVT\nVTsf5P4q9wM485a5KdOFqvrwTjw3k1zf/4bu3k6ynSSP1aVe3XoALBPqe93ttmAAx8TleQDDCTXA\ncIeGursfPYpFADiYI2qA4YQaYDihBhhOqAGGE2qA4YQaYLhlPpkIH7l22YdU4ag5ogYYTqgBhhNq\ngOGEGmA4oQYYTqgBhhNqgOGEGmC4Qz/wUlUPsvtA240kt5O82N13170YALuWOaK+191Xu/vZJO8n\neWnNOwGwx8Oe+riZ5Ol1LALAwZYOdVVtJHkhu6dB9r+2VVU7VbXzQe6vcj+AM2+ZmzJdqKpbi69v\nJrm+/w3dvZ1kO0keq0u9uvUAWCbU97rbLdMAjonL8wCGE2qA4Q4NdXc/ehSLAHAwR9QAwwk1wHBC\nDTCcUAMMJ9QAwwk1wHDLfDIRPnLjzq3D3/RHuHbZh2BhP0fUAMMJNcBwQg0wnFADDCfUAMMJNcBw\nQg0wnFADDHfoB16q6kF2H2i7keR2khe7++66FwNg1zJH1Pe6+2p3P5vk/SQvrXknAPZ42FMfN5M8\nvY5FADjY0qGuqo0kL2T3NMj+17aqaqeqdj7I/VXuB3DmLXNTpgtV9eGdeG4mub7/Dd29nWQ7SR6r\nS7269QBYJtT3utstzQCOicvzAIYTaoDhDg11dz96FIsAcDBH1ADDCTXAcEINMJxQAwwn1ADDCTXA\ncMt8MhE+cu2yD6nCUXNEDTCcUAMMJ9QAwwk1wHBCDTCcUAMMJ9QAwx16HXVVPcjucxI3ktxO8mJ3\n3133YgDsWuaI+l53X+3uZ5O8n+SlNe8EwB4Pe+rjZpKn17EIAAdbOtRVtZHkheyeBgHgiCxzr48L\nVXVr8fXNJNf3v6GqtpJsJcn5XFzddgAsFep73f2xd+Lp7u0k20nyWF3qVSwGwC6X5wEMJ9QAwx0a\n6u5+9CgWAeBgjqgBhhNqgOGEGmA4oQYYTqgBhhNqgOGEGmC46l7tJ76r6jdJfvUQf+SJJG+vdAnz\nzTf/NMw/ybv/IfP/tLufPOiFlYf6YVXVTndvmm+++eYf1eyTNt+pD4DhhBpguAmh3jbffPPNP+LZ\nJ2r+sZ+jBuDjTTiiBuBjCDUnUlU9VVVvnNT5HJ+qelBVt6rqjap6uapW+vzAdcwXauCsudfdV7v7\n2STvJ3lp+vxjDXVV/aCqfl5Vby4ekGv+KZp/BDaq6vtVdbuqXln1kdE651fVV6vq9ap6raq+u6q5\n5j+0m0meHj+/u4/tnySXFr9eSPJGksfNPz3z1/y981SSTvJXi99/O8nfn4T5ST6b5K0kT+z9/7DC\n3c3/+Pn/t/h1I8m/JPnb6fOP+9TH16vqtSQ/S3IlyWfMP1Xz1+3X3f3TxdffS/L5EzL/i0le7u63\nk6S7313RXPOXc6GqbiXZSfLfSa5Pn7/xR6/0B6qqLyT5UpLnuvtuVf1bkvPmn475R2T/taWrvtZ0\n3fM5Hve6++pJmn+cR9SfTPLeIhLPJPmc+adq/lH4dFU9t/j6K0l+ckLm/yjJl6vq8SSpqksrmmv+\nKXWcof5hdn9YczvJN7P712/zT8/8o/DLJF9b/Dt8Ksm3TsL87n4zyTeS/Hhx6ukfVzHX/NPLJxMB\nhju2c9ScXou/sr56wEvPd/c7R70PnHSOqAGGO+7L8wA4hFADDCfUAMMJNcBwQg0wnFADDPf/KfFi\nycGaaRoAAAAASUVORK5CYII=\n",
+ "text/plain": [
+ "