Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datasets] Add from_torch #29588

Merged
merged 12 commits into from
Nov 7, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ werkzeug
wandb
tensorflow; sys_platform != 'darwin' or platform_machine != 'arm64'
tensorflow-macos; sys_platform == 'darwin' and platform_machine == 'arm64'
torch
torchvision
transformers

# Ray libraries
Expand Down
8 changes: 5 additions & 3 deletions doc/source/data/api/input_output.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ Mars
.. automethod:: ray.data.Dataset.to_mars
:noindex:

Torch
-----

.. autofunction:: ray.data.from_torch

HuggingFace
------------

Expand Down Expand Up @@ -193,9 +198,6 @@ Built-in Datasources
.. autoclass:: ray.data.datasource.SimpleTensorFlowDatasource
:members:

.. autoclass:: ray.data.datasource.SimpleTorchDatasource
:members:

.. autoclass:: ray.data.datasource.TFRecordDatasource
:members:

Expand Down
13 changes: 5 additions & 8 deletions doc/source/data/creating-datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -476,23 +476,20 @@ From Torch and TensorFlow
.. tabbed:: PyTorch

If you already have a Torch dataset available, you can create a Ray Dataset using
:py:class:`~ray.data.datasource.SimpleTorchDatasource`.
:class:`~ray.data.from_torch`.

.. warning::
:py:class:`~ray.data.datasource.SimpleTorchDatasource` doesn't support parallel
:py:class:`~ray.data.datasource.from_torch` doesn't support parallel
reads. You should only use this datasource for small datasets like MNIST or
CIFAR.

.. code-block:: python

import ray.data
from ray.data.datasource import SimpleTorchDatasource
import ray
import torchvision

dataset_factory = lambda: torchvision.datasets.MNIST("data", download=True)
dataset = ray.data.read_datasource(
SimpleTorchDatasource(), parallelism=1, dataset_factory=dataset_factory
)
dataset = torchvision.datasets.MNIST("data", download=True)
dataset = ray.data.from_torch(dataset)
dataset.take(1)
# (<PIL.Image.Image image mode=L size=28x28 at 0x1142CCA60>, 5)

Expand Down
52 changes: 17 additions & 35 deletions doc/source/ray-air/examples/torch_image_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,67 +63,49 @@
"metadata": {},
"outputs": [
{
"name": "stderr",
"name": "stdout",
"output_type": "stream",
"text": [
"2022-08-30 15:30:36,678\tINFO worker.py:1510 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32m127.0.0.1:8265 \u001b[39m\u001b[22m\n",
"2022-08-30 15:30:37,791\tWARNING read_api.py:291 -- ⚠️ The number of blocks in this dataset (1) limits its parallelism to 1 concurrent tasks. This is much less than the number of available CPU slots in the cluster. Use `.repartition(n)` to increase the number of dataset blocks.\n",
"\u001b[2m\u001b[36m(_get_read_tasks pid=3958)\u001b[0m 2022-08-30 15:30:37,789\tWARNING torch_datasource.py:55 -- `SimpleTorchDatasource` doesn't support parallel reads. The `parallelism` argument will be ignored.\n"
"Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz\n"
]
},
{
"name": "stdout",
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(_execute_read_task pid=3958)\u001b[0m Using downloaded and verified file: ./data/cifar-10-python.tar.gz\n",
"\u001b[2m\u001b[36m(_execute_read_task pid=3958)\u001b[0m Extracting ./data/cifar-10-python.tar.gz to ./data\n"
"100%|██████████| 170498071/170498071 [00:21<00:00, 7792736.24it/s]\n"
]
},
{
"name": "stderr",
"name": "stdout",
"output_type": "stream",
"text": [
"2022-08-30 15:30:44,508\tWARNING read_api.py:291 -- ⚠️ The number of blocks in this dataset (1) limits its parallelism to 1 concurrent tasks. This is much less than the number of available CPU slots in the cluster. Use `.repartition(n)` to increase the number of dataset blocks.\n",
"\u001b[2m\u001b[36m(_get_read_tasks pid=3958)\u001b[0m 2022-08-30 15:30:44,507\tWARNING torch_datasource.py:55 -- `SimpleTorchDatasource` doesn't support parallel reads. The `parallelism` argument will be ignored.\n"
"Extracting data/cifar-10-python.tar.gz to data\n",
"Files already downloaded and verified\n"
]
},
{
"name": "stdout",
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[2m\u001b[36m(_execute_read_task pid=3958)\u001b[0m Files already downloaded and verified\n"
"2022-10-23 10:33:48,403\tINFO worker.py:1518 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32m127.0.0.1:8265 \u001b[39m\u001b[22m\n"
]
}
],
"source": [
"import ray\n",
"from ray.data.datasource import SimpleTorchDatasource\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"\n",
"transform = transforms.Compose(\n",
" [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n",
")\n",
"\n",
"train_dataset = torchvision.datasets.CIFAR10(\"data\", download=True, train=True, transform=transform)\n",
"test_dataset = torchvision.datasets.CIFAR10(\"data\", download=True, train=False, transform=transform)\n",
"\n",
"def train_dataset_factory():\n",
" return torchvision.datasets.CIFAR10(\n",
" root=\"./data\", download=True, train=True, transform=transform\n",
" )\n",
"\n",
"\n",
"def test_dataset_factory():\n",
" return torchvision.datasets.CIFAR10(\n",
" root=\"./data\", download=True, train=False, transform=transform\n",
" )\n",
"\n",
"\n",
"train_dataset: ray.data.Dataset = ray.data.read_datasource(\n",
" SimpleTorchDatasource(), dataset_factory=train_dataset_factory\n",
")\n",
"test_dataset: ray.data.Dataset = ray.data.read_datasource(\n",
" SimpleTorchDatasource(), dataset_factory=test_dataset_factory\n",
")"
"train_dataset: ray.data.Dataset = ray.data.from_torch(train_dataset)\n",
"test_dataset: ray.data.Dataset = ray.data.from_torch(test_dataset)"
]
},
{
Expand Down Expand Up @@ -156,7 +138,7 @@
"id": "a89b59e8",
"metadata": {},
"source": [
"{py:class}`SimpleTorchDatasource <ray.data.datasource.SimpleTorchDatasource>` doesn't parallelize reads, so you shouldn't use it with larger datasets.\n",
"{py:class}`from_torch <ray.data.from_torch>` doesn't parallelize reads, so you shouldn't use it with larger datasets.\n",
"\n",
"Next, let's represent our data using a dictionary of ndarrays instead of tuples. This lets us call {py:meth}`Dataset.iter_torch_batches <ray.data.Dataset.iter_torch_batches>` later in the tutorial."
]
Expand Down Expand Up @@ -828,7 +810,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.12 ('.venv': venv)",
"display_name": "Python 3.10.8 ('.venv': venv)",
"language": "python",
"name": "python3"
},
Expand All @@ -842,11 +824,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.10.8"
},
"vscode": {
"interpreter": {
"hash": "a658351b4133f922c5967ed6133cfc05c9f16c53a5161e5843ace3f528fccaf5"
"hash": "c704e19737f24b51bc631dadcac7a7e356bb35d1c5cd7766248d8a6946059909"
}
}
},
Expand Down
Loading