diff --git a/daft/table/table_io.py b/daft/table/table_io.py index da31a42a2c..f1ed3efcbd 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -414,8 +414,6 @@ def write_tabular( partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__", ) -> MicroPartition: - pass - [resolved_path], fs = _resolve_paths_and_filesystem(path, io_config=io_config) if isinstance(path, pathlib.Path): path_str = str(path) diff --git a/docs/source/user_guide/tutorials.rst b/docs/source/user_guide/tutorials.rst index 2852803209..08aea27569 100644 --- a/docs/source/user_guide/tutorials.rst +++ b/docs/source/user_guide/tutorials.rst @@ -31,6 +31,18 @@ Generate images from text prompts using a deep learning model (Mini DALL-E) and `Run this tutorial on Google Colab `__ +.. These can't be run because DeltaLake can't be accessed in anonymous mode from Google Colab +.. ML model batch inference/training on a Data Catalog +.. --------------------------------------------------- + +.. Run ML models or train them on data in your data catalog (e.g. Apache Iceberg, DeltaLake or Hudi) + +.. 1. `Local batch inference `__ +.. 1. `Distributed batch inference `__ +.. 1. `Single-node Pytorch model training `__ + + + .. Other ideas: .. Scaling up in the cloud with Ray **[Coming Soon]** .. Building a HTTP service **[Coming Soon]** diff --git a/tutorials/delta_lake/1-local-image-batch-inference.ipynb b/tutorials/delta_lake/1-local-image-batch-inference.ipynb new file mode 100644 index 0000000000..1737ea0405 --- /dev/null +++ b/tutorials/delta_lake/1-local-image-batch-inference.ipynb @@ -0,0 +1,559 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e2893f93-72d9-40ea-93e9-be2d3f3f66ee", + "metadata": {}, + "source": [ + "# Batch Inference on data in DeltaLake\n", + "\n", + "In this tutorial, we showcase how to perform ML model batch inference on data in a DeltaLake table.\n", + "\n", + "> **ML Model Batch Inference**\n", + "> \n", + "> When we have a trained machine learning model, the next step is often to apply this model to a large amount of data. This involves efficiently loading the model into memory (potentially GPU memory) and then running data through the model to produce outputs.\n", + "\n", + "To run this tutorial you will require AWS credentials to be correctly provisioned on your machine as all data is hosted in a requestor-pays bucket in AWS S3.\n", + "\n", + "Let's get started!" + ] + }, + { + "cell_type": "markdown", + "id": "ff29e3e0-0038-4365-9f0e-ebc2e835e2b8", + "metadata": {}, + "source": [ + "## Provisioning Cloud Credentials\n", + "\n", + "First, let's provision credentials to Daft! We can do so using the ``boto3`` library, and creating a Daft {class}`IOConfig ` object like so:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8bda28c8-94f1-492d-8e25-6f6be3327b07", + "metadata": {}, + "outputs": [], + "source": [ + "import boto3\n", + "import daft\n", + "\n", + "session = boto3.session.Session()\n", + "creds = session.get_credentials()\n", + "io_config = daft.io.IOConfig(\n", + " s3=daft.io.S3Config(\n", + " access_key=creds.secret_key,\n", + " key_id=creds.access_key,\n", + " session_token=creds.token,\n", + " region_name=\"us-west-2\",\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "7fd72180-8b2f-4e3a-8ee5-043cf2396bd7", + "metadata": {}, + "source": [ + "## Retrieving Data\n", + "\n", + "Now we're ready to read data from our DeltaLake table!\n", + "\n", + "We've hosted a 10k row sample of the validation set of imagenet for you to try this out.\n", + "\n", + "Simply pass in the ``IOConfig`` that we previously created to the call in order to ensure that we can access the data." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b18c8c0d-5711-4a3a-b2b8-7361162d9d00", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "
folder
Utf8
filename
Utf8
source
Struct[database: Utf8]
size
Struct[depth: Utf8, height: Utf8, width: Utf8]
segmented
Utf8
object
List[Struct[bndbox: Struct[xmax: Utf8, xmin: Utf8, ymax: Utf8, ymin: Utf8], difficult: Utf8, name: Utf8, pose: Utf8, truncated: Utf8]]
\n", + "(No data to display: Dataframe not materialized)\n", + "
" + ], + "text/plain": [ + "╭────────┬──────────┬────────────────────────┬─────────────────────────────┬───────────┬──────────────────────────╮\n", + "│ folder ┆ filename ┆ source ┆ size ┆ segmented ┆ object │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ Utf8 ┆ Utf8 ┆ Struct[database: Utf8] ┆ Struct[depth: Utf8, height: ┆ Utf8 ┆ List[Struct[bndbox: │\n", + "│ ┆ ┆ ┆ Utf8, width: Utf8] ┆ ┆ Struct[xmax: Utf8, xmin: │\n", + "│ ┆ ┆ ┆ ┆ ┆ Utf8, ymax: Utf8, ymin: │\n", + "│ ┆ ┆ ┆ ┆ ┆ Utf8], difficult: Utf8, │\n", + "│ ┆ ┆ ┆ ┆ ┆ name: Utf8, pose: Utf8, │\n", + "│ ┆ ┆ ┆ ┆ ┆ truncated: Utf8]] │\n", + "╰────────┴──────────┴────────────────────────┴─────────────────────────────┴───────────┴──────────────────────────╯\n", + "\n", + "(No data to display: Dataframe not materialized)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = daft.read_delta_lake(\"s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/\", io_config=io_config)\n", + "df" + ] + }, + { + "cell_type": "markdown", + "id": "9dd7cff7-8830-4372-b15d-2b9306e21e8a", + "metadata": {}, + "source": [ + "For this demo, we're running this on our local machine and thus will be limiting the total amount of data to 100." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "72c23712-c368-47f6-af75-901f849608f0", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a4b3bf42c51542ac9938c50d54b2e419", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "ScanWithTask-LocalLimit-LocalLimit-Project [Stage:1]: 0%| | 0/1 [00:00\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
folder
Utf8
filename
Utf8
object
List[Struct[bndbox: Struct[xmax: Utf8, xmin: Utf8, ymax: Utf8, ymin: Utf8], difficult: Utf8, name: Utf8, pose: Utf8, truncated: Utf8]]
val
ILSVRC2012_val_00000001
[{bndbox: {xmax: 441,
xmin: 111,
ymax: 193,
ymin: 108,
},
difficult: 0,
name: n01751748,
pose: Unspecified,
truncated: 0,
}]
val
ILSVRC2012_val_00000002
[{bndbox: {xmax: 499,
xmin: 45,
ymax: 162,
ymin: 49,
},
difficult: 0,
name: n09193705,
pose: Unspecified,
truncated: 0,
}, {bndbox: {xmax: 437,
xmin: 2,
ymax: 207,
ymin: 69,
},
difficult: 0,
name: n09193705,
pose: Unspecified,
truncated: 0,
}]
val
ILSVRC2012_val_00000003
[{bndbox: {xmax: 385,
xmin: 38,
ymax: 373,
ymin: 19,
},
difficult: 0,
name: n02105855,
pose: Unspecified,
truncated: 0,
}]
val
ILSVRC2012_val_00000004
[{bndbox: {xmax: 441,
xmin: 94,
ymax: 284,
ymin: 15,
},
difficult: 0,
name: n04263257,
pose: Unspecified,
truncated: 0,
}]
val
ILSVRC2012_val_00000005
[{bndbox: {xmax: 425,
xmin: 17,
ymax: 332,
ymin: 1,
},
difficult: 0,
name: n03125729,
pose: Unspecified,
truncated: 0,
}]
val
ILSVRC2012_val_00000006
[{bndbox: {xmax: 358,
xmin: 105,
ymax: 279,
ymin: 204,
},
difficult: 0,
name: n01735189,
pose: Unspecified,
truncated: 0,
}]
val
ILSVRC2012_val_00000007
[{bndbox: {xmax: 498,
xmin: 89,
ymax: 268,
ymin: 75,
},
difficult: 0,
name: n02346627,
pose: Unspecified,
truncated: 0,
}]
val
ILSVRC2012_val_00000008
[{bndbox: {xmax: 181,
xmin: 14,
ymax: 328,
ymin: 163,
},
difficult: 0,
name: n02776631,
pose: Unspecified,
truncated: 0,
}, {bndbox: {xmax: 331,
xmin: 176,
ymax: 223,
ymin: 81,
},
difficult: 0,
name: n02776631,
pose: Unspecified,
truncated: 0,
}, {bndbox: {xmax: 236,
xmin: 77,
ymax: 155,
ymin: 2,
},
difficult: 0,
name: n02776631,
pose: Unspecified,
truncated: 0,
}, {bndbox: {xmax: 355,
xmin: 163,
ymax: 374,
ymin: 219,
},
difficult: 0,
name: n02776631,
pose: Unspecified,
truncated: 0,
}]
\n", + "(Showing first 8 of 100 rows)\n", + "" + ], + "text/plain": [ + "╭────────┬─────────────────────────┬─────────────────────────────────────────────────────────────────────────────╮\n", + "│ folder ┆ filename ┆ object │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ Utf8 ┆ Utf8 ┆ List[Struct[bndbox: Struct[xmax: Utf8, xmin: Utf8, ymax: Utf8, ymin: Utf8], │\n", + "│ ┆ ┆ difficult: Utf8, name: Utf8, pose: Utf8, truncated: Utf8]] │\n", + "╞════════╪═════════════════════════╪═════════════════════════════════════════════════════════════════════════════╡\n", + "│ val ┆ ILSVRC2012_val_00000001 ┆ [{bndbox: {xmax: 441, │\n", + "│ ┆ ┆ xmin: 1… │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_00000002 ┆ [{bndbox: {xmax: 499, │\n", + "│ ┆ ┆ xmin: 4… │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_00000003 ┆ [{bndbox: {xmax: 385, │\n", + "│ ┆ ┆ xmin: 3… │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_00000004 ┆ [{bndbox: {xmax: 441, │\n", + "│ ┆ ┆ xmin: 9… │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_00000005 ┆ [{bndbox: {xmax: 425, │\n", + "│ ┆ ┆ xmin: 1… │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_00000006 ┆ [{bndbox: {xmax: 358, │\n", + "│ ┆ ┆ xmin: 1… │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_00000007 ┆ [{bndbox: {xmax: 498, │\n", + "│ ┆ ┆ xmin: 8… │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_00000008 ┆ [{bndbox: {xmax: 181, │\n", + "│ ┆ ┆ xmin: 1… │\n", + "╰────────┴─────────────────────────┴─────────────────────────────────────────────────────────────────────────────╯\n", + "\n", + "(Showing first 8 of 100 rows)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = df.limit(100)\n", + "df = df.select(\"folder\", \"filename\", \"object\")\n", + "df.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "ac0d9686-4d93-42b5-9c49-6f96dd0d5522", + "metadata": {}, + "source": [ + "## Retrieving Images\n", + "\n", + "Let's now resolve the images to their URLs, and start downloading/decoding them into images in our dataframe!" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "06c7d63a-582e-4e67-ba8e-054b5704c40f", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "df = df.with_column(\n", + " \"image_url\",\n", + " \"s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/\" + df[\"filename\"] + \".jpeg\"\n", + ")\n", + "df = df.with_column(\"image\", df[\"image_url\"].url.download().image.decode())" + ] + }, + { + "cell_type": "markdown", + "id": "7ec1400f-fd32-4b96-8bf7-32b333bd7063", + "metadata": {}, + "source": [ + "We also want to do a little preprocessing on our images to get them all into the same size. We can do this with the {meth}`.image.resize ` method!" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e7f0ba8f-7309-4ed1-bc3f-f5f6fa319980", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
folder
Utf8
filename
Utf8
object
List[Struct[bndbox: Struct[xmax: Utf8, xmin: Utf8, ymax: Utf8, ymin: Utf8], difficult: Utf8, name: Utf8, pose: Utf8, truncated: Utf8]]
image_url
Utf8
image
Image[MIXED]
image_resized_small
Image[MIXED]
image_resized_large
Image[MIXED]
val
ILSVRC2012_val_00000001
[{bndbox: {xmax: 441,
xmin: 111,
ymax: 193,
ymin: 108,
},
difficult: 0,
name: n01751748,
pose: Unspecified,
truncated: 0,
}]
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000001.jpeg
\"<Image\" />
\"<Image\" />
\"<Image\" />
val
ILSVRC2012_val_00000002
[{bndbox: {xmax: 499,
xmin: 45,
ymax: 162,
ymin: 49,
},
difficult: 0,
name: n09193705,
pose: Unspecified,
truncated: 0,
}, {bndbox: {xmax: 437,
xmin: 2,
ymax: 207,
ymin: 69,
},
difficult: 0,
name: n09193705,
pose: Unspecified,
truncated: 0,
}]
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000002.jpeg
\"<Image\" />
\"<Image\" />
\"<Image\" />
val
ILSVRC2012_val_00000003
[{bndbox: {xmax: 385,
xmin: 38,
ymax: 373,
ymin: 19,
},
difficult: 0,
name: n02105855,
pose: Unspecified,
truncated: 0,
}]
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000003.jpeg
\"<Image\" />
\"<Image\" />
\"<Image\" />
val
ILSVRC2012_val_00000004
[{bndbox: {xmax: 441,
xmin: 94,
ymax: 284,
ymin: 15,
},
difficult: 0,
name: n04263257,
pose: Unspecified,
truncated: 0,
}]
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000004.jpeg
\"<Image\" />
\"<Image\" />
\"<Image\" />
\n", + "(Showing first 4 rows)\n", + "
" + ], + "text/plain": [ + "╭────────┬────────────────────┬────────────────────┬────────────┬──────────────┬───────────────────┬───────────────────╮\n", + "│ folder ┆ filename ┆ object ┆ … ┆ image ┆ image_resized_sma ┆ image_resized_lar │\n", + "│ --- ┆ --- ┆ --- ┆ ┆ --- ┆ ll ┆ ge │\n", + "│ Utf8 ┆ Utf8 ┆ List[Struct[bndbox ┆ (1 hidden) ┆ Image[MIXED] ┆ --- ┆ --- │\n", + "│ ┆ ┆ : Struct[xmax: ┆ ┆ ┆ Image[MIXED] ┆ Image[MIXED] │\n", + "│ ┆ ┆ Utf8, xmin: Utf8, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ ymax: Utf8, ymin: ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ Utf8], difficult: ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ Utf8, name: Utf8, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ pose: Utf8, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ truncated: Utf8]] ┆ ┆ ┆ ┆ │\n", + "╞════════╪════════════════════╪════════════════════╪════════════╪══════════════╪═══════════════════╪═══════════════════╡\n", + "│ val ┆ ILSVRC2012_val_000 ┆ [{bndbox: {xmax: ┆ … ┆ │\n", + "│ ┆ 00001 ┆ 441, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ xmin: 1… ┆ ┆ ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_000 ┆ [{bndbox: {xmax: ┆ … ┆ │\n", + "│ ┆ 00002 ┆ 499, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ xmin: 4… ┆ ┆ ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_000 ┆ [{bndbox: {xmax: ┆ … ┆ │\n", + "│ ┆ 00003 ┆ 385, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ xmin: 3… ┆ ┆ ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_000 ┆ [{bndbox: {xmax: ┆ … ┆ │\n", + "│ ┆ 00004 ┆ 441, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ xmin: 9… ┆ ┆ ┆ ┆ │\n", + "╰────────┴────────────────────┴────────────────────┴────────────┴──────────────┴───────────────────┴───────────────────╯\n", + "\n", + "(Showing first 4 rows)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df = df.with_column(\"image_resized_small\", df[\"image\"].image.resize(32, 32))\n", + "df = df.with_column(\"image_resized_large\", df[\"image\"].image.resize(256, 256))\n", + "df.show(4)" + ] + }, + { + "cell_type": "markdown", + "id": "e95ffcf9-a3b7-4132-9711-aa10a1a752f9", + "metadata": {}, + "source": [ + "## Running Batch Inference\n", + "\n", + "Great! We now have our images nicely preprocessed, and are ready to run batch inference on them.\n", + "\n", + "Let's run a simple ResNet image classifier on each image's \"high-resolution\" and \"low-resolution\" variant, to see how sensitive our model is to the resolution of the image!\n", + "\n", + "First off, we define a \"Stateful UDF\" that will initialize our model once in the ``__init__`` method, and then use the same model across multiple invocations on different partitions of data." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "262a01f8-16f7-44ac-92c0-a8b264088ba7", + "metadata": {}, + "outputs": [], + "source": [ + "import daft\n", + "import numpy as np\n", + "import torch\n", + "from torchvision.models import resnet50, ResNet50_Weights\n", + "\n", + "@daft.udf(return_dtype=daft.DataType.string())\n", + "class ClassifyImage:\n", + " def __init__(self):\n", + " weights = ResNet50_Weights.DEFAULT\n", + " self.model = resnet50(weights=weights)\n", + " self.model.eval()\n", + " self.preprocess = weights.transforms()\n", + " self.category_map = weights.meta[\"categories\"]\n", + "\n", + " def __call__(self, images: daft.Series, shape: list[int, int, int]):\n", + " if len(images) == 0:\n", + " return []\n", + "\n", + " # Convert the Daft Series into a list of Numpy arrays\n", + " data = images.cast(daft.DataType.tensor(daft.DataType.uint8(), tuple(shape))).to_pylist()\n", + "\n", + " # Convert the numpy arrays into a torch tensor\n", + " images_array = torch.tensor(np.array(data)).permute((0, 3, 1, 2))\n", + "\n", + " # Run the model, and map results back to a human-readable string\n", + " batch = self.preprocess(images_array)\n", + " prediction = self.model(batch).softmax(0)\n", + " class_ids = prediction.argmax(1)\n", + " scores = prediction[:, class_ids]\n", + " return [self.category_map[class_id] for class_id in class_ids]\n" + ] + }, + { + "cell_type": "markdown", + "id": "651c1cae-d0af-46b9-b368-f156e2c71ce6", + "metadata": {}, + "source": [ + "To run our model on the dataframe, simply call the ``ClassifyImage`` function we defined earlier on the columns!\n", + "\n", + "NOTE: If we wanted to ensure that our UDF will run with a GPU, we can specify:\n", + "\n", + "```\n", + "df.with_column(..., resource_request=daft.ResourceRequest(num_gpus=1))\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9cbdb759-a6a0-493e-83c2-908dfd2f8051", + "metadata": {}, + "outputs": [], + "source": [ + "df = df.with_column(\"predictions_lowres\", ClassifyImage(df[\"image_resized_small\"], [32, 32, 3]))\n", + "df = df.with_column(\"predictions_highres\", ClassifyImage(df[\"image_resized_large\"], [256, 256, 3]))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f2c250a2-9159-46a8-b8d9-e1db363a99fa", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
folder
Utf8
filename
Utf8
object
List[Struct[bndbox: Struct[xmax: Utf8, xmin: Utf8, ymax: Utf8, ymin: Utf8], difficult: Utf8, name: Utf8, pose: Utf8, truncated: Utf8]]
image_url
Utf8
image
Image[MIXED]
image_resized_small
Image[MIXED]
image_resized_large
Image[MIXED]
predictions_lowres
Utf8
predictions_highres
Utf8
val
ILSVRC2012_val_00000001
[{bndbox: {xmax: 441,
xmin: 111,
ymax: 193,
ymin: 108,
},
difficult: 0,
name: n01751748,
pose: Unspecified,
truncated: 0,
}]
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000001.jpeg
\"<Image\" />
\"<Image\" />
\"<Image\" />
rock python
rock python
val
ILSVRC2012_val_00000003
[{bndbox: {xmax: 385,
xmin: 38,
ymax: 373,
ymin: 19,
},
difficult: 0,
name: n02105855,
pose: Unspecified,
truncated: 0,
}]
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000003.jpeg
\"<Image\" />
\"<Image\" />
\"<Image\" />
Shetland sheepdog
Shetland sheepdog
val
ILSVRC2012_val_00000004
[{bndbox: {xmax: 441,
xmin: 94,
ymax: 284,
ymin: 15,
},
difficult: 0,
name: n04263257,
pose: Unspecified,
truncated: 0,
}]
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000004.jpeg
\"<Image\" />
\"<Image\" />
\"<Image\" />
eggnog
soup bowl
val
ILSVRC2012_val_00000005
[{bndbox: {xmax: 425,
xmin: 17,
ymax: 332,
ymin: 1,
},
difficult: 0,
name: n03125729,
pose: Unspecified,
truncated: 0,
}]
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000005.jpeg
\"<Image\" />
\"<Image\" />
\"<Image\" />
packet
cradle
\n", + "(Showing first 4 rows)\n", + "
" + ], + "text/plain": [ + "╭────────┬───────────────────┬───────────────────┬────────────┬──────────────────┬──────────────────┬──────────────────╮\n", + "│ folder ┆ filename ┆ object ┆ … ┆ image_resized_la ┆ predictions_lowr ┆ predictions_high │\n", + "│ --- ┆ --- ┆ --- ┆ ┆ rge ┆ es ┆ res │\n", + "│ Utf8 ┆ Utf8 ┆ List[Struct[bndbo ┆ (3 hidden) ┆ --- ┆ --- ┆ --- │\n", + "│ ┆ ┆ x: Struct[xmax: ┆ ┆ Image[MIXED] ┆ Utf8 ┆ Utf8 │\n", + "│ ┆ ┆ Utf8, xmin: Utf8, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ ymax: Utf8, ymin: ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ Utf8], difficult: ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ Utf8, name: Utf8, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ pose: Utf8, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ truncated: Utf8]] ┆ ┆ ┆ ┆ │\n", + "╞════════╪═══════════════════╪═══════════════════╪════════════╪══════════════════╪══════════════════╪══════════════════╡\n", + "│ val ┆ ILSVRC2012_val_00 ┆ [{bndbox: {xmax: ┆ … ┆ ┆ rock python ┆ rock python │\n", + "│ ┆ 000001 ┆ 441, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ xmin: 1… ┆ ┆ ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_00 ┆ [{bndbox: {xmax: ┆ … ┆ ┆ Shetland ┆ Shetland │\n", + "│ ┆ 000003 ┆ 385, ┆ ┆ ┆ sheepdog ┆ sheepdog │\n", + "│ ┆ ┆ xmin: 3… ┆ ┆ ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_00 ┆ [{bndbox: {xmax: ┆ … ┆ ┆ eggnog ┆ soup bowl │\n", + "│ ┆ 000004 ┆ 441, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ xmin: 9… ┆ ┆ ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_00 ┆ [{bndbox: {xmax: ┆ … ┆ ┆ packet ┆ cradle │\n", + "│ ┆ 000005 ┆ 425, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ xmin: 1… ┆ ┆ ┆ ┆ │\n", + "╰────────┴───────────────────┴───────────────────┴────────────┴──────────────────┴──────────────────┴──────────────────╯\n", + "\n", + "(Showing first 4 rows)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df.show(4)" + ] + }, + { + "cell_type": "markdown", + "id": "5924409b-a923-46d9-8801-c10234b59ec5", + "metadata": {}, + "source": [ + "Pretty cool! looks like decreasing the resolution of the image too much does have a strong effect on the model's performance, as expected.\n", + "\n", + "We can go ahead and show **just** the rows that have show this behavior. We will also need to filter for rows where the image does not have 3 channels because that will break our code.\n", + "\n", + "Note that the following cell will now take a much longer time to run as we need to run the model on all the rows instead of just the first 4!" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b13facc4-b1c9-48af-bb21-bfa182febeca", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
folder
Utf8
filename
Utf8
object
List[Struct[bndbox: Struct[xmax: Utf8, xmin: Utf8, ymax: Utf8, ymin: Utf8], difficult: Utf8, name: Utf8, pose: Utf8, truncated: Utf8]]
image_url
Utf8
image
Image[MIXED]
image_resized_small
Image[MIXED]
image_resized_large
Image[MIXED]
predictions_lowres
Utf8
predictions_highres
Utf8
val
ILSVRC2012_val_00000004
[{bndbox: {xmax: 441,
xmin: 94,
ymax: 284,
ymin: 15,
},
difficult: 0,
name: n04263257,
pose: Unspecified,
truncated: 0,
}]
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000004.jpeg
\"<Image\" />
\"<Image\" />
\"<Image\" />
eggnog
soup bowl
val
ILSVRC2012_val_00000005
[{bndbox: {xmax: 425,
xmin: 17,
ymax: 332,
ymin: 1,
},
difficult: 0,
name: n03125729,
pose: Unspecified,
truncated: 0,
}]
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000005.jpeg
\"<Image\" />
\"<Image\" />
\"<Image\" />
packet
cradle
val
ILSVRC2012_val_00000006
[{bndbox: {xmax: 358,
xmin: 105,
ymax: 279,
ymin: 204,
},
difficult: 0,
name: n01735189,
pose: Unspecified,
truncated: 0,
}]
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000006.jpeg
\"<Image\" />
\"<Image\" />
\"<Image\" />
golf ball
sidewinder
val
ILSVRC2012_val_00000007
[{bndbox: {xmax: 498,
xmin: 89,
ymax: 268,
ymin: 75,
},
difficult: 0,
name: n02346627,
pose: Unspecified,
truncated: 0,
}]
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000007.jpeg
\"<Image\" />
\"<Image\" />
\"<Image\" />
Madagascar cat
porcupine
\n", + "(Showing first 4 rows)\n", + "
" + ], + "text/plain": [ + "╭────────┬───────────────────┬───────────────────┬────────────┬──────────────────┬──────────────────┬──────────────────╮\n", + "│ folder ┆ filename ┆ object ┆ … ┆ image_resized_la ┆ predictions_lowr ┆ predictions_high │\n", + "│ --- ┆ --- ┆ --- ┆ ┆ rge ┆ es ┆ res │\n", + "│ Utf8 ┆ Utf8 ┆ List[Struct[bndbo ┆ (3 hidden) ┆ --- ┆ --- ┆ --- │\n", + "│ ┆ ┆ x: Struct[xmax: ┆ ┆ Image[MIXED] ┆ Utf8 ┆ Utf8 │\n", + "│ ┆ ┆ Utf8, xmin: Utf8, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ ymax: Utf8, ymin: ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ Utf8], difficult: ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ Utf8, name: Utf8, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ pose: Utf8, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ truncated: Utf8]] ┆ ┆ ┆ ┆ │\n", + "╞════════╪═══════════════════╪═══════════════════╪════════════╪══════════════════╪══════════════════╪══════════════════╡\n", + "│ val ┆ ILSVRC2012_val_00 ┆ [{bndbox: {xmax: ┆ … ┆ ┆ eggnog ┆ soup bowl │\n", + "│ ┆ 000004 ┆ 441, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ xmin: 9… ┆ ┆ ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_00 ┆ [{bndbox: {xmax: ┆ … ┆ ┆ packet ┆ cradle │\n", + "│ ┆ 000005 ┆ 425, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ xmin: 1… ┆ ┆ ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_00 ┆ [{bndbox: {xmax: ┆ … ┆ ┆ golf ball ┆ sidewinder │\n", + "│ ┆ 000006 ┆ 358, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ xmin: 1… ┆ ┆ ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ val ┆ ILSVRC2012_val_00 ┆ [{bndbox: {xmax: ┆ … ┆ ┆ Madagascar cat ┆ porcupine │\n", + "│ ┆ 000007 ┆ 498, ┆ ┆ ┆ ┆ │\n", + "│ ┆ ┆ xmin: 8… ┆ ┆ ┆ ┆ │\n", + "╰────────┴───────────────────┴───────────────────┴────────────┴──────────────────┴──────────────────┴──────────────────╯\n", + "\n", + "(Showing first 4 rows)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Filter out images where the number of channels != 3\n", + "df = df.where(df[\"image\"].apply(lambda img: img.shape[2] == 3, return_dtype=daft.DataType.bool()))\n", + "\n", + "# Show only rows where the predictions on the low-res/high-res images don't match\n", + "df = df.where(df[\"predictions_lowres\"] != df[\"predictions_highres\"])\n", + "\n", + "df.show(4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9bb2c154-a8a1-4d4a-97b6-7096f2a33df7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/delta_lake/2-distributed-batch-inference.ipynb b/tutorials/delta_lake/2-distributed-batch-inference.ipynb new file mode 100644 index 0000000000..203c15767f --- /dev/null +++ b/tutorials/delta_lake/2-distributed-batch-inference.ipynb @@ -0,0 +1,454 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e2893f93-72d9-40ea-93e9-be2d3f3f66ee", + "metadata": {}, + "source": [ + "# Distributed ML model batch inference on data in DeltaLake\n", + "\n", + "In this tutorial, we showcase how to perform ML model batch inference on data in a DeltaLake table.\n", + "\n", + "This is a continuation of the previous tutorial on **local** batch inference, which is a great way to get started and make sure that your code is working before graduating to larger scales in a distributed batch inference workload. Make sure to give that a read before looking at this tutorial!\n", + "\n", + "To run this tutorial you will require AWS credentials to be correctly provisioned on your machine as all data is hosted in a requestor-pays bucket in AWS S3.\n", + "\n", + "Let's get started!\n", + "\n", + "# Going Distributed\n", + "\n", + "The first step (and most important for this demo!) is to switch our Daft runner to the Ray runner, and point it at a Ray cluster. This is super simple:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8e600443-3931-44f2-b814-0056e42da612", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DaftContext(_daft_execution_config=, _daft_planning_config=, _runner_config=_RayRunnerConfig(address=None, max_task_backlog=None), _disallow_set_runner=True, _runner=None)" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import daft\n", + "\n", + "# If you have your own Ray cluster running, feel free to set this to that address!\n", + "# RAY_ADDRESS = \"ray://localhost:10001\"\n", + "RAY_ADDRESS = None\n", + "\n", + "daft.context.set_runner_ray(address=RAY_ADDRESS)" + ] + }, + { + "cell_type": "markdown", + "id": "1fdf0722-eff4-485d-84e8-f4c74e79caca", + "metadata": {}, + "source": [ + "Now, we run the same operations as before. The only difference is that instead of execution happening locally on the machine that's running this code, Daft will distribute the computation over your Ray cluster!" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "569c6297-dcfd-4013-9a04-e3fc8f0ca315", + "metadata": {}, + "outputs": [], + "source": [ + "# Feel free to tweak this variable to have the tutorial run on as many rows as you'd like!\n", + "NUM_ROWS = 1000" + ] + }, + { + "cell_type": "markdown", + "id": "08ff4bf0-7b5f-4884-80d3-95d7b9005a8b", + "metadata": {}, + "source": [ + "### Retrieving data\n", + "\n", + "We will be retrieving the data exactly the same way we did in the previous tutorial, with the same API and arguments." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "83a76976-aed6-49ea-8c8e-1572947d93ec", + "metadata": {}, + "outputs": [], + "source": [ + "# Provision Cloud Credentials\n", + "import boto3\n", + "import daft\n", + "\n", + "session = boto3.session.Session()\n", + "creds = session.get_credentials()\n", + "io_config = daft.io.IOConfig(\n", + " s3=daft.io.S3Config(\n", + " access_key=creds.secret_key,\n", + " key_id=creds.access_key,\n", + " session_token=creds.token,\n", + " region_name=\"us-west-2\",\n", + " )\n", + ")\n", + "\n", + "# Retrieve data\n", + "df = daft.read_delta_lake(\"s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/\", io_config=io_config)\n", + "\n", + "# Prune data\n", + "df = df.limit(NUM_ROWS)\n", + "df = df.where(df[\"object\"].list.lengths() == 1)" + ] + }, + { + "cell_type": "markdown", + "id": "8043c7e1-c350-449b-bd93-4a5ca93adc4d", + "metadata": {}, + "source": [ + "### Splitting the data into more partitions\n", + "\n", + "We now split the data into more partitions for additional parallelism when performing our data processing in a **distributed** fashion" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d553c284-5f4e-435b-a0f2-c35bf4fc09d2", + "metadata": {}, + "outputs": [], + "source": [ + "df = df.into_partitions(16)" + ] + }, + { + "cell_type": "markdown", + "id": "acc6f220-aaef-463c-ae51-014bddc14231", + "metadata": {}, + "source": [ + "### Retrieving the images and preprocessing\n", + "\n", + "Now we continue with exactly the same code as in the local case for retrieving and preprocessing our images" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "317e2778-4986-4993-ab4d-0426e5fee149", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Retrieve images and run preprocessing\n", + "df = df.with_column(\n", + " \"image_url\",\n", + " \"s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/\" + df[\"filename\"] + \".jpeg\"\n", + ")\n", + "df = df.with_column(\"image\", df[\"image_url\"].url.download().image.decode())\n", + "df = df.with_column(\"image_resized_small\", df[\"image\"].image.resize(32, 32))\n", + "df = df.with_column(\"image_resized_large\", df[\"image\"].image.resize(256, 256))" + ] + }, + { + "cell_type": "markdown", + "id": "7dc7cb6a-9c03-4386-8f31-7e5559b370b3", + "metadata": {}, + "source": [ + "### Running batch inference with a UDF\n", + "\n", + "Running the UDF is also exactly the same!" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "29b21c64-026e-43bd-aea0-48ae3a452b7b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-03-29 19:38:18,040\tINFO worker.py:1642 -- Started a local Ray instance.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c76dafc0f26b4f4782089c7381e4baa0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "ScanWithTask-LocalLimit-LocalLimit-Project-Filter [Stage:3]: 0%| | 0/1 [00:00\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
path
Utf8
my_results.parquet/8eb54f00-9537-4e28-ac85-e96a00a071d5-0.parquet
my_results.parquet/04ccf8fe-9777-4307-9e1f-916c8532ca1c-0.parquet
my_results.parquet/867fc77f-f730-4b53-8e9a-11ed5dc9b98f-0.parquet
my_results.parquet/e4645f7b-8a70-4ee8-8221-823777467a0a-0.parquet
my_results.parquet/dd41fced-6e6b-4ece-8e58-d0804311b4ff-0.parquet
my_results.parquet/c548e6f4-3c83-4f76-b7c5-821f81157720-0.parquet
my_results.parquet/28753019-9875-45a2-94b4-b7b9217492ca-0.parquet
my_results.parquet/f66ffaa6-cc2e-4328-8137-aa358244a8a3-0.parquet
\n", + "(Showing first 8 of 16 rows)\n", + "" + ], + "text/plain": [ + "╭────────────────────────────────╮\n", + "│ path │\n", + "│ --- │\n", + "│ Utf8 │\n", + "╞════════════════════════════════╡\n", + "│ my_results.parquet/8eb54f00-9… │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ my_results.parquet/04ccf8fe-9… │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ my_results.parquet/867fc77f-f… │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ my_results.parquet/e4645f7b-8… │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ my_results.parquet/dd41fced-6… │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ my_results.parquet/c548e6f4-3… │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ my_results.parquet/28753019-9… │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ my_results.parquet/f66ffaa6-c… │\n", + "╰────────────────────────────────╯\n", + "\n", + "(Showing first 8 of 16 rows)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Run batch inference over the entire dataset\n", + "import daft\n", + "import numpy as np\n", + "import torch\n", + "from torchvision.models import resnet50, ResNet50_Weights\n", + "\n", + "@daft.udf(return_dtype=daft.DataType.string())\n", + "class ClassifyImage:\n", + " def __init__(self):\n", + " weights = ResNet50_Weights.DEFAULT\n", + " self.model = resnet50(weights=weights)\n", + " self.model.eval()\n", + " self.preprocess = weights.transforms()\n", + " self.category_map = weights.meta[\"categories\"]\n", + "\n", + " def __call__(self, images: daft.Series, shape: list[int, int, int]):\n", + " if len(images) == 0:\n", + " return []\n", + "\n", + " # Convert the Daft Series into a list of Numpy arrays\n", + " data = images.cast(daft.DataType.tensor(daft.DataType.uint8(), tuple(shape))).to_pylist()\n", + "\n", + " # Convert the numpy arrays into a torch tensor\n", + " images_array = torch.tensor(np.array(data)).permute((0, 3, 1, 2))\n", + "\n", + " # Run the model, and map results back to a human-readable string\n", + " batch = self.preprocess(images_array)\n", + " prediction = self.model(batch).softmax(0)\n", + " class_ids = prediction.argmax(1)\n", + " scores = prediction[:, class_ids]\n", + " return [self.category_map[class_id] for class_id in class_ids]\n", + "\n", + "# Filter out rows where the channel != 3\n", + "df = df.where(df[\"image\"].apply(lambda img: img.shape[2] == 3, return_dtype=daft.DataType.bool()))\n", + "\n", + "df = df.with_column(\"predictions_lowres\", ClassifyImage(df[\"image_resized_small\"], [32, 32, 3]))\n", + "df = df.with_column(\"predictions_highres\", ClassifyImage(df[\"image_resized_large\"], [256, 256, 3]))\n", + "\n", + "# Prune the results and write data back out as Parquet\n", + "df = df.select(\n", + " \"filename\",\n", + " \"image_url\",\n", + " \"object\",\n", + " \"predictions_lowres\",\n", + " \"predictions_highres\",\n", + ")\n", + "df.write_parquet(\"my_results.parquet\")" + ] + }, + { + "cell_type": "markdown", + "id": "fa593498-655b-4e41-87bc-05fe70a5ab66", + "metadata": {}, + "source": [ + "# Now, take a look at your handiwork!\n", + "\n", + "Let's read the results of our distributed Daft job!" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c82c38f5-dccd-484c-856a-92362e147412", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "26e5c0781a57481cbb0fb454a43296a0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "ScanWithTask [Stage:5]: 0%| | 0/1 [00:00\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
filename
Utf8
image_url
Utf8
object
List[Struct[bndbox: Struct[xmax: Utf8, xmin: Utf8, ymax: Utf8, ymin: Utf8], difficult: Utf8, name: Utf8, pose: Utf8, truncated: Utf8]]
predictions_lowres
Utf8
predictions_highres
Utf8
ILSVRC2012_val_00000244
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000244.jpeg
[{bndbox: {xmax: 251,
xmin: 187,
ymax: 376,
ymin: 18,
},
difficult: 0,
name: n04090263,
pose: Unspecified,
truncated: 0,
}]
langur
rifle
ILSVRC2012_val_00000245
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000245.jpeg
[{bndbox: {xmax: 297,
xmin: 208,
ymax: 197,
ymin: 134,
},
difficult: 0,
name: n01986214,
pose: Unspecified,
truncated: 0,
}]
banded gecko
hermit crab
ILSVRC2012_val_00000247
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000247.jpeg
[{bndbox: {xmax: 362,
xmin: 50,
ymax: 476,
ymin: 92,
},
difficult: 0,
name: n01534433,
pose: Unspecified,
truncated: 0,
}]
junco
junco
ILSVRC2012_val_00000248
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000248.jpeg
[{bndbox: {xmax: 407,
xmin: 0,
ymax: 374,
ymin: 0,
},
difficult: 0,
name: n02093256,
pose: Unspecified,
truncated: 0,
}]
Staffordshire bullterrier
Staffordshire bullterrier
ILSVRC2012_val_00000250
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000250.jpeg
[{bndbox: {xmax: 422,
xmin: 10,
ymax: 498,
ymin: 13,
},
difficult: 0,
name: n03016953,
pose: Unspecified,
truncated: 0,
}]
wardrobe
wardrobe
ILSVRC2012_val_00000251
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000251.jpeg
[{bndbox: {xmax: 431,
xmin: 303,
ymax: 217,
ymin: 71,
},
difficult: 0,
name: n12620546,
pose: Unspecified,
truncated: 0,
}]
hip
hip
ILSVRC2012_val_00000252
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000252.jpeg
[{bndbox: {xmax: 498,
xmin: 1,
ymax: 498,
ymin: 1,
},
difficult: 0,
name: n03937543,
pose: Unspecified,
truncated: 0,
}]
face powder
bottlecap
ILSVRC2012_val_00000253
s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/ILSVRC2012_val_00000253.jpeg
[{bndbox: {xmax: 306,
xmin: 121,
ymax: 396,
ymin: 177,
},
difficult: 0,
name: n02815834,
pose: Unspecified,
truncated: 0,
}]
espresso maker
beaker
\n", + "(Showing first 8 of 745 rows)\n", + "" + ], + "text/plain": [ + "╭───────────────────────┬───────────────────────┬───────────────────────┬───────────────────────┬──────────────────────╮\n", + "│ filename ┆ image_url ┆ object ┆ predictions_lowres ┆ predictions_highres │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ Utf8 ┆ Utf8 ┆ List[Struct[bndbox: ┆ Utf8 ┆ Utf8 │\n", + "│ ┆ ┆ Struct[xmax: Utf8, ┆ ┆ │\n", + "│ ┆ ┆ xmin: Utf8, ymax: ┆ ┆ │\n", + "│ ┆ ┆ Utf8, ymin: Utf8], ┆ ┆ │\n", + "│ ┆ ┆ difficult: Utf8, ┆ ┆ │\n", + "│ ┆ ┆ name: Utf8, pose: ┆ ┆ │\n", + "│ ┆ ┆ Utf8, truncated: ┆ ┆ │\n", + "│ ┆ ┆ Utf8]] ┆ ┆ │\n", + "╞═══════════════════════╪═══════════════════════╪═══════════════════════╪═══════════════════════╪══════════════════════╡\n", + "│ ILSVRC2012_val_000002 ┆ s3://daft-public-data ┆ [{bndbox: {xmax: 251, ┆ langur ┆ rifle │\n", + "│ 44 ┆ sets/ima… ┆ xmin: 1… ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ ILSVRC2012_val_000002 ┆ s3://daft-public-data ┆ [{bndbox: {xmax: 297, ┆ banded gecko ┆ hermit crab │\n", + "│ 45 ┆ sets/ima… ┆ xmin: 2… ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ ILSVRC2012_val_000002 ┆ s3://daft-public-data ┆ [{bndbox: {xmax: 362, ┆ junco ┆ junco │\n", + "│ 47 ┆ sets/ima… ┆ xmin: 5… ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ ILSVRC2012_val_000002 ┆ s3://daft-public-data ┆ [{bndbox: {xmax: 407, ┆ Staffordshire ┆ Staffordshire │\n", + "│ 48 ┆ sets/ima… ┆ xmin: 0… ┆ bullterrier ┆ bullterrier │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ ILSVRC2012_val_000002 ┆ s3://daft-public-data ┆ [{bndbox: {xmax: 422, ┆ wardrobe ┆ wardrobe │\n", + "│ 50 ┆ sets/ima… ┆ xmin: 1… ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ ILSVRC2012_val_000002 ┆ s3://daft-public-data ┆ [{bndbox: {xmax: 431, ┆ hip ┆ hip │\n", + "│ 51 ┆ sets/ima… ┆ xmin: 3… ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ ILSVRC2012_val_000002 ┆ s3://daft-public-data ┆ [{bndbox: {xmax: 498, ┆ face powder ┆ bottlecap │\n", + "│ 52 ┆ sets/ima… ┆ xmin: 1… ┆ ┆ │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ ILSVRC2012_val_000002 ┆ s3://daft-public-data ┆ [{bndbox: {xmax: 306, ┆ espresso maker ┆ beaker │\n", + "│ 53 ┆ sets/ima… ┆ xmin: 1… ┆ ┆ │\n", + "╰───────────────────────┴───────────────────────┴───────────────────────┴───────────────────────┴──────────────────────╯\n", + "\n", + "(Showing first 8 of 745 rows)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "daft.read_parquet(\"my_results.parquet\").collect()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae3c32e5-a5c8-4db3-8f91-3beef30ca753", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/delta_lake/3-pytorch-ray-single-node-training.ipynb b/tutorials/delta_lake/3-pytorch-ray-single-node-training.ipynb new file mode 100644 index 0000000000..7a7ad5433d --- /dev/null +++ b/tutorials/delta_lake/3-pytorch-ray-single-node-training.ipynb @@ -0,0 +1,491 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fb722781-c5eb-489d-a1eb-37df19340c68", + "metadata": {}, + "source": [ + "# Pytorch model training on data in DeltaLake\n", + "\n", + "In this tutorial, we showcase how to perform ML model training on data in a DeltaLake table containing URLs pointing out to images.\n", + "\n", + "To run this tutorial you will require AWS credentials to be correctly provisioned on your machine as all data is hosted in a requestor-pays bucket in AWS S3.\n", + "\n", + "Let's get started!" + ] + }, + { + "cell_type": "markdown", + "id": "51109558-d43e-4591-bd93-8cb707fbc37e", + "metadata": {}, + "source": [ + "## Provisioning Cloud Credentials\n", + "\n", + "First, let's provision credentials to Daft! We can do so using the ``boto3`` library, and creating a Daft {class}`IOConfig ` object like so:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c30e3ddc-5c0f-44b1-ba98-3de4412d31aa", + "metadata": {}, + "outputs": [], + "source": [ + "import boto3\n", + "import daft\n", + "\n", + "session = boto3.session.Session()\n", + "creds = session.get_credentials()\n", + "io_config = daft.io.IOConfig(\n", + " s3=daft.io.S3Config(\n", + " access_key=creds.secret_key,\n", + " key_id=creds.access_key,\n", + " session_token=creds.token,\n", + " region_name=\"us-west-2\",\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "94123b7f-e325-4af8-a355-06966f96d031", + "metadata": {}, + "source": [ + "## Configuring Ray\n", + "\n", + "Now we configure Daft to run on the Ray backend.\n", + "\n", + "For the purposes of this tutorial, we define a ``USE_REMOTE_CLUSTER`` variable which will either have the tutorial run locally in the notebook (on a smaller subset of data), or on a remote Ray cluster (on a full subset of the data)." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b8047b98-55bd-4db0-81ea-fe2bebd3abed", + "metadata": {}, + "outputs": [], + "source": [ + "USE_REMOTE_CLUSTER = False\n", + "RAY_ADDRESS = \"ray://localhost:10001\" if USE_REMOTE_CLUSTER else None" + ] + }, + { + "cell_type": "markdown", + "id": "109e1bfb-2b2d-425c-9de2-8901ac1436fa", + "metadata": {}, + "source": [ + "Additionally, if running remotely we will want to ensure that the remote Ray cluster has access to all the dependencies required to run the code in this tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "78973b5e-ce68-4a5d-8f83-dc38393856af", + "metadata": {}, + "outputs": [], + "source": [ + "import ray\n", + "\n", + "# Set up connection to Ray cluster if USE_REMOTE_CLUSTER=True\n", + "if USE_REMOTE_CLUSTER:\n", + " ray.init(\n", + " address=RAY_ADDRESS,\n", + " runtime_env={\n", + " \"pip\": [\n", + " \"getdaft\",\n", + " \"torch\",\n", + " \"torchvision\",\n", + " ]\n", + " },\n", + " )\n", + " print(ray.available_resources())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "eb3d9ebd-821a-46ff-b047-5f5a6044fee2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DaftContext(_daft_execution_config=, _daft_planning_config=, _runner_config=_RayRunnerConfig(address=None, max_task_backlog=None), _disallow_set_runner=True, _runner=None)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "daft.context.set_runner_ray(address=RAY_ADDRESS)" + ] + }, + { + "cell_type": "markdown", + "id": "7570a74f-393b-4331-a0ce-aad996905544", + "metadata": {}, + "source": [ + "## Retrieving Data\n", + "\n", + "Now we're ready to get right to business!\n", + "\n", + "First, let's load some data from our DeltaLake table!!\n", + "\n", + "We've hosted a 10k row sample of the validation set of imagenet for you to try this out.\n", + "\n", + "Simply pass in the ``IOConfig`` that we previously created to the call in order to ensure that we can access the data." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "5662943d-45e1-4a4d-ab0f-e94c6d1259ec", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "
folder
Utf8
filename
Utf8
source
Struct[database: Utf8]
size
Struct[depth: Utf8, height: Utf8, width: Utf8]
segmented
Utf8
object
List[Struct[bndbox: Struct[xmax: Utf8, xmin: Utf8, ymax: Utf8, ymin: Utf8], difficult: Utf8, name: Utf8, pose: Utf8, truncated: Utf8]]
\n", + "(No data to display: Dataframe not materialized)\n", + "
" + ], + "text/plain": [ + "╭────────┬──────────┬────────────────────────┬─────────────────────────────┬───────────┬──────────────────────────╮\n", + "│ folder ┆ filename ┆ source ┆ size ┆ segmented ┆ object │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ Utf8 ┆ Utf8 ┆ Struct[database: Utf8] ┆ Struct[depth: Utf8, height: ┆ Utf8 ┆ List[Struct[bndbox: │\n", + "│ ┆ ┆ ┆ Utf8, width: Utf8] ┆ ┆ Struct[xmax: Utf8, xmin: │\n", + "│ ┆ ┆ ┆ ┆ ┆ Utf8, ymax: Utf8, ymin: │\n", + "│ ┆ ┆ ┆ ┆ ┆ Utf8], difficult: Utf8, │\n", + "│ ┆ ┆ ┆ ┆ ┆ name: Utf8, pose: Utf8, │\n", + "│ ┆ ┆ ┆ ┆ ┆ truncated: Utf8]] │\n", + "╰────────┴──────────┴────────────────────────┴─────────────────────────────┴───────────┴──────────────────────────╯\n", + "\n", + "(No data to display: Dataframe not materialized)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import daft\n", + "\n", + "df = daft.read_delta_lake(\"s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/\", io_config=io_config)\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "37a35164-989f-4161-b024-9164038f5d34", + "metadata": {}, + "outputs": [], + "source": [ + "# If running just locally in an example notebook, limit to 128 rows so we don't blow up!\n", + "if not USE_REMOTE_CLUSTER:\n", + " df = df.limit(128)" + ] + }, + { + "cell_type": "markdown", + "id": "14ab53cd-b33c-4008-90af-e01926f7aca8", + "metadata": {}, + "source": [ + "## Preprocessing\n", + "\n", + "We are now going to run some data pre-processing, which involves downloading the imagery data and running some basic image kernels." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "287bc36c-8c46-4da9-b8e3-1520ea7a229d", + "metadata": {}, + "outputs": [], + "source": [ + "# Download and resize images\n", + "df = df.with_column(\n", + " \"image_url\",\n", + " \"s3://daft-public-datasets/imagenet/val-10k-sample-deltalake/images/\" + df[\"filename\"] + \".jpeg\"\n", + ")\n", + "df = df.with_column(\"image\", df[\"image_url\"].url.download().image.decode())\n", + "df = df.with_column(\"image\", df[\"image\"].image.resize(256, 256))\n", + "\n", + "# Convert the images to a Tensor datatype\n", + "df = df.with_column(\n", + " \"arr\",\n", + " (\n", + " df[\"image\"]\n", + " .cast(daft.DataType.tensor(daft.DataType.uint8(), shape=(256, 256, 3)))\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d54a5a39-e04e-43af-b48f-27c98f66c8fb", + "metadata": {}, + "outputs": [], + "source": [ + "# Map class names to human-readable names and numeric IDs\n", + "classes = [('n01440764', 'tench'), ('n01443537', 'goldfish'), ('n01484850', 'great_white_shark'), ('n01491361', 'tiger_shark'), ('n01494475', 'hammerhead'), ('n01496331', 'electric_ray'), ('n01498041', 'stingray'), ('n01514668', 'cock'), ('n01514859', 'hen'), ('n01518878', 'ostrich'), ('n01530575', 'brambling'), ('n01531178', 'goldfinch'), ('n01532829', 'house_finch'), ('n01534433', 'junco'), ('n01537544', 'indigo_bunting'), ('n01558993', 'robin'), ('n01560419', 'bulbul'), ('n01580077', 'jay'), ('n01582220', 'magpie'), ('n01592084', 'chickadee'), ('n01601694', 'water_ouzel'), ('n01608432', 'kite'), ('n01614925', 'bald_eagle'), ('n01616318', 'vulture'), ('n01622779', 'great_grey_owl'), ('n01629819', 'European_fire_salamander'), ('n01630670', 'common_newt'), ('n01631663', 'eft'), ('n01632458', 'spotted_salamander'), ('n01632777', 'axolotl'), ('n01641577', 'bullfrog'), ('n01644373', 'tree_frog'), ('n01644900', 'tailed_frog'), ('n01664065', 'loggerhead'), ('n01665541', 'leatherback_turtle'), ('n01667114', 'mud_turtle'), ('n01667778', 'terrapin'), ('n01669191', 'box_turtle'), ('n01675722', 'banded_gecko'), ('n01677366', 'common_iguana'), ('n01682714', 'American_chameleon'), ('n01685808', 'whiptail'), ('n01687978', 'agama'), ('n01688243', 'frilled_lizard'), ('n01689811', 'alligator_lizard'), ('n01692333', 'Gila_monster'), ('n01693334', 'green_lizard'), ('n01694178', 'African_chameleon'), ('n01695060', 'Komodo_dragon'), ('n01697457', 'African_crocodile'), ('n01698640', 'American_alligator'), ('n01704323', 'triceratops'), ('n01728572', 'thunder_snake'), ('n01728920', 'ringneck_snake'), ('n01729322', 'hognose_snake'), ('n01729977', 'green_snake'), ('n01734418', 'king_snake'), ('n01735189', 'garter_snake'), ('n01737021', 'water_snake'), ('n01739381', 'vine_snake'), ('n01740131', 'night_snake'), ('n01742172', 'boa_constrictor'), ('n01744401', 'rock_python'), ('n01748264', 'Indian_cobra'), ('n01749939', 'green_mamba'), ('n01751748', 'sea_snake'), ('n01753488', 'horned_viper'), ('n01755581', 'diamondback'), ('n01756291', 'sidewinder'), ('n01768244', 'trilobite'), ('n01770081', 'harvestman'), ('n01770393', 'scorpion'), ('n01773157', 'black_and_gold_garden_spider'), ('n01773549', 'barn_spider'), ('n01773797', 'garden_spider'), ('n01774384', 'black_widow'), ('n01774750', 'tarantula'), ('n01775062', 'wolf_spider'), ('n01776313', 'tick'), ('n01784675', 'centipede'), ('n01795545', 'black_grouse'), ('n01796340', 'ptarmigan'), ('n01797886', 'ruffed_grouse'), ('n01798484', 'prairie_chicken'), ('n01806143', 'peacock'), ('n01806567', 'quail'), ('n01807496', 'partridge'), ('n01817953', 'African_grey'), ('n01818515', 'macaw'), ('n01819313', 'sulphur-crested_cockatoo'), ('n01820546', 'lorikeet'), ('n01824575', 'coucal'), ('n01828970', 'bee_eater'), ('n01829413', 'hornbill'), ('n01833805', 'hummingbird'), ('n01843065', 'jacamar'), ('n01843383', 'toucan'), ('n01847000', 'drake'), ('n01855032', 'red-breasted_merganser'), ('n01855672', 'goose'), ('n01860187', 'black_swan'), ('n01871265', 'tusker'), ('n01872401', 'echidna'), ('n01873310', 'platypus'), ('n01877812', 'wallaby'), ('n01882714', 'koala'), ('n01883070', 'wombat'), ('n01910747', 'jellyfish'), ('n01914609', 'sea_anemone'), ('n01917289', 'brain_coral'), ('n01924916', 'flatworm'), ('n01930112', 'nematode'), ('n01943899', 'conch'), ('n01944390', 'snail'), ('n01945685', 'slug'), ('n01950731', 'sea_slug'), ('n01955084', 'chiton'), ('n01968897', 'chambered_nautilus'), ('n01978287', 'Dungeness_crab'), ('n01978455', 'rock_crab'), ('n01980166', 'fiddler_crab'), ('n01981276', 'king_crab'), ('n01983481', 'American_lobster'), ('n01984695', 'spiny_lobster'), ('n01985128', 'crayfish'), ('n01986214', 'hermit_crab'), ('n01990800', 'isopod'), ('n02002556', 'white_stork'), ('n02002724', 'black_stork'), ('n02006656', 'spoonbill'), ('n02007558', 'flamingo'), ('n02009229', 'little_blue_heron'), ('n02009912', 'American_egret'), ('n02011460', 'bittern'), ('n02012849', 'crane'), ('n02013706', 'limpkin'), ('n02017213', 'European_gallinule'), ('n02018207', 'American_coot'), ('n02018795', 'bustard'), ('n02025239', 'ruddy_turnstone'), ('n02027492', 'red-backed_sandpiper'), ('n02028035', 'redshank'), ('n02033041', 'dowitcher'), ('n02037110', 'oystercatcher'), ('n02051845', 'pelican'), ('n02056570', 'king_penguin'), ('n02058221', 'albatross'), ('n02066245', 'grey_whale'), ('n02071294', 'killer_whale'), ('n02074367', 'dugong'), ('n02077923', 'sea_lion'), ('n02085620', 'Chihuahua'), ('n02085782', 'Japanese_spaniel'), ('n02085936', 'Maltese_dog'), ('n02086079', 'Pekinese'), ('n02086240', 'Shih-Tzu'), ('n02086646', 'Blenheim_spaniel'), ('n02086910', 'papillon'), ('n02087046', 'toy_terrier'), ('n02087394', 'Rhodesian_ridgeback'), ('n02088094', 'Afghan_hound'), ('n02088238', 'basset'), ('n02088364', 'beagle'), ('n02088466', 'bloodhound'), ('n02088632', 'bluetick'), ('n02089078', 'black-and-tan_coonhound'), ('n02089867', 'Walker_hound'), ('n02089973', 'English_foxhound'), ('n02090379', 'redbone'), ('n02090622', 'borzoi'), ('n02090721', 'Irish_wolfhound'), ('n02091032', 'Italian_greyhound'), ('n02091134', 'whippet'), ('n02091244', 'Ibizan_hound'), ('n02091467', 'Norwegian_elkhound'), ('n02091635', 'otterhound'), ('n02091831', 'Saluki'), ('n02092002', 'Scottish_deerhound'), ('n02092339', 'Weimaraner'), ('n02093256', 'Staffordshire_bullterrier'), ('n02093428', 'American_Staffordshire_terrier'), ('n02093647', 'Bedlington_terrier'), ('n02093754', 'Border_terrier'), ('n02093859', 'Kerry_blue_terrier'), ('n02093991', 'Irish_terrier'), ('n02094114', 'Norfolk_terrier'), ('n02094258', 'Norwich_terrier'), ('n02094433', 'Yorkshire_terrier'), ('n02095314', 'wire-haired_fox_terrier'), ('n02095570', 'Lakeland_terrier'), ('n02095889', 'Sealyham_terrier'), ('n02096051', 'Airedale'), ('n02096177', 'cairn'), ('n02096294', 'Australian_terrier'), ('n02096437', 'Dandie_Dinmont'), ('n02096585', 'Boston_bull'), ('n02097047', 'miniature_schnauzer'), ('n02097130', 'giant_schnauzer'), ('n02097209', 'standard_schnauzer'), ('n02097298', 'Scotch_terrier'), ('n02097474', 'Tibetan_terrier'), ('n02097658', 'silky_terrier'), ('n02098105', 'soft-coated_wheaten_terrier'), ('n02098286', 'West_Highland_white_terrier'), ('n02098413', 'Lhasa'), ('n02099267', 'flat-coated_retriever'), ('n02099429', 'curly-coated_retriever'), ('n02099601', 'golden_retriever'), ('n02099712', 'Labrador_retriever'), ('n02099849', 'Chesapeake_Bay_retriever'), ('n02100236', 'German_short-haired_pointer'), ('n02100583', 'vizsla'), ('n02100735', 'English_setter'), ('n02100877', 'Irish_setter'), ('n02101006', 'Gordon_setter'), ('n02101388', 'Brittany_spaniel'), ('n02101556', 'clumber'), ('n02102040', 'English_springer'), ('n02102177', 'Welsh_springer_spaniel'), ('n02102318', 'cocker_spaniel'), ('n02102480', 'Sussex_spaniel'), ('n02102973', 'Irish_water_spaniel'), ('n02104029', 'kuvasz'), ('n02104365', 'schipperke'), ('n02105056', 'groenendael'), ('n02105162', 'malinois'), ('n02105251', 'briard'), ('n02105412', 'kelpie'), ('n02105505', 'komondor'), ('n02105641', 'Old_English_sheepdog'), ('n02105855', 'Shetland_sheepdog'), ('n02106030', 'collie'), ('n02106166', 'Border_collie'), ('n02106382', 'Bouvier_des_Flandres'), ('n02106550', 'Rottweiler'), ('n02106662', 'German_shepherd'), ('n02107142', 'Doberman'), ('n02107312', 'miniature_pinscher'), ('n02107574', 'Greater_Swiss_Mountain_dog'), ('n02107683', 'Bernese_mountain_dog'), ('n02107908', 'Appenzeller'), ('n02108000', 'EntleBucher'), ('n02108089', 'boxer'), ('n02108422', 'bull_mastiff'), ('n02108551', 'Tibetan_mastiff'), ('n02108915', 'French_bulldog'), ('n02109047', 'Great_Dane'), ('n02109525', 'Saint_Bernard'), ('n02109961', 'Eskimo_dog'), ('n02110063', 'malamute'), ('n02110185', 'Siberian_husky'), ('n02110341', 'dalmatian'), ('n02110627', 'affenpinscher'), ('n02110806', 'basenji'), ('n02110958', 'pug'), ('n02111129', 'Leonberg'), ('n02111277', 'Newfoundland'), ('n02111500', 'Great_Pyrenees'), ('n02111889', 'Samoyed'), ('n02112018', 'Pomeranian'), ('n02112137', 'chow'), ('n02112350', 'keeshond'), ('n02112706', 'Brabancon_griffon'), ('n02113023', 'Pembroke'), ('n02113186', 'Cardigan'), ('n02113624', 'toy_poodle'), ('n02113712', 'miniature_poodle'), ('n02113799', 'standard_poodle'), ('n02113978', 'Mexican_hairless'), ('n02114367', 'timber_wolf'), ('n02114548', 'white_wolf'), ('n02114712', 'red_wolf'), ('n02114855', 'coyote'), ('n02115641', 'dingo'), ('n02115913', 'dhole'), ('n02116738', 'African_hunting_dog'), ('n02117135', 'hyena'), ('n02119022', 'red_fox'), ('n02119789', 'kit_fox'), ('n02120079', 'Arctic_fox'), ('n02120505', 'grey_fox'), ('n02123045', 'tabby'), ('n02123159', 'tiger_cat'), ('n02123394', 'Persian_cat'), ('n02123597', 'Siamese_cat'), ('n02124075', 'Egyptian_cat'), ('n02125311', 'cougar'), ('n02127052', 'lynx'), ('n02128385', 'leopard'), ('n02128757', 'snow_leopard'), ('n02128925', 'jaguar'), ('n02129165', 'lion'), ('n02129604', 'tiger'), ('n02130308', 'cheetah'), ('n02132136', 'brown_bear'), ('n02133161', 'American_black_bear'), ('n02134084', 'ice_bear'), ('n02134418', 'sloth_bear'), ('n02137549', 'mongoose'), ('n02138441', 'meerkat'), ('n02165105', 'tiger_beetle'), ('n02165456', 'ladybug'), ('n02167151', 'ground_beetle'), ('n02168699', 'long-horned_beetle'), ('n02169497', 'leaf_beetle'), ('n02172182', 'dung_beetle'), ('n02174001', 'rhinoceros_beetle'), ('n02177972', 'weevil'), ('n02190166', 'fly'), ('n02206856', 'bee'), ('n02219486', 'ant'), ('n02226429', 'grasshopper'), ('n02229544', 'cricket'), ('n02231487', 'walking_stick'), ('n02233338', 'cockroach'), ('n02236044', 'mantis'), ('n02256656', 'cicada'), ('n02259212', 'leafhopper'), ('n02264363', 'lacewing'), ('n02268443', 'dragonfly'), ('n02268853', 'damselfly'), ('n02276258', 'admiral'), ('n02277742', 'ringlet'), ('n02279972', 'monarch'), ('n02280649', 'cabbage_butterfly'), ('n02281406', 'sulphur_butterfly'), ('n02281787', 'lycaenid'), ('n02317335', 'starfish'), ('n02319095', 'sea_urchin'), ('n02321529', 'sea_cucumber'), ('n02325366', 'wood_rabbit'), ('n02326432', 'hare'), ('n02328150', 'Angora'), ('n02342885', 'hamster'), ('n02346627', 'porcupine'), ('n02356798', 'fox_squirrel'), ('n02361337', 'marmot'), ('n02363005', 'beaver'), ('n02364673', 'guinea_pig'), ('n02389026', 'sorrel'), ('n02391049', 'zebra'), ('n02395406', 'hog'), ('n02396427', 'wild_boar'), ('n02397096', 'warthog'), ('n02398521', 'hippopotamus'), ('n02403003', 'ox'), ('n02408429', 'water_buffalo'), ('n02410509', 'bison'), ('n02412080', 'ram'), ('n02415577', 'bighorn'), ('n02417914', 'ibex'), ('n02422106', 'hartebeest'), ('n02422699', 'impala'), ('n02423022', 'gazelle'), ('n02437312', 'Arabian_camel'), ('n02437616', 'llama'), ('n02441942', 'weasel'), ('n02442845', 'mink'), ('n02443114', 'polecat'), ('n02443484', 'black-footed_ferret'), ('n02444819', 'otter'), ('n02445715', 'skunk'), ('n02447366', 'badger'), ('n02454379', 'armadillo'), ('n02457408', 'three-toed_sloth'), ('n02480495', 'orangutan'), ('n02480855', 'gorilla'), ('n02481823', 'chimpanzee'), ('n02483362', 'gibbon'), ('n02483708', 'siamang'), ('n02484975', 'guenon'), ('n02486261', 'patas'), ('n02486410', 'baboon'), ('n02487347', 'macaque'), ('n02488291', 'langur'), ('n02488702', 'colobus'), ('n02489166', 'proboscis_monkey'), ('n02490219', 'marmoset'), ('n02492035', 'capuchin'), ('n02492660', 'howler_monkey'), ('n02493509', 'titi'), ('n02493793', 'spider_monkey'), ('n02494079', 'squirrel_monkey'), ('n02497673', 'Madagascar_cat'), ('n02500267', 'indri'), ('n02504013', 'Indian_elephant'), ('n02504458', 'African_elephant'), ('n02509815', 'lesser_panda'), ('n02510455', 'giant_panda'), ('n02514041', 'barracouta'), ('n02526121', 'eel'), ('n02536864', 'coho'), ('n02606052', 'rock_beauty'), ('n02607072', 'anemone_fish'), ('n02640242', 'sturgeon'), ('n02641379', 'gar'), ('n02643566', 'lionfish'), ('n02655020', 'puffer'), ('n02666196', 'abacus'), ('n02667093', 'abaya'), ('n02669723', 'academic_gown'), ('n02672831', 'accordion'), ('n02676566', 'acoustic_guitar'), ('n02687172', 'aircraft_carrier'), ('n02690373', 'airliner'), ('n02692877', 'airship'), ('n02699494', 'altar'), ('n02701002', 'ambulance'), ('n02704792', 'amphibian'), ('n02708093', 'analog_clock'), ('n02727426', 'apiary'), ('n02730930', 'apron'), ('n02747177', 'ashcan'), ('n02749479', 'assault_rifle'), ('n02769748', 'backpack'), ('n02776631', 'bakery'), ('n02777292', 'balance_beam'), ('n02782093', 'balloon'), ('n02783161', 'ballpoint'), ('n02786058', 'Band_Aid'), ('n02787622', 'banjo'), ('n02788148', 'bannister'), ('n02790996', 'barbell'), ('n02791124', 'barber_chair'), ('n02791270', 'barbershop'), ('n02793495', 'barn'), ('n02794156', 'barometer'), ('n02795169', 'barrel'), ('n02797295', 'barrow'), ('n02799071', 'baseball'), ('n02802426', 'basketball'), ('n02804414', 'bassinet'), ('n02804610', 'bassoon'), ('n02807133', 'bathing_cap'), ('n02808304', 'bath_towel'), ('n02808440', 'bathtub'), ('n02814533', 'beach_wagon'), ('n02814860', 'beacon'), ('n02815834', 'beaker'), ('n02817516', 'bearskin'), ('n02823428', 'beer_bottle'), ('n02823750', 'beer_glass'), ('n02825657', 'bell_cote'), ('n02834397', 'bib'), ('n02835271', 'bicycle-built-for-two'), ('n02837789', 'bikini'), ('n02840245', 'binder'), ('n02841315', 'binoculars'), ('n02843684', 'birdhouse'), ('n02859443', 'boathouse'), ('n02860847', 'bobsled'), ('n02865351', 'bolo_tie'), ('n02869837', 'bonnet'), ('n02870880', 'bookcase'), ('n02871525', 'bookshop'), ('n02877765', 'bottlecap'), ('n02879718', 'bow'), ('n02883205', 'bow_tie'), ('n02892201', 'brass'), ('n02892767', 'brassiere'), ('n02894605', 'breakwater'), ('n02895154', 'breastplate'), ('n02906734', 'broom'), ('n02909870', 'bucket'), ('n02910353', 'buckle'), ('n02916936', 'bulletproof_vest'), ('n02917067', 'bullet_train'), ('n02927161', 'butcher_shop'), ('n02930766', 'cab'), ('n02939185', 'caldron'), ('n02948072', 'candle'), ('n02950826', 'cannon'), ('n02951358', 'canoe'), ('n02951585', 'can_opener'), ('n02963159', 'cardigan'), ('n02965783', 'car_mirror'), ('n02966193', 'carousel'), ('n02966687', \"carpenter's_kit\"), ('n02971356', 'carton'), ('n02974003', 'car_wheel'), ('n02977058', 'cash_machine'), ('n02978881', 'cassette'), ('n02979186', 'cassette_player'), ('n02980441', 'castle'), ('n02981792', 'catamaran'), ('n02988304', 'CD_player'), ('n02992211', 'cello'), ('n02992529', 'cellular_telephone'), ('n02999410', 'chain'), ('n03000134', 'chainlink_fence'), ('n03000247', 'chain_mail'), ('n03000684', 'chain_saw'), ('n03014705', 'chest'), ('n03016953', 'chiffonier'), ('n03017168', 'chime'), ('n03018349', 'china_cabinet'), ('n03026506', 'Christmas_stocking'), ('n03028079', 'church'), ('n03032252', 'cinema'), ('n03041632', 'cleaver'), ('n03042490', 'cliff_dwelling'), ('n03045698', 'cloak'), ('n03047690', 'clog'), ('n03062245', 'cocktail_shaker'), ('n03063599', 'coffee_mug'), ('n03063689', 'coffeepot'), ('n03065424', 'coil'), ('n03075370', 'combination_lock'), ('n03085013', 'computer_keyboard'), ('n03089624', 'confectionery'), ('n03095699', 'container_ship'), ('n03100240', 'convertible'), ('n03109150', 'corkscrew'), ('n03110669', 'cornet'), ('n03124043', 'cowboy_boot'), ('n03124170', 'cowboy_hat'), ('n03125729', 'cradle'), ('n03126707', 'crane'), ('n03127747', 'crash_helmet'), ('n03127925', 'crate'), ('n03131574', 'crib'), ('n03133878', 'Crock_Pot'), ('n03134739', 'croquet_ball'), ('n03141823', 'crutch'), ('n03146219', 'cuirass'), ('n03160309', 'dam'), ('n03179701', 'desk'), ('n03180011', 'desktop_computer'), ('n03187595', 'dial_telephone'), ('n03188531', 'diaper'), ('n03196217', 'digital_clock'), ('n03197337', 'digital_watch'), ('n03201208', 'dining_table'), ('n03207743', 'dishrag'), ('n03207941', 'dishwasher'), ('n03208938', 'disk_brake'), ('n03216828', 'dock'), ('n03218198', 'dogsled'), ('n03220513', 'dome'), ('n03223299', 'doormat'), ('n03240683', 'drilling_platform'), ('n03249569', 'drum'), ('n03250847', 'drumstick'), ('n03255030', 'dumbbell'), ('n03259280', 'Dutch_oven'), ('n03271574', 'electric_fan'), ('n03272010', 'electric_guitar'), ('n03272562', 'electric_locomotive'), ('n03290653', 'entertainment_center'), ('n03291819', 'envelope'), ('n03297495', 'espresso_maker'), ('n03314780', 'face_powder'), ('n03325584', 'feather_boa'), ('n03337140', 'file'), ('n03344393', 'fireboat'), ('n03345487', 'fire_engine'), ('n03347037', 'fire_screen'), ('n03355925', 'flagpole'), ('n03372029', 'flute'), ('n03376595', 'folding_chair'), ('n03379051', 'football_helmet'), ('n03384352', 'forklift'), ('n03388043', 'fountain'), ('n03388183', 'fountain_pen'), ('n03388549', 'four-poster'), ('n03393912', 'freight_car'), ('n03394916', 'French_horn'), ('n03400231', 'frying_pan'), ('n03404251', 'fur_coat'), ('n03417042', 'garbage_truck'), ('n03424325', 'gasmask'), ('n03425413', 'gas_pump'), ('n03443371', 'goblet'), ('n03444034', 'go-kart'), ('n03445777', 'golf_ball'), ('n03445924', 'golfcart'), ('n03447447', 'gondola'), ('n03447721', 'gong'), ('n03450230', 'gown'), ('n03452741', 'grand_piano'), ('n03457902', 'greenhouse'), ('n03459775', 'grille'), ('n03461385', 'grocery_store'), ('n03467068', 'guillotine'), ('n03476684', 'hair_slide'), ('n03476991', 'hair_spray'), ('n03478589', 'half_track'), ('n03481172', 'hammer'), ('n03482405', 'hamper'), ('n03483316', 'hand_blower'), ('n03485407', 'hand-held_computer'), ('n03485794', 'handkerchief'), ('n03492542', 'hard_disc'), ('n03494278', 'harmonica'), ('n03495258', 'harp'), ('n03496892', 'harvester'), ('n03498962', 'hatchet'), ('n03527444', 'holster'), ('n03529860', 'home_theater'), ('n03530642', 'honeycomb'), ('n03532672', 'hook'), ('n03534580', 'hoopskirt'), ('n03535780', 'horizontal_bar'), ('n03538406', 'horse_cart'), ('n03544143', 'hourglass'), ('n03584254', 'iPod'), ('n03584829', 'iron'), ('n03590841', \"jack-o'-lantern\"), ('n03594734', 'jean'), ('n03594945', 'jeep'), ('n03595614', 'jersey'), ('n03598930', 'jigsaw_puzzle'), ('n03599486', 'jinrikisha'), ('n03602883', 'joystick'), ('n03617480', 'kimono'), ('n03623198', 'knee_pad'), ('n03627232', 'knot'), ('n03630383', 'lab_coat'), ('n03633091', 'ladle'), ('n03637318', 'lampshade'), ('n03642806', 'laptop'), ('n03649909', 'lawn_mower'), ('n03657121', 'lens_cap'), ('n03658185', 'letter_opener'), ('n03661043', 'library'), ('n03662601', 'lifeboat'), ('n03666591', 'lighter'), ('n03670208', 'limousine'), ('n03673027', 'liner'), ('n03676483', 'lipstick'), ('n03680355', 'Loafer'), ('n03690938', 'lotion'), ('n03691459', 'loudspeaker'), ('n03692522', 'loupe'), ('n03697007', 'lumbermill'), ('n03706229', 'magnetic_compass'), ('n03709823', 'mailbag'), ('n03710193', 'mailbox'), ('n03710637', 'maillot'), ('n03710721', 'maillot'), ('n03717622', 'manhole_cover'), ('n03720891', 'maraca'), ('n03721384', 'marimba'), ('n03724870', 'mask'), ('n03729826', 'matchstick'), ('n03733131', 'maypole'), ('n03733281', 'maze'), ('n03733805', 'measuring_cup'), ('n03742115', 'medicine_chest'), ('n03743016', 'megalith'), ('n03759954', 'microphone'), ('n03761084', 'microwave'), ('n03763968', 'military_uniform'), ('n03764736', 'milk_can'), ('n03769881', 'minibus'), ('n03770439', 'miniskirt'), ('n03770679', 'minivan'), ('n03773504', 'missile'), ('n03775071', 'mitten'), ('n03775546', 'mixing_bowl'), ('n03776460', 'mobile_home'), ('n03777568', 'Model_T'), ('n03777754', 'modem'), ('n03781244', 'monastery'), ('n03782006', 'monitor'), ('n03785016', 'moped'), ('n03786901', 'mortar'), ('n03787032', 'mortarboard'), ('n03788195', 'mosque'), ('n03788365', 'mosquito_net'), ('n03791053', 'motor_scooter'), ('n03792782', 'mountain_bike'), ('n03792972', 'mountain_tent'), ('n03793489', 'mouse'), ('n03794056', 'mousetrap'), ('n03796401', 'moving_van'), ('n03803284', 'muzzle'), ('n03804744', 'nail'), ('n03814639', 'neck_brace'), ('n03814906', 'necklace'), ('n03825788', 'nipple'), ('n03832673', 'notebook'), ('n03837869', 'obelisk'), ('n03838899', 'oboe'), ('n03840681', 'ocarina'), ('n03841143', 'odometer'), ('n03843555', 'oil_filter'), ('n03854065', 'organ'), ('n03857828', 'oscilloscope'), ('n03866082', 'overskirt'), ('n03868242', 'oxcart'), ('n03868863', 'oxygen_mask'), ('n03871628', 'packet'), ('n03873416', 'paddle'), ('n03874293', 'paddlewheel'), ('n03874599', 'padlock'), ('n03876231', 'paintbrush'), ('n03877472', 'pajama'), ('n03877845', 'palace'), ('n03884397', 'panpipe'), ('n03887697', 'paper_towel'), ('n03888257', 'parachute'), ('n03888605', 'parallel_bars'), ('n03891251', 'park_bench'), ('n03891332', 'parking_meter'), ('n03895866', 'passenger_car'), ('n03899768', 'patio'), ('n03902125', 'pay-phone'), ('n03903868', 'pedestal'), ('n03908618', 'pencil_box'), ('n03908714', 'pencil_sharpener'), ('n03916031', 'perfume'), ('n03920288', 'Petri_dish'), ('n03924679', 'photocopier'), ('n03929660', 'pick'), ('n03929855', 'pickelhaube'), ('n03930313', 'picket_fence'), ('n03930630', 'pickup'), ('n03933933', 'pier'), ('n03935335', 'piggy_bank'), ('n03937543', 'pill_bottle'), ('n03938244', 'pillow'), ('n03942813', 'ping-pong_ball'), ('n03944341', 'pinwheel'), ('n03947888', 'pirate'), ('n03950228', 'pitcher'), ('n03954731', 'plane'), ('n03956157', 'planetarium'), ('n03958227', 'plastic_bag'), ('n03961711', 'plate_rack'), ('n03967562', 'plow'), ('n03970156', 'plunger'), ('n03976467', 'Polaroid_camera'), ('n03976657', 'pole'), ('n03977966', 'police_van'), ('n03980874', 'poncho'), ('n03982430', 'pool_table'), ('n03983396', 'pop_bottle'), ('n03991062', 'pot'), ('n03992509', \"potter's_wheel\"), ('n03995372', 'power_drill'), ('n03998194', 'prayer_rug'), ('n04004767', 'printer'), ('n04005630', 'prison'), ('n04008634', 'projectile'), ('n04009552', 'projector'), ('n04019541', 'puck'), ('n04023962', 'punching_bag'), ('n04026417', 'purse'), ('n04033901', 'quill'), ('n04033995', 'quilt'), ('n04037443', 'racer'), ('n04039381', 'racket'), ('n04040759', 'radiator'), ('n04041544', 'radio'), ('n04044716', 'radio_telescope'), ('n04049303', 'rain_barrel'), ('n04065272', 'recreational_vehicle'), ('n04067472', 'reel'), ('n04069434', 'reflex_camera'), ('n04070727', 'refrigerator'), ('n04074963', 'remote_control'), ('n04081281', 'restaurant'), ('n04086273', 'revolver'), ('n04090263', 'rifle'), ('n04099969', 'rocking_chair'), ('n04111531', 'rotisserie'), ('n04116512', 'rubber_eraser'), ('n04118538', 'rugby_ball'), ('n04118776', 'rule'), ('n04120489', 'running_shoe'), ('n04125021', 'safe'), ('n04127249', 'safety_pin'), ('n04131690', 'saltshaker'), ('n04133789', 'sandal'), ('n04136333', 'sarong'), ('n04141076', 'sax'), ('n04141327', 'scabbard'), ('n04141975', 'scale'), ('n04146614', 'school_bus'), ('n04147183', 'schooner'), ('n04149813', 'scoreboard'), ('n04152593', 'screen'), ('n04153751', 'screw'), ('n04154565', 'screwdriver'), ('n04162706', 'seat_belt'), ('n04179913', 'sewing_machine'), ('n04192698', 'shield'), ('n04200800', 'shoe_shop'), ('n04201297', 'shoji'), ('n04204238', 'shopping_basket'), ('n04204347', 'shopping_cart'), ('n04208210', 'shovel'), ('n04209133', 'shower_cap'), ('n04209239', 'shower_curtain'), ('n04228054', 'ski'), ('n04229816', 'ski_mask'), ('n04235860', 'sleeping_bag'), ('n04238763', 'slide_rule'), ('n04239074', 'sliding_door'), ('n04243546', 'slot'), ('n04251144', 'snorkel'), ('n04252077', 'snowmobile'), ('n04252225', 'snowplow'), ('n04254120', 'soap_dispenser'), ('n04254680', 'soccer_ball'), ('n04254777', 'sock'), ('n04258138', 'solar_dish'), ('n04259630', 'sombrero'), ('n04263257', 'soup_bowl'), ('n04264628', 'space_bar'), ('n04265275', 'space_heater'), ('n04266014', 'space_shuttle'), ('n04270147', 'spatula'), ('n04273569', 'speedboat'), ('n04275548', 'spider_web'), ('n04277352', 'spindle'), ('n04285008', 'sports_car'), ('n04286575', 'spotlight'), ('n04296562', 'stage'), ('n04310018', 'steam_locomotive'), ('n04311004', 'steel_arch_bridge'), ('n04311174', 'steel_drum'), ('n04317175', 'stethoscope'), ('n04325704', 'stole'), ('n04326547', 'stone_wall'), ('n04328186', 'stopwatch'), ('n04330267', 'stove'), ('n04332243', 'strainer'), ('n04335435', 'streetcar'), ('n04336792', 'stretcher'), ('n04344873', 'studio_couch'), ('n04346328', 'stupa'), ('n04347754', 'submarine'), ('n04350905', 'suit'), ('n04355338', 'sundial'), ('n04355933', 'sunglass'), ('n04356056', 'sunglasses'), ('n04357314', 'sunscreen'), ('n04366367', 'suspension_bridge'), ('n04367480', 'swab'), ('n04370456', 'sweatshirt'), ('n04371430', 'swimming_trunks'), ('n04371774', 'swing'), ('n04372370', 'switch'), ('n04376876', 'syringe'), ('n04380533', 'table_lamp'), ('n04389033', 'tank'), ('n04392985', 'tape_player'), ('n04398044', 'teapot'), ('n04399382', 'teddy'), ('n04404412', 'television'), ('n04409515', 'tennis_ball'), ('n04417672', 'thatch'), ('n04418357', 'theater_curtain'), ('n04423845', 'thimble'), ('n04428191', 'thresher'), ('n04429376', 'throne'), ('n04435653', 'tile_roof'), ('n04442312', 'toaster'), ('n04443257', 'tobacco_shop'), ('n04447861', 'toilet_seat'), ('n04456115', 'torch'), ('n04458633', 'totem_pole'), ('n04461696', 'tow_truck'), ('n04462240', 'toyshop'), ('n04465501', 'tractor'), ('n04467665', 'trailer_truck'), ('n04476259', 'tray'), ('n04479046', 'trench_coat'), ('n04482393', 'tricycle'), ('n04483307', 'trimaran'), ('n04485082', 'tripod'), ('n04486054', 'triumphal_arch'), ('n04487081', 'trolleybus'), ('n04487394', 'trombone'), ('n04493381', 'tub'), ('n04501370', 'turnstile'), ('n04505470', 'typewriter_keyboard'), ('n04507155', 'umbrella'), ('n04509417', 'unicycle'), ('n04515003', 'upright'), ('n04517823', 'vacuum'), ('n04522168', 'vase'), ('n04523525', 'vault'), ('n04525038', 'velvet'), ('n04525305', 'vending_machine'), ('n04532106', 'vestment'), ('n04532670', 'viaduct'), ('n04536866', 'violin'), ('n04540053', 'volleyball'), ('n04542943', 'waffle_iron'), ('n04548280', 'wall_clock'), ('n04548362', 'wallet'), ('n04550184', 'wardrobe'), ('n04552348', 'warplane'), ('n04553703', 'washbasin'), ('n04554684', 'washer'), ('n04557648', 'water_bottle'), ('n04560804', 'water_jug'), ('n04562935', 'water_tower'), ('n04579145', 'whiskey_jug'), ('n04579432', 'whistle'), ('n04584207', 'wig'), ('n04589890', 'window_screen'), ('n04590129', 'window_shade'), ('n04591157', 'Windsor_tie'), ('n04591713', 'wine_bottle'), ('n04592741', 'wing'), ('n04596742', 'wok'), ('n04597913', 'wooden_spoon'), ('n04599235', 'wool'), ('n04604644', 'worm_fence'), ('n04606251', 'wreck'), ('n04612504', 'yawl'), ('n04613696', 'yurt'), ('n06359193', 'web_site'), ('n06596364', 'comic_book'), ('n06785654', 'crossword_puzzle'), ('n06794110', 'street_sign'), ('n06874185', 'traffic_light'), ('n07248320', 'book_jacket'), ('n07565083', 'menu'), ('n07579787', 'plate'), ('n07583066', 'guacamole'), ('n07584110', 'consomme'), ('n07590611', 'hot_pot'), ('n07613480', 'trifle'), ('n07614500', 'ice_cream'), ('n07615774', 'ice_lolly'), ('n07684084', 'French_loaf'), ('n07693725', 'bagel'), ('n07695742', 'pretzel'), ('n07697313', 'cheeseburger'), ('n07697537', 'hotdog'), ('n07711569', 'mashed_potato'), ('n07714571', 'head_cabbage'), ('n07714990', 'broccoli'), ('n07715103', 'cauliflower'), ('n07716358', 'zucchini'), ('n07716906', 'spaghetti_squash'), ('n07717410', 'acorn_squash'), ('n07717556', 'butternut_squash'), ('n07718472', 'cucumber'), ('n07718747', 'artichoke'), ('n07720875', 'bell_pepper'), ('n07730033', 'cardoon'), ('n07734744', 'mushroom'), ('n07742313', 'Granny_Smith'), ('n07745940', 'strawberry'), ('n07747607', 'orange'), ('n07749582', 'lemon'), ('n07753113', 'fig'), ('n07753275', 'pineapple'), ('n07753592', 'banana'), ('n07754684', 'jackfruit'), ('n07760859', 'custard_apple'), ('n07768694', 'pomegranate'), ('n07802026', 'hay'), ('n07831146', 'carbonara'), ('n07836838', 'chocolate_sauce'), ('n07860988', 'dough'), ('n07871810', 'meat_loaf'), ('n07873807', 'pizza'), ('n07875152', 'potpie'), ('n07880968', 'burrito'), ('n07892512', 'red_wine'), ('n07920052', 'espresso'), ('n07930864', 'cup'), ('n07932039', 'eggnog'), ('n09193705', 'alp'), ('n09229709', 'bubble'), ('n09246464', 'cliff'), ('n09256479', 'coral_reef'), ('n09288635', 'geyser'), ('n09332890', 'lakeside'), ('n09399592', 'promontory'), ('n09421951', 'sandbar'), ('n09428293', 'seashore'), ('n09468604', 'valley'), ('n09472597', 'volcano'), ('n09835506', 'ballplayer'), ('n10148035', 'groom'), ('n10565667', 'scuba_diver'), ('n11879895', 'rapeseed'), ('n11939491', 'daisy'), ('n12057211', \"yellow_lady's_slipper\"), ('n12144580', 'corn'), ('n12267677', 'acorn'), ('n12620546', 'hip'), ('n12768682', 'buckeye'), ('n12985857', 'coral_fungus'), ('n12998815', 'agaric'), ('n13037406', 'gyromitra'), ('n13040303', 'stinkhorn'), ('n13044778', 'earthstar'), ('n13052670', 'hen-of-the-woods'), ('n13054560', 'bolete'), ('n13133613', 'ear'), ('n15075141', 'toilet_tissue')]\n", + "classes_human_readable = {v0: v1 for (v0, v1) in classes}\n", + "classes_id = {v0: int(k) for k, (v0, _) in enumerate(classes)}\n", + "\n", + "df = df.with_column(\n", + " \"class_human_readable\",\n", + " df[\"object\"].list.get(0).struct.get(\"name\").apply(\n", + " lambda name: classes_human_readable[name], return_dtype=daft.DataType.string()\n", + " ),\n", + ")\n", + "df = df.with_column(\n", + " \"class_id\",\n", + " df[\"object\"].list.get(0).struct.get(\"name\").apply(\n", + " lambda name: classes_id[name], return_dtype=daft.DataType.int64()\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "947a0175-151b-4929-a733-3840276041c4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
image
Image[MIXED]
arr
FixedShapeTensor(UInt8, [256, 256, 3])
class_human_readable
Utf8
class_id
Int64
\"<Image\" />
<FixedShapeTensor>
sea_snake
65
\"<Image\" />
<FixedShapeTensor>
alp
970
\"<Image\" />
<FixedShapeTensor>
Shetland_sheepdog
230
\"<Image\" />
<FixedShapeTensor>
soup_bowl
809
\n", + "(Showing first 4 rows)\n", + "
" + ], + "text/plain": [ + "╭──────────────┬────────────────────────────────────────┬──────────────────────┬──────────╮\n", + "│ image ┆ arr ┆ class_human_readable ┆ class_id │\n", + "│ --- ┆ --- ┆ --- ┆ --- │\n", + "│ Image[MIXED] ┆ FixedShapeTensor(UInt8, [256, 256, 3]) ┆ Utf8 ┆ Int64 │\n", + "╞══════════════╪════════════════════════════════════════╪══════════════════════╪══════════╡\n", + "│ ┆ sea_snake ┆ 65 │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤\n", + "│ ┆ alp ┆ 970 │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤\n", + "│ ┆ Shetland_sheepdog ┆ 230 │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤\n", + "│ ┆ soup_bowl ┆ 809 │\n", + "╰──────────────┴────────────────────────────────────────┴──────────────────────┴──────────╯\n", + "\n", + "(Showing first 4 rows)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df = df.select(\"image\", \"arr\", \"class_human_readable\", \"class_id\")\n", + "df.show(4)" + ] + }, + { + "cell_type": "markdown", + "id": "be1af108-1b02-47d4-b849-ac3ebc80b7bf", + "metadata": {}, + "source": [ + "## Convert to Ray Dataset\n", + "\n", + "Now we can convert our dataframe to a Ray Dataset, which is a great API and framework for ingesting data into ML training" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ef14cef1-6029-4890-87fd-d0f232b01730", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "ScanWithTask-LocalLimit-LocalLimit-Project-Project-Filter-Project [Stage:3]: 0%| | 0/1 [00:00\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
image
Image[MIXED]
arr
FixedShapeTensor(UInt8, [256, 256, 3])
class_human_readable
Utf8
class_id
Int64
model_predictions
Python
\"<Image\" />
<FixedShapeTensor>
sea_snake
65
tensor([[-2.7218e-01, -3.9199e-01, -6.0069e-01, -1.3214e+00, -7.2138e-01,
-2.2742e+00, -9.6090e-01, 1.5565e-01, 2.1052e-03, -8.7728e-01,
-8.4119e-01, -7.8948e-01, -4.2467e-01, -1.1369e+00, -9.2177e-01,
-7.4507e-01, -6.8149e-01, -2.7634e-01, -7.3824e-01, -8.3110e-01,
-1.4968e+00, 7.0988e+00, 2.4761e+00, -2.3017e+01, -1.1977e+00,
-1.1058e+00, -1.1112e+00, -1.0467e+00, 1.2516e+01, 8.9705e+00,
-9.9994e-01, -9.8104e-01, -1.7171e+01, -1.3706e+00, -1.9668e+00,
-6.7701e-01, -2.1457e-01, -6.8726e-01, -8.5024e-01, -2.1465e-01,
-7.7041e-01, -1.0962e+00, -1.0028e+00, -3.2649e-01, -8.0604e-01,
-6.8337e-01, 6.4011e+00, -6.6844e-01, -1.2203e+00, -2.2910e+00,
8.5518e+00, 1.6674e-01, -4.6913e+00, -2.8240e+00, -1.2394e+00,
-1.1207e+01, -4.1234e-01, -1.6717e+01, -2.7463e-01, -1.0010e+00,
-2.3535e+00, -4.7213e-01, -3.9813e+00, -4.1223e-01, -8.9204e-01,
4.7543e+00, -1.5626e+00, -1.5489e+00, -2.1282e+00, -1.1224e+00,
-1.4203e+00, -2.4671e-01, -1.1183e+00, -2.6329e-01, 5.6322e+00,
-1.1023e+00, -2.2263e-01, -6.0349e-01, -2.0382e-01, -2.0846e-01,
8.0481e+00, -1.4509e+00, -1.7574e-01, -3.3088e-01, -5.7435e-01,
-4.0449e-01, -6.2366e-03, -2.1108e-01, -1.4353e-01, -6.2005e-01,
-9.4666e-01, -7.8964e-01, 6.5133e+00, -5.3910e-01, 6.0052e-02,
-1.4843e+00, -8.4534e-01, -1.0062e-01, -1.1493e+00, -3.7695e-01,
-1.0982e+00, 7.8310e+00, -9.2716e-01, -3.3702e-01, -6.0828e-01,
-6.8946e-01, -7.3448e-01, 2.6687e+00, 5.6855e+00, -1.3891e+01,
-1.8712e+00, -1.3486e+00, 6.0427e-01, -7.6204e-02, -1.3723e+00,
-1.0717e+00, -1.0682e+00, -5.0470e-01, 1.3474e-01, -5.8448e-01,
-9.3026e-01, -4.2922e-01, 5.5625e+00, -1.5869e-01, 5.7220e-01,
-3.7398e-01, -6.7703e-02, -1.2750e+00, -8.9673e-01, 1.4321e+00,
-1.3602e+00, -1.3918e+00, -1.0868e+00, -1.0081e+00, -7.7643e-01,
-9.8061e-01, -8.5948e-01, -1.0274e+00, -1.0425e+00, -1.1698e+00,
-1.3543e+00, -1.2258e+00, 5.2071e+00, -1.4204e+00, -5.8002e-01,
-5.6669e-01, -8.6916e-01, 2.5076e+00, -1.0873e+00, 1.0240e+01,
5.7582e+00, 1.0144e+00, -1.0974e+00, -6.9455e-01, -1.8479e-01,
4.6612e-02, -7.4393e-01, -3.3550e+00, 4.6889e-03, 2.2762e+01,
1.0435e+01, 3.7995e-01, -6.1795e-02, 1.9692e-01, 8.0856e-02,
-2.6889e-01, -3.3845e-01, -6.3070e+00, 1.7832e-01, -9.8300e-01,
-4.8971e-01, 3.9528e-01, 1.7481e-01, -2.4623e+01, -3.3910e-02,
8.1921e+00, -1.3139e-01, -4.1811e-01, 1.8810e-01, 1.3478e-01,
1.5791e-01, -3.5686e-01, 6.8664e-02, 5.9554e+00, 2.1844e-01,
2.7276e-01, 2.8488e-01, -1.1633e-01, -1.2988e-01, 4.6866e-02,
-6.6081e-01, 8.1372e-03, 3.4176e-02, 3.3784e-02, -6.6022e-01,
5.1170e-01, -7.6427e-02, -2.3861e-01, 3.2628e+00, 3.3807e-01,
-1.3836e-01, 8.5680e+00, 1.4327e-01, 1.6371e-01, -1.8617e-01,
5.0539e-02, -9.6662e-02, 2.0577e-01, 6.3415e-01, 3.9985e-02,
-2.3564e-01, 7.4005e-02, 1.6075e-01, -2.8846e-01, -2.0495e-01,
-1.6295e-01, -2.6908e-01, -3.4792e-02, -7.2707e-01, 2.2754e-01,
-1.9363e-01, -2.7528e-01, -1.5432e-01, 2.0404e-01, -1.6314e-02,
2.2121e-01, -1.6963e-01, 5.6767e+00, -6.0694e-01, -4.4068e-01,
8.9364e+00, -4.6164e+00, -3.3140e-01, -3.0048e-01, 3.9635e-01,
4.5472e-01, 2.9910e-01, 2.3377e-01, 2.8540e-01, -3.0505e-01,
2.9635e-01, -1.9184e-01, 1.8154e-01, 7.9804e-02, -4.2487e-01,
3.2846e-01, 1.3669e-01, -1.3106e-01, 4.3115e-01, 6.4543e-02,
2.6967e-01, 1.8888e-01, -8.2692e-01, 2.5494e-01, 4.7320e-01,
-7.3103e-01, 1.0469e+01, -3.2224e-02, -3.8894e-01, 9.1315e+00,
-4.4979e-01, -8.5652e-01, -2.8508e-01, -2.5630e-01, 3.2376e-01,
2.8337e-01, 6.2566e-02, 2.5271e-02, -6.4055e-02, -5.2655e-01,
9.8856e+00, -9.3613e-01, -5.6440e-01, -1.4580e-01, -1.0994e+00,
8.0291e+00, -1.0691e+00, -7.1980e-01, -1.0271e+00, -6.7273e-01,
-4.6598e-01, 1.1383e-01, -1.4973e+00, -4.3521e-01, 7.3579e-02,
-7.5883e-02, 9.6552e+00, -4.8396e+00, -9.6121e-01, -1.4415e+00,
-1.8126e+00, -6.1558e+00, -9.1588e-01, -9.7442e-01, -1.2681e+00,
-7.8147e-01, -7.9471e-01, -1.1873e+00, -8.4203e-01, -1.0959e+00,
-1.5060e+00, -8.2959e-01, -6.0841e-01, -3.4820e-01, -1.0868e+00,
-1.3578e+00, -7.5302e-02, -1.0509e+00, -1.4035e+00, -5.7627e-01,
-1.8200e-01, 3.6591e+00, -4.1301e-01, -1.7634e-01, 8.7630e-02,
-5.8917e-01, -8.8327e-01, -1.1651e+00, -1.0661e+00, -6.1558e-01,
-1.2304e+00, -9.2926e-01, -9.6682e-01, -1.2946e+00, 7.7499e+00,
-1.2954e+00, -1.1388e+00, -3.9786e-01, -1.1407e+00, -8.1483e-01,
8.0324e+00, -2.7111e-01, 2.6719e+00, 6.1825e-03, -3.0653e+01,
-1.1672e+00, -1.6488e+00, -1.0722e+00, -6.8668e-01, -6.9518e-01,
-2.1088e-01, -2.3844e-01, -1.0247e+00, -1.1067e+00, -4.3101e-01,
3.9490e-01, -9.0675e-01, -1.3142e+00, -2.2741e-01, 6.6318e+00,
-1.0125e+00, -1.6010e+00, 3.1983e+00, -4.6579e-01, -4.4227e-01,
1.8837e-01, 6.8256e-01, -2.1065e-01, 6.8557e+00, 1.7923e-01,
-8.4758e-01, 5.1670e-03, -1.4428e+00, -8.6775e-02, -4.5402e-01,
-6.2526e-01, 8.0519e+00, -9.4258e-01, -6.6032e-01, 8.3520e+00,
7.9748e+00, -6.1636e+00, -1.3699e+00, 7.9199e+00, -1.5636e+00,
-1.5538e+00, -8.6365e-01, -3.2521e+00, -5.8446e-01, -6.3183e-01,
-6.0908e-01, -5.5044e-01, -4.3848e-01, 1.6095e+01, -5.7093e-01,
-5.1839e-01, -9.4339e-01, -9.8410e-01, -9.7510e-01, -9.1060e-03,
-2.6986e+00, 8.7224e-02, 7.3503e+00, -1.5795e+00, 7.1197e+00,
-1.7444e-01, -1.0286e+00, -1.2707e+00, 6.9452e+00, 3.3508e-02,
-5.6076e-01, 4.0854e-01, 6.8882e-02, -1.6711e+00, -1.2866e+00,
-4.6440e-01, -3.3918e-01, -4.1093e-01, -5.1589e-01, 6.9805e-01,
-9.2784e-02, -5.3806e-03, 1.1749e+00, 3.2381e-01, 9.3087e-02,
-5.2370e+01, -3.9389e-01, -8.6826e-02, -3.6716e-02, -5.6528e-01,
6.2077e-01, 6.5678e-01, -2.5474e-01, -1.2546e+00, 7.4020e+00,
-1.0467e+00, -4.5434e-02, 8.2333e-01, 8.3208e-01, 2.3128e-03,
-8.2706e-01, -5.0666e+00, -5.4116e-02, -2.3150e-01, -5.9189e-01,
7.7825e-01, -3.4248e-01, -6.9473e-01, 2.8248e-01, 2.3578e-01,
4.9211e-01, -1.4448e-01, -1.7877e-01, -1.0182e+00, -4.5573e-01,
3.7252e-02, 1.7568e-01, 7.3505e-01, 3.7696e-01, -5.2635e-01,
-1.0219e-01, 2.9260e-01, 2.0781e-01, -4.8262e-01, -3.5518e-01,
-1.1886e-02, 1.0787e+00, 8.5548e-01, -3.1228e-01, 3.6009e-01,
-7.9261e-01, 1.3557e+01, 7.2666e-01, 1.4181e+00, 5.8066e-01,
-2.8873e-01, -9.9615e-01, 8.9555e+00, -1.7767e+00, 6.3621e-01,
-1.1300e+00, 1.1247e-01, 1.5563e-01, 8.8448e+00, 6.5351e+00,
1.3532e-02, 7.3117e+00, 1.8166e-01, -2.2641e+01, 1.5273e+01,
-2.4683e-01, -9.5049e-03, 1.1630e-01, 3.0633e+00, -1.1981e+00,
-2.5942e-01, 1.8629e+01, 7.5353e+00, 5.1561e-01, 1.4080e-01,
1.4515e-01, 3.1030e-01, 4.2499e-01, -9.5339e-01, 3.4819e-01,
-1.0911e+00, -5.5979e-01, -5.4063e-01, 7.2251e+00, 7.4376e-01,
-1.2931e+00, 3.6317e-01, 7.8293e-01, 2.1842e-01, 2.6177e-01,
6.3840e-01, 2.7178e-01, 2.2662e-01, 8.2165e-02, -9.4395e-02,
-9.9010e-01, -6.1838e-01, 3.7146e-01, 5.7994e-02, 9.5550e-01,
1.2946e+00, -1.1042e+01, 4.3105e+00, 8.7146e-01, 2.8734e-01,
-6.0668e+00, -3.6015e-02, 5.6721e-01, 8.6468e-01, 3.0424e-01,
-6.8888e-01, -1.5222e-01, -3.8488e-01, -9.8998e-02, -2.3635e-02,
4.8019e-01, -3.3171e-02, -5.5059e-01, -8.8969e-01, 3.8396e-02,
-8.0375e-01, -3.3322e-01, -7.1269e-01, -8.8020e-02, 4.8284e-01,
-1.0943e+00, 9.1922e-01, 4.9380e-01, 3.0667e-01, 8.5801e-02,
4.7345e-01, 4.5078e-01, -1.5427e+00, -1.1097e+00, -1.4702e-01,
-5.4698e-01, -3.7909e-01, 3.7092e-01, -2.5831e-01, -1.3308e+00,
-5.2416e-01, 9.8743e-02, 8.8819e-02, 8.2802e-01, 2.5360e-01,
-3.2341e-01, -4.4850e-01, 4.2854e-01, -5.3071e-02, -2.8175e+00,
9.6548e+00, -1.2190e-01, 4.6243e-01, 2.0567e-01, -1.0811e+00,
6.8205e-01, -2.0924e-01, 5.1765e-01, -4.9734e+00, -3.0300e-01,
-4.1095e-01, -6.7360e-01, 3.2369e-01, -2.3787e-02, -4.4443e-01,
-9.3292e-03, -2.2638e-01, 3.4301e-01, 1.0489e-01, 2.4126e-01,
-2.9834e-01, 1.1082e+01, 1.0232e+00, 5.0191e+00, 6.3104e-01,
-3.5633e-01, 4.0038e+00, -5.5478e-01, 5.4387e-01, 1.5682e-01,
7.6247e+00, 7.6770e-01, -1.3988e-01, -5.5349e-01, -1.6953e-01,
9.5465e-01, -2.2591e-01, -2.3700e-01, -4.8072e-01, 2.5551e-01,
-9.6264e-02, 6.6338e-01, -6.0307e-01, 5.4596e+00, -5.3585e-01,
6.3952e-01, 3.0489e-01, -7.5147e-01, 2.9045e-01, -2.0272e-01,
-6.9456e-02, 6.5666e-01, 2.7208e-01, 1.1204e+00, 3.9470e-01,
3.8175e-01, -1.6587e+00, -1.3396e-01, 1.9864e-01, -7.9700e-02,
-9.1202e-01, 4.2452e-01, -4.5784e-01, -9.0505e-01, -2.5147e+00,
-1.5909e-01, -3.4615e-01, 5.5582e+00, 5.7411e-01, -1.8700e-01,
7.2076e-02, 7.8569e-01, -4.1314e-03, 4.2262e+00, 9.9424e-02,
-1.6063e+00, 6.2787e-01, -2.5430e-01, 9.1142e-01, 1.5559e-01,
3.0832e+00, 2.4105e-01, -2.3489e-01, -4.6815e-01, -1.2299e+00,
4.1385e-01, 6.5344e-03, 3.8777e-01, 5.8309e-01, -1.4908e+00,
2.8903e-01, -7.0837e-01, -5.3410e-01, -3.4958e-03, -2.8071e+00,
-2.0681e-01, -1.0029e+00, -1.2692e-01, -6.9856e-01, 4.4065e+00,
-7.0137e-02, 6.0135e-01, 1.1621e-01, -7.2703e-01, -5.9359e-01,
1.5025e-01, -5.2114e-01, -6.2804e-01, 2.1222e-01, -3.1270e+01,
-1.5322e+00, 1.1399e+00, 1.4421e-01, 6.7192e-01, 5.5183e-02,
-5.6790e-01, 3.0249e-01, -5.6765e-01, 1.3607e-01, 5.9236e-01,
-1.3364e+00, 4.9346e+00, -1.0410e+00, -4.6063e-01, -7.7269e-01,
-4.4073e-01, 5.4013e-01, -8.6504e-01, 4.5763e-01, -1.0415e+00,
-9.1738e-02, 9.6063e-01, -2.8823e-01, -4.2659e-01, 5.7030e-01,
-4.7005e-02, -4.5133e-01, -2.2809e-01, 1.3206e-01, 4.5811e-02,
9.3358e+00, -3.9568e-01, -4.4094e-02, 6.6723e-01, -4.0498e-01,
-2.6842e-01, 2.6176e-02, -1.1839e-01, -3.9041e-01, 7.0255e-04,
-3.3377e-01, 1.9423e-01, 3.8088e+00, 3.5921e+00, 1.5202e-01,
4.9591e+00, 1.2007e-01, 4.8731e-01, 5.7460e-02, -6.8261e-01,
-8.3474e+00, -7.1905e-01, 1.2885e+01, 5.4042e-01, -5.3001e-01,
-1.8281e+00, 1.1641e+00, -6.4115e-01, 1.1523e+00, -1.4470e+00,
-1.7008e-01, -2.1594e-01, -2.9089e-02, 5.8464e-01, -1.4488e-01,
7.1962e-01, -6.4103e-02, 2.5908e-01, -1.0145e-01, -2.4248e-02,
-2.0624e-01, -4.1559e-01, 4.3878e-01, -2.6849e-02, 6.6292e-01,
-1.9844e-01, -3.2533e-01, -1.0234e-01, 9.2657e-02, 4.8087e-01,
-7.2569e-01, 1.0871e+01, -4.4782e+00, 1.1555e+00, -1.8037e-01,
-1.3585e-01, 4.1578e-01, 1.1390e-01, 6.3028e-01, 7.7623e-01,
-1.9664e+00, 1.5135e-01, -3.6405e+00, -3.8648e-01, 6.1547e-01,
1.5643e-01, 9.3263e-02, 8.1692e-01, 6.7228e-02, -7.1655e-02,
1.9427e-01, 3.3246e-01, -1.5778e+01, 6.7203e-01, -8.6169e-01,
9.0010e+00, -8.7306e-01, 7.3237e-01, 3.8963e-01, 7.3310e-01,
4.5097e-02, 1.8724e-01, 9.1519e-01, 7.6903e+00, -3.1040e-01,
2.6296e-01, 3.5329e-01, 7.2229e-01, 1.8164e-01, 1.8765e-01,
-2.9655e+01, 5.6940e-01, 3.4993e-01, -4.9079e-01, -3.6331e-02,
-8.2086e-01, 1.2464e-01, -1.0329e+00, -1.0741e+00, 3.2376e-01,
6.0191e-01, 5.4995e-02, -1.9888e-01, 1.0859e+00, -1.1630e+01,
-3.3229e-01, 4.3768e-01, -3.7276e-01, 7.2431e-01, -8.3915e-01,
-4.2646e-01, -4.7301e-02, -8.3625e-01, 5.7841e-01, 1.3629e-01,
-1.2626e+00, -7.2264e-01, 2.0603e-01, 3.6301e-01, 4.0489e-01,
-7.4101e-01, 1.0360e-01, 5.6793e-01, 1.1387e-01, -5.0085e-01,
5.5292e-01, -2.8517e-01, -5.3055e-01, -8.4060e-01, -8.0749e-02,
1.4585e-01, 1.0964e+00, 9.8573e-01, 1.6353e-01, -5.3257e-01,
8.4816e-01, 6.0900e+00, 5.6142e+00, 3.4809e-01, 5.2744e+00,
8.3104e-01, 1.1439e+01, -6.9536e-01, 1.2885e-01, 5.1153e-01,
3.8695e-01, 7.0928e-01, 3.5301e-01, -5.0566e-01, -4.7296e-01,
4.5524e-01, -1.4350e+00, -2.9241e+00, 9.6179e+00, 6.3856e-01,
6.2121e-02, 8.4120e-01, 4.4685e-01, -6.8144e-02, -4.6556e-01,
2.3636e-01, -5.5843e-01, -1.1836e+00, 8.1906e-01, -4.6677e-01,
6.5789e+00, -1.0363e+00, 7.7462e+00, -7.6452e-01, -1.8418e+00,
1.4656e-01, 1.1439e+00, -5.2971e-01, 1.1212e+01, 7.5197e-01,
4.4209e-01, -3.5320e-01, 4.1087e-01, 8.0785e-01, -2.2465e-01,
-1.2364e-01, -4.1837e-01, -2.5483e+00, -7.6983e-01, 1.4111e-01,
-6.2874e-01, -8.8027e-02, 4.7306e-01, 1.3312e-02, -6.9450e-01,
-1.1118e+00, 7.6371e-01, 4.7137e-01, -3.7205e-01, 1.2002e+00,
-9.6311e-01, -5.4649e-01, 3.2844e-01, 3.1530e-01, 5.2128e-01,
3.1231e-02, 5.0209e+00, 8.2332e-01, -8.3126e-01, 4.4870e-01,
5.0346e-01, 6.7402e-01, 4.2432e-01, -4.3691e-01, -1.3725e+00,
-5.2472e-01, 5.2077e-04, 1.5176e-01, -1.7166e-02, -1.4921e-01,
-2.7750e-01, 6.0583e-01, -6.9361e-01, 3.7252e-02, -6.6764e-01,
-4.8876e+00, -1.0858e+00, -8.7510e-01, -5.6715e-01, -2.4048e+00,
-5.2062e-01, -4.1056e-01, -3.1579e-01, -1.9164e+00, -4.0713e+00,
-7.9799e-01, 9.4610e-02, -2.9357e-01, -3.6255e-01, -6.1240e-01,
-4.8822e-01, -8.6886e-01, -6.7513e-01, -6.2841e-01, -7.8575e-01,
-1.0141e+00, -1.0293e+00, 7.3600e-02, -1.9057e+01, 7.8292e+00,
-5.7524e-01, -8.7335e-01, -8.3840e-01, -2.4755e-01, 3.1148e-01,
-1.3108e+00, -5.7400e-01, -5.6589e-01, -1.4413e-01, 1.3141e+01,
-1.4567e-01, -3.5524e-01, -5.4817e-01, -8.4205e-01, -1.1962e+00,
-2.0332e+00, 2.4699e-01, -2.1461e+00, -1.6782e+00, -4.9797e+00,
6.7784e+00, 5.6267e-01, -3.2829e-01, -1.1577e+00, -1.7021e+00,
2.8842e-01, -1.4424e+00, -1.7387e-01, -1.3089e-01, -8.3507e-01,
-9.2152e-01, -1.7488e-01, 6.3576e-02, -5.1483e-01, -6.1928e-01,
-1.2056e+00, -2.0253e+00, 8.1032e-01, -4.6983e-01, -1.1170e+00,
-3.6558e-01, -1.2480e+00, -9.8161e-01, -1.4049e+00, 7.6402e+00,
-3.6551e-01, -7.8730e-03, -7.3599e-01, 7.3325e-01, -1.2999e-01]],
requires_grad=True)
\"<Image\" />
<FixedShapeTensor>
alp
970
tensor([[-2.9406e-01, -3.9664e-01, -5.8837e-01, -1.4149e+00, -7.2320e-01,
-2.2410e+00, -9.7316e-01, 1.5885e-01, -3.0665e-02, -7.7977e-01,
-8.4206e-01, -7.4727e-01, -4.1816e-01, -1.1171e+00, -8.3957e-01,
-7.2551e-01, -6.5612e-01, -2.1734e-01, -7.5877e-01, -8.1655e-01,
-1.4666e+00, 6.8534e+00, 2.3918e+00, -2.2120e+01, -1.1626e+00,
-1.0394e+00, -1.1378e+00, -1.0831e+00, 1.2214e+01, 8.3887e+00,
-9.4926e-01, -9.2964e-01, -1.6574e+01, -1.3692e+00, -1.8884e+00,
-6.0111e-01, -1.7153e-01, -5.9309e-01, -7.9982e-01, -1.8660e-01,
-7.6009e-01, -1.0750e+00, -9.2130e-01, -2.9906e-01, -7.9046e-01,
-5.9484e-01, 6.1356e+00, -6.4141e-01, -1.1772e+00, -2.1657e+00,
8.2907e+00, 1.6004e-01, -4.5558e+00, -2.7258e+00, -1.1597e+00,
-1.0933e+01, -3.5250e-01, -1.6114e+01, -2.6516e-01, -9.5088e-01,
-2.2213e+00, -3.5355e-01, -3.8035e+00, -3.0580e-01, -8.6134e-01,
4.6872e+00, -1.4660e+00, -1.3926e+00, -2.0184e+00, -1.1133e+00,
-1.4065e+00, -2.1029e-01, -1.0242e+00, -2.1178e-01, 5.4390e+00,
-1.0446e+00, -1.9681e-01, -5.6447e-01, -1.9908e-01, -1.8673e-01,
7.7422e+00, -1.4272e+00, -1.7023e-01, -2.9570e-01, -5.2039e-01,
-3.1763e-01, -2.0674e-02, -1.6901e-01, -1.0512e-02, -5.1825e-01,
-8.3698e-01, -7.0063e-01, 6.3533e+00, -5.1435e-01, 1.2796e-01,
-1.4354e+00, -7.6603e-01, -1.0521e-01, -1.1031e+00, -3.5712e-01,
-1.0332e+00, 7.5147e+00, -8.5460e-01, -2.5403e-01, -5.4898e-01,
-6.0829e-01, -7.3291e-01, 2.6935e+00, 5.4835e+00, -1.3480e+01,
-1.8402e+00, -1.2724e+00, 6.1110e-01, -3.5563e-02, -1.2901e+00,
-1.0908e+00, -1.0241e+00, -5.1424e-01, 3.2534e-02, -6.2045e-01,
-9.6891e-01, -5.4501e-01, 5.3297e+00, -2.3010e-01, 5.1841e-01,
-3.5784e-01, -9.4067e-02, -1.2004e+00, -9.1853e-01, 1.4070e+00,
-1.2968e+00, -1.2999e+00, -1.0496e+00, -9.5197e-01, -6.6701e-01,
-8.9681e-01, -7.8379e-01, -9.4913e-01, -9.0083e-01, -1.0882e+00,
-1.2324e+00, -1.1427e+00, 5.0570e+00, -1.4333e+00, -5.4541e-01,
-6.1113e-01, -8.6848e-01, 2.4377e+00, -1.0680e+00, 9.8447e+00,
5.6174e+00, 1.0218e+00, -1.1402e+00, -7.2895e-01, -1.8470e-01,
5.6801e-02, -7.3009e-01, -3.2811e+00, -6.1589e-02, 2.1895e+01,
1.0091e+01, 4.4335e-01, -3.3916e-02, 2.0015e-01, 1.3278e-01,
-2.9462e-01, -3.3158e-01, -5.9422e+00, 1.7484e-01, -9.6596e-01,
-4.7054e-01, 3.6560e-01, 2.5628e-01, -2.3975e+01, -8.7247e-02,
7.8706e+00, -1.0385e-01, -4.3048e-01, 1.5711e-01, 7.7143e-02,
7.8516e-02, -3.4594e-01, 4.5937e-02, 5.6425e+00, 1.5112e-01,
2.2575e-01, 2.4955e-01, -1.6244e-01, -1.8184e-01, -2.3581e-02,
-6.7364e-01, -2.7870e-02, 1.8746e-02, -2.8437e-02, -6.0419e-01,
4.8259e-01, -8.7599e-02, -2.6273e-01, 3.2028e+00, 2.3763e-01,
-2.6528e-01, 8.2496e+00, 1.1074e-01, 1.4258e-01, -2.2528e-01,
9.3379e-03, -1.7632e-01, 2.0691e-01, 5.8571e-01, -1.6069e-02,
-2.6504e-01, 2.4425e-02, 9.9887e-02, -3.3414e-01, -2.0628e-01,
-2.1807e-01, -2.9011e-01, -1.0363e-01, -7.1037e-01, 1.7000e-01,
-2.7459e-01, -3.6257e-01, -1.3321e-01, 1.2090e-01, -1.1051e-01,
1.8445e-01, -2.5708e-01, 5.4282e+00, -5.8523e-01, -4.4985e-01,
8.5578e+00, -4.5061e+00, -4.0017e-01, -4.1432e-01, 3.5181e-01,
3.9117e-01, 2.9906e-01, 2.2990e-01, 2.9776e-01, -3.1685e-01,
2.5615e-01, -1.8224e-01, 1.6206e-01, 2.7632e-02, -4.6059e-01,
2.9663e-01, 1.1910e-01, -1.9236e-01, 4.1063e-01, 9.8644e-02,
2.6365e-01, 1.5021e-01, -9.2747e-01, 2.5778e-01, 4.7184e-01,
-7.9703e-01, 1.0053e+01, -5.5267e-03, -4.5335e-01, 8.7502e+00,
-5.0980e-01, -8.8995e-01, -3.1298e-01, -2.5574e-01, 2.9981e-01,
2.3983e-01, -1.0101e-02, -1.8801e-02, -5.3076e-02, -5.2083e-01,
9.5328e+00, -8.4371e-01, -5.4568e-01, -1.2985e-01, -1.0995e+00,
7.8161e+00, -9.3372e-01, -6.6218e-01, -9.0596e-01, -6.4735e-01,
-4.4208e-01, 1.2425e-01, -1.4217e+00, -4.1171e-01, 1.3034e-01,
-7.3342e-02, 9.3341e+00, -4.6729e+00, -9.0944e-01, -1.3831e+00,
-1.7844e+00, -5.9477e+00, -8.5478e-01, -8.7530e-01, -1.1970e+00,
-7.5070e-01, -7.1503e-01, -1.1938e+00, -7.8174e-01, -1.0038e+00,
-1.4613e+00, -7.4007e-01, -5.7681e-01, -3.6986e-01, -9.8656e-01,
-1.2252e+00, 1.1815e-02, -9.2557e-01, -1.3181e+00, -5.0359e-01,
-1.4871e-01, 3.5457e+00, -4.3417e-01, -1.8000e-01, 1.1858e-01,
-5.7208e-01, -8.2794e-01, -1.1515e+00, -9.9724e-01, -4.8915e-01,
-1.1499e+00, -9.0271e-01, -9.7713e-01, -1.1906e+00, 7.4759e+00,
-1.2325e+00, -1.1113e+00, -3.7537e-01, -1.1247e+00, -8.1317e-01,
7.7153e+00, -3.0876e-01, 2.4818e+00, -1.8819e-02, -2.9745e+01,
-1.0812e+00, -1.5324e+00, -1.0727e+00, -6.3278e-01, -6.5594e-01,
-2.1141e-01, -2.9419e-01, -1.0257e+00, -1.0231e+00, -4.4509e-01,
3.7156e-01, -8.5344e-01, -1.2985e+00, -2.2088e-01, 6.4307e+00,
-1.0552e+00, -1.4340e+00, 3.1215e+00, -4.4853e-01, -4.1941e-01,
1.6904e-01, 6.6213e-01, -2.4516e-01, 6.6184e+00, 1.8444e-01,
-7.6232e-01, -9.0120e-02, -1.3869e+00, -9.8926e-02, -3.9975e-01,
-6.4051e-01, 7.7341e+00, -9.9642e-01, -5.8463e-01, 8.0843e+00,
7.7446e+00, -5.9572e+00, -1.3195e+00, 7.6834e+00, -1.4794e+00,
-1.4984e+00, -8.2104e-01, -3.1359e+00, -5.0445e-01, -5.7738e-01,
-6.2605e-01, -4.7661e-01, -4.0145e-01, 1.5603e+01, -5.4976e-01,
-4.8561e-01, -8.7848e-01, -9.2070e-01, -9.4457e-01, -5.9788e-02,
-2.5845e+00, 7.1403e-02, 7.0118e+00, -1.5812e+00, 6.7361e+00,
-1.8427e-01, -1.0359e+00, -1.2749e+00, 6.7171e+00, 8.5411e-03,
-5.9113e-01, 3.3010e-01, 1.2502e-01, -1.6013e+00, -1.2181e+00,
-3.5198e-01, -3.2387e-01, -3.7935e-01, -4.9484e-01, 6.2369e-01,
-5.5559e-02, 2.8182e-02, 1.1607e+00, 3.1603e-01, 9.1170e-02,
-5.0581e+01, -3.3354e-01, -6.9272e-03, 8.9704e-04, -5.2492e-01,
6.6158e-01, 6.1548e-01, -2.6688e-01, -1.1560e+00, 7.1814e+00,
-1.0232e+00, 3.0500e-02, 7.7873e-01, 7.7198e-01, 8.5599e-02,
-8.0316e-01, -4.9364e+00, -9.1021e-03, -2.1112e-01, -6.0259e-01,
7.8204e-01, -3.4360e-01, -6.4694e-01, 2.0318e-01, 2.0348e-01,
5.0687e-01, -1.3695e-01, -2.0527e-01, -8.3965e-01, -3.8869e-01,
-1.5978e-02, 1.9347e-01, 7.0880e-01, 4.0714e-01, -5.1326e-01,
-1.4493e-01, 3.1495e-01, 2.0955e-01, -4.6648e-01, -3.4225e-01,
3.4637e-02, 1.0108e+00, 9.0532e-01, -2.0762e-01, 3.5621e-01,
-7.2445e-01, 1.3067e+01, 6.7726e-01, 1.3835e+00, 6.2843e-01,
-2.8452e-01, -1.0192e+00, 8.6062e+00, -1.6392e+00, 5.7381e-01,
-1.0435e+00, 1.5714e-01, 1.2279e-01, 8.6283e+00, 6.2561e+00,
-4.0468e-02, 7.0681e+00, 1.2585e-01, -2.1681e+01, 1.4611e+01,
-2.3411e-01, -7.2368e-02, -1.8365e-03, 2.9596e+00, -1.1411e+00,
-2.9730e-01, 1.7938e+01, 7.2424e+00, 5.1016e-01, 1.0051e-01,
1.0628e-01, 3.2571e-01, 4.6448e-01, -8.7698e-01, 2.7590e-01,
-1.0133e+00, -4.5238e-01, -5.4223e-01, 7.0200e+00, 7.4948e-01,
-1.1842e+00, 3.3506e-01, 7.5488e-01, 2.4372e-01, 3.0494e-01,
6.4743e-01, 2.3148e-01, 1.7616e-01, 1.3851e-02, -9.6435e-02,
-8.9037e-01, -6.0654e-01, 4.2858e-01, 1.5258e-01, 9.8812e-01,
1.3079e+00, -1.0711e+01, 4.2166e+00, 8.7998e-01, 3.3054e-01,
-5.8444e+00, -3.4548e-02, 4.9459e-01, 7.7938e-01, 2.9191e-01,
-7.0197e-01, -7.9523e-02, -3.6578e-01, -1.2519e-01, -3.2218e-02,
5.0998e-01, -5.6472e-02, -5.2412e-01, -8.2180e-01, 5.2325e-02,
-8.4480e-01, -2.6161e-01, -7.4100e-01, -8.2730e-02, 4.7658e-01,
-1.0537e+00, 8.9414e-01, 4.8828e-01, 3.0081e-01, 1.5052e-01,
4.5074e-01, 4.4070e-01, -1.4736e+00, -1.0278e+00, -7.8714e-02,
-5.8056e-01, -3.5484e-01, 3.6458e-01, -1.6018e-01, -1.2974e+00,
-4.8609e-01, 1.2365e-01, 1.4940e-01, 8.0415e-01, 2.2631e-01,
-3.0853e-01, -5.1093e-01, 4.4652e-01, -5.7071e-02, -2.6952e+00,
9.4816e+00, -1.2769e-01, 4.3953e-01, 1.5600e-01, -1.0600e+00,
6.0080e-01, -2.0380e-01, 5.0884e-01, -4.7676e+00, -3.3775e-01,
-4.2323e-01, -6.1404e-01, 2.8504e-01, -2.0254e-02, -4.0116e-01,
-4.9001e-02, -2.4928e-01, 3.4356e-01, 8.4512e-02, 2.8335e-01,
-2.1168e-01, 1.0649e+01, 1.0420e+00, 4.8394e+00, 6.1964e-01,
-3.4492e-01, 3.8942e+00, -4.9420e-01, 5.2617e-01, 1.2750e-01,
7.2947e+00, 8.1702e-01, -4.1560e-02, -5.4786e-01, -2.0883e-01,
9.6875e-01, -2.9027e-01, -2.2928e-01, -4.4895e-01, 1.4488e-01,
-1.6438e-01, 5.9487e-01, -5.8163e-01, 5.3420e+00, -5.8452e-01,
6.9544e-01, 3.6611e-01, -7.5611e-01, 3.0533e-01, -1.3884e-01,
-1.0706e-01, 6.9872e-01, 2.7047e-01, 1.1215e+00, 4.6718e-01,
4.0828e-01, -1.7092e+00, -1.6398e-01, 2.5839e-01, -6.7712e-02,
-8.7694e-01, 4.0819e-01, -4.9471e-01, -8.6026e-01, -2.4154e+00,
-1.2276e-01, -2.7824e-01, 5.3658e+00, 4.9554e-01, -1.9481e-01,
1.2292e-01, 7.8525e-01, 2.3872e-02, 4.0939e+00, 3.4132e-02,
-1.5645e+00, 7.0096e-01, -2.1828e-01, 8.5402e-01, 1.5863e-01,
2.9813e+00, 2.3696e-01, -2.3722e-01, -3.8006e-01, -1.1631e+00,
4.3912e-01, 3.7513e-02, 4.3560e-01, 5.2420e-01, -1.4707e+00,
2.7144e-01, -7.4455e-01, -5.1493e-01, 5.2521e-02, -2.6730e+00,
-1.8987e-01, -9.9107e-01, -1.6863e-01, -7.1332e-01, 4.2669e+00,
-5.3371e-02, 5.9373e-01, 1.4868e-01, -7.0538e-01, -6.0082e-01,
1.4704e-01, -4.9876e-01, -6.0873e-01, 1.7014e-01, -3.0275e+01,
-1.4344e+00, 1.0841e+00, 1.7418e-01, 5.6397e-01, 8.2630e-02,
-5.9019e-01, 2.8368e-01, -4.4285e-01, 1.4122e-01, 6.4251e-01,
-1.3243e+00, 4.8020e+00, -1.0624e+00, -4.6566e-01, -7.7769e-01,
-4.0430e-01, 4.7527e-01, -8.2211e-01, 4.6503e-01, -1.0446e+00,
-8.7541e-02, 9.1078e-01, -1.9699e-01, -4.1365e-01, 5.1298e-01,
1.4619e-02, -4.1418e-01, -2.2742e-01, 1.1421e-01, -6.2088e-04,
9.0550e+00, -3.6152e-01, -4.1244e-02, 6.7183e-01, -3.5516e-01,
-2.8442e-01, 5.1599e-02, -1.5489e-01, -4.2900e-01, -5.0945e-02,
-2.3836e-01, 1.9734e-01, 3.5747e+00, 3.5048e+00, 1.2286e-01,
4.7291e+00, 1.6677e-01, 5.0529e-01, 1.1175e-01, -5.6003e-01,
-8.1337e+00, -7.2567e-01, 1.2407e+01, 4.9879e-01, -5.0294e-01,
-1.8211e+00, 1.1139e+00, -7.3934e-01, 1.1462e+00, -1.4798e+00,
-2.1136e-01, -1.6329e-01, 1.8851e-03, 5.9336e-01, -2.1684e-01,
7.2945e-01, -4.1354e-02, 1.7159e-01, -1.9134e-01, -3.2639e-02,
-2.3387e-01, -4.5282e-01, 5.6362e-01, -2.1310e-03, 5.7659e-01,
-1.9335e-01, -4.2932e-01, -6.7643e-02, 3.1747e-02, 4.1841e-01,
-6.5551e-01, 1.0555e+01, -4.1258e+00, 1.1800e+00, -2.4468e-01,
-8.0830e-02, 3.7979e-01, 1.4558e-01, 7.2083e-01, 8.2094e-01,
-1.9364e+00, 1.5049e-01, -3.4493e+00, -3.7798e-01, 6.4016e-01,
1.2758e-01, 1.2351e-01, 8.7173e-01, 3.6028e-02, -8.3858e-02,
2.3584e-01, 3.5734e-01, -1.5258e+01, 6.5326e-01, -8.1408e-01,
8.7982e+00, -8.2415e-01, 7.5384e-01, 4.1488e-01, 6.9907e-01,
2.6800e-02, 1.5982e-01, 9.0493e-01, 7.4287e+00, -2.6007e-01,
2.3836e-01, 3.3552e-01, 6.1951e-01, 1.7672e-01, 1.9357e-01,
-2.8489e+01, 5.6765e-01, 3.8560e-01, -3.6292e-01, 2.4188e-02,
-7.4524e-01, 1.0474e-01, -1.0021e+00, -1.0982e+00, 3.6997e-01,
5.4891e-01, 9.9697e-02, -2.0172e-01, 1.1126e+00, -1.1220e+01,
-3.8723e-01, 3.9403e-01, -4.2756e-01, 7.1580e-01, -8.3587e-01,
-3.9499e-01, 1.9091e-02, -8.8970e-01, 5.4678e-01, 2.2064e-01,
-1.2714e+00, -7.0085e-01, 1.8650e-01, 3.7933e-01, 3.4788e-01,
-6.9564e-01, 5.8728e-02, 5.8410e-01, 7.7375e-02, -4.6896e-01,
4.7287e-01, -2.9092e-01, -5.1333e-01, -8.0839e-01, -1.2330e-01,
1.8131e-01, 1.0093e+00, 9.5369e-01, 2.0513e-01, -4.9450e-01,
8.6447e-01, 5.8131e+00, 5.3772e+00, 3.4961e-01, 5.0210e+00,
7.7519e-01, 1.1220e+01, -6.5384e-01, 3.6045e-02, 4.9234e-01,
4.1998e-01, 6.7939e-01, 4.1427e-01, -5.4165e-01, -4.1638e-01,
4.2398e-01, -1.4143e+00, -2.7998e+00, 9.3041e+00, 5.3229e-01,
4.6536e-02, 7.9408e-01, 4.8043e-01, -6.7394e-02, -4.8614e-01,
2.1090e-01, -5.7115e-01, -1.1627e+00, 8.5690e-01, -4.7997e-01,
6.2639e+00, -9.6856e-01, 7.4864e+00, -7.2197e-01, -1.8133e+00,
1.7946e-01, 1.1411e+00, -5.2402e-01, 1.0813e+01, 7.0472e-01,
4.8512e-01, -3.3854e-01, 3.4711e-01, 8.1417e-01, -3.0601e-01,
-1.8852e-01, -3.2081e-01, -2.4498e+00, -7.9404e-01, 1.6746e-01,
-5.9490e-01, -1.7251e-01, 4.8110e-01, 2.6063e-02, -6.2740e-01,
-1.0655e+00, 7.7957e-01, 4.3022e-01, -3.2064e-01, 1.1387e+00,
-8.4732e-01, -5.2953e-01, 3.1137e-01, 2.9283e-01, 6.0483e-01,
7.9638e-02, 4.8899e+00, 8.5033e-01, -8.3357e-01, 4.8676e-01,
4.6297e-01, 6.0738e-01, 3.7526e-01, -4.6644e-01, -1.2772e+00,
-5.1409e-01, 4.9370e-02, 1.7685e-01, -2.2113e-02, -1.4460e-01,
-3.0095e-01, 5.9490e-01, -5.7839e-01, 1.0648e-02, -6.1280e-01,
-4.7614e+00, -1.1388e+00, -8.7777e-01, -5.2101e-01, -2.3257e+00,
-5.4177e-01, -3.8826e-01, -3.5506e-01, -1.8402e+00, -3.9653e+00,
-7.6208e-01, 8.7070e-02, -3.3776e-01, -3.3950e-01, -6.2380e-01,
-5.6249e-01, -9.0860e-01, -7.7653e-01, -6.5774e-01, -8.0787e-01,
-9.6694e-01, -1.0197e+00, 9.2538e-02, -1.8464e+01, 7.5323e+00,
-5.4676e-01, -8.1616e-01, -8.7753e-01, -2.5963e-01, 3.4584e-01,
-1.2955e+00, -5.4957e-01, -5.1894e-01, -1.5269e-01, 1.2589e+01,
-1.3563e-01, -3.4128e-01, -5.7195e-01, -9.0047e-01, -1.2300e+00,
-1.9702e+00, 2.6821e-01, -2.0906e+00, -1.6163e+00, -4.8408e+00,
6.6520e+00, 4.2927e-01, -2.9416e-01, -1.1562e+00, -1.6190e+00,
3.1745e-01, -1.3883e+00, -1.8319e-01, -9.5329e-02, -8.3665e-01,
-8.5588e-01, -1.4261e-01, 6.0884e-02, -5.4049e-01, -5.5738e-01,
-1.2308e+00, -1.9899e+00, 7.7201e-01, -4.1223e-01, -1.0523e+00,
-3.3784e-01, -1.2651e+00, -9.0390e-01, -1.3715e+00, 7.3094e+00,
-3.9702e-01, -8.1356e-02, -6.8597e-01, 7.0341e-01, -1.4777e-01]],
requires_grad=True)
\n", + "(Showing first 2 rows)\n", + "" + ], + "text/plain": [ + "╭──────────────┬────────────────────────────────────┬──────────────────────┬──────────┬────────────────────────────────╮\n", + "│ image ┆ arr ┆ class_human_readable ┆ class_id ┆ model_predictions │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ Image[MIXED] ┆ FixedShapeTensor(UInt8, [256, 256, ┆ Utf8 ┆ Int64 ┆ Python │\n", + "│ ┆ 3]) ┆ ┆ ┆ │\n", + "╞══════════════╪════════════════════════════════════╪══════════════════════╪══════════╪════════════════════════════════╡\n", + "│ ┆ sea_snake ┆ 65 ┆ tensor([[-2.7218e-01, -3.9199… │\n", + "├╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤\n", + "│ ┆ alp ┆ 970 ┆ tensor([[-2.9406e-01, -3.9664… │\n", + "╰──────────────┴────────────────────────────────────┴──────────────────────┴──────────┴────────────────────────────────╯\n", + "\n", + "(Showing first 2 rows)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df.with_column(\n", + " \"model_predictions\",\n", + " df[\"arr\"].apply(lambda arr: model(torch.tensor(arr).permute(2,0,1).unsqueeze(0).float()), return_dtype=daft.DataType.python())\n", + ").show(2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6dfb44fa-3c29-4e16-9b23-b02bb3276f06", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}