diff --git a/.github/scripts/check_complete_doc.py b/.github/scripts/check_complete_doc.py index fbf2bae05..6d174a245 100644 --- a/.github/scripts/check_complete_doc.py +++ b/.github/scripts/check_complete_doc.py @@ -38,14 +38,13 @@ def compare_sets(set_a, set_b, ignore_set=None): def main(): - datapipes_folder = os.path.join("torchdata", "datapipes") init_file = "__init__.py" docs_source_folder = os.path.join("docs", "source") exit_code = 0 - for target, ignore_set in zip(["iter", "map", "utils"], [{"IterDataPipe", "Extractor"}, {"MapDataPipe"}, {}]): - init_path = os.path.join(datapipes_folder, target, init_file) - rst_path = os.path.join(docs_source_folder, "torchdata.datapipes." + target + ".rst") + for target, ignore_set in [("stateful_dataloader", {})]: + init_path = os.path.join("torchdata", target, init_file) + rst_path = os.path.join(docs_source_folder, "torchdata." + target + ".rst") init_set = collect_init_dps(init_path) rst_set = collect_rst_dps(rst_path) diff --git a/.github/workflows/aistore_ci.yml b/.github/workflows/aistore_ci.yml deleted file mode 100644 index f5b0abb2d..000000000 --- a/.github/workflows/aistore_ci.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: Run AIStore Datapipe Test -on: - push: - branches: - - main - - release/* - tags: - pull_request: - types: [opened, synchronize, reopened, labeled] - branches: - - main - # For PR created by ghstack - - gh/*/*/base - - release/* - -jobs: - test: - if: ${{ github.repository_owner == 'pytorch' }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: - - macos-latest - - ubuntu-latest - python-version: - - 3.9 - steps: - - name: Get PyTorch Channel - shell: bash - run: | - if [[ "${{ github.base_ref }}" == release/* ]] || [[ "${{ github.ref }}" == refs/heads/release/* ]] || [[ "${{ github.ref }}" == refs/tags/v* ]]; then - PT_CHANNEL="https://download.pytorch.org/whl/test/cpu/torch_test.html" - else - PT_CHANNEL="https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html" - fi - echo "value=$PT_CHANNEL" >> $GITHUB_OUTPUT - id: pytorch_channel - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Check out source repository - uses: actions/checkout@v3 - - name: Install dependencies - run: | - pip3 install -r requirements.txt - pip3 install --pre torch -f "${{ steps.pytorch_channel.outputs.value }}" - - name: Run AIStore local deployment - uses: NVIDIA/aistore@main - - name: Build TorchData - run: | - pip3 install . - - name: Install test requirements - run: pip3 install -r test/requirements_aistore.txt - - name: Run AIStore DataPipe tests with pytest - run: pytest --no-header -v test/test_aistore.py diff --git a/.github/workflows/build_wheels_linux.yml b/.github/workflows/build_wheels_linux.yml index c8be185c5..1f15f8d40 100644 --- a/.github/workflows/build_wheels_linux.yml +++ b/.github/workflows/build_wheels_linux.yml @@ -35,7 +35,7 @@ jobs: include: - repository: pytorch/data pre-script: packaging/pre_build_script_linux.sh - post-script: packaging/post_build_script_linux.sh + post-script: "" smoke-test-script: test/smoke_test/smoke_test.py package-name: torchdata name: ${{ matrix.repository }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index d31b17556..000000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,93 +0,0 @@ -name: Run DataPipes Tests -on: - push: - branches: - - main - - release/* - tags: - pull_request: - types: [opened, synchronize, reopened, labeled] - branches: - - main - # For PR created by ghstack - - gh/*/*/base - - release/* - -jobs: - test: - if: ${{ github.repository_owner == 'pytorch' }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: - - macos-latest - - ubuntu-latest - - windows-latest - python-version: - - 3.9 - - "3.10" - - "3.11" - - "3.12" - steps: - - name: Get PyTorch Channel - shell: bash - run: | - if [[ "${{ github.base_ref }}" == release/* ]] || [[ "${{ github.ref }}" == refs/heads/release/* ]] || [[ "${{ github.ref }}" == refs/tags/v* ]]; then - PT_CHANNEL="https://download.pytorch.org/whl/test/cpu/torch_test.html" - else - PT_CHANNEL="https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html" - fi - echo "value=$PT_CHANNEL" >> $GITHUB_OUTPUT - id: pytorch_channel - - name: Setup additional system libraries - if: startsWith( matrix.os, 'ubuntu' ) - run: | - sudo add-apt-repository multiverse - sudo apt update - sudo apt install rar unrar libssl-dev libcurl4-openssl-dev zlib1g-dev - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Setup msbuild on Windows - if: matrix.os == 'windows-latest' - uses: microsoft/setup-msbuild@v1.1 - - name: Set up Visual Studio shell - if: matrix.os == 'windows-latest' - uses: egor-tensin/vs-shell@v2 - with: - arch: x64 - - name: Check out source repository - uses: actions/checkout@v4 - with: - submodules: recursive - - name: Install dependencies - run: | - pip3 install -r requirements.txt - pip3 install networkx - pip3 install --pre torch -f "${{ steps.pytorch_channel.outputs.value }}" - pip3 install cmake ninja - echo "/home/runner/.local/bin" >> $GITHUB_PATH - - name: Build TorchData - run: | - pip3 install . - env: - BUILD_S3: 1 - - name: Install test requirements - run: pip3 install -r test/requirements.txt - - name: Test documentation examples - if: matrix.os != 'windows-latest' - run: | - cd ./docs - pip3 install -r requirements.txt - make doctest - cd .. - - name: Run DataPipes tests with pytest - if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} - run: - pytest --durations=0 --no-header -v test --ignore=test/test_period.py --ignore=test/test_text_examples.py - --ignore=test/test_audio_examples.py --ignore=test/test_aistore.py - --ignore=test/dataloader2/test_dataloader2.py --ignore=test/dataloader2/test_mprs.py - --ignore=test/test_distributed.py --ignore=test/stateful_dataloader/test_dataloader.py - --ignore=test/stateful_dataloader/test_state_dict.py diff --git a/.github/workflows/domain_ci.yml b/.github/workflows/domain_ci.yml deleted file mode 100644 index c7c2961c7..000000000 --- a/.github/workflows/domain_ci.yml +++ /dev/null @@ -1,96 +0,0 @@ -name: Run Domain Tests -on: - push: - branches: - - main - pull_request: - branches: - - main - # For PR created by ghstack - - gh/*/*/base - -jobs: - torchaudio: - if: ${{ github.repository_owner == 'pytorch' }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: - # - macos-latest - - ubuntu-latest - # - windows-latest - python-version: - - 3.9 - steps: - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - - name: Install torch and torchaudio from nightlies - run: | - pip install networkx - pip install --pre torch torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu - - - name: Check out torchdata repository - uses: actions/checkout@v3 - - - name: Install torchdata - run: | - pip install -r requirements.txt - python setup.py install - - - name: Install test requirements - run: pip install dill expecttest numpy pytest - - - name: Run torchaudio example datasets tests - if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} - run: pytest --no-header -v test/test_audio_examples.py - - name: Run torchaudio example datasets (including slow tests) - if: ${{ contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} - run: pytest --no-header -v test/test_audio_examples.py - env: - PYTORCH_TEST_WITH_SLOW: 1 - - torcharrow: - if: ${{ github.repository_owner == 'pytorch' }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: - # - macos-latest - - ubuntu-latest - # - windows-latest - python-version: - - 3.9 - steps: - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - - name: Install torch and torcharrow from nightlies - run: pip install --pre torch torcharrow -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - - - name: Check out torchdata repository - uses: actions/checkout@v3 - - - name: Install torchdata - run: | - pip install -r requirements.txt - python setup.py install - - - name: Install test requirements - run: pip install dill expecttest numpy pytest - - - name: Run torcharrow example datasets tests - if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} - run: pytest --no-header -v test/test_dataframe.py - - - name: Run torcharrow example datasets (including slow tests) - if: ${{ contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} - run: pytest --no-header -v test/test_dataframe.py - env: - PYTORCH_TEST_WITH_SLOW: 1 diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index e938d9e99..000000000 --- a/.gitmodules +++ /dev/null @@ -1,6 +0,0 @@ -[submodule "third_party/pybind11"] - path = third_party/pybind11 - url = https://github.com/pybind/pybind11.git -[submodule "third_party/aws-sdk-cpp"] - path = third_party/aws-sdk-cpp - url = https://github.com/aws/aws-sdk-cpp.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a3cf65a4..6696de652 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,6 +34,6 @@ repos: - usort == 1.0.0 - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 5.0.4 hooks: - id: flake8 diff --git a/CMakeLists.txt b/CMakeLists.txt deleted file mode 100644 index 06136dace..000000000 --- a/CMakeLists.txt +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -cmake_minimum_required(VERSION 3.13 FATAL_ERROR) - -# Most of the configurations are taken from PyTorch -# https://github.com/pytorch/pytorch/blob/0c9fb4aff0d60eaadb04e4d5d099fb1e1d5701a9/CMakeLists.txt - -# Use compiler ID "AppleClang" instead of "Clang" for XCode. -# Not setting this sometimes makes XCode C compiler gets detected as "Clang", -# even when the C++ one is detected as "AppleClang". -cmake_policy(SET CMP0010 NEW) -cmake_policy(SET CMP0025 NEW) - -# Suppress warning flags in default MSVC configuration. It's not -# mandatory that we do this (and we don't if cmake is old), but it's -# nice when it's possible, and it's possible on our Windows configs. -if(NOT CMAKE_VERSION VERSION_LESS 3.15.0) - cmake_policy(SET CMP0092 NEW) -endif() - -project(torchdata) - -# check and set CMAKE_CXX_STANDARD -string(FIND "${CMAKE_CXX_FLAGS}" "-std=c++" env_cxx_standard) -if(env_cxx_standard GREATER -1) - message( - WARNING "C++ standard version definition detected in environment variable." - "PyTorch requires -std=c++17. Please remove -std=c++ settings in your environment.") -endif() - -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_C_STANDARD 11) - - -set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CMAKE_POSITION_INDEPENDENT_CODE ON) - -# Apple specific -if(APPLE) - # Get clang version on macOS - execute_process( COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string ) - string(REGEX REPLACE "Apple LLVM version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION_STRING ${clang_full_version_string}) - message( STATUS "CLANG_VERSION_STRING: " ${CLANG_VERSION_STRING} ) - - # RPATH stuff - set(CMAKE_MACOSX_RPATH ON) - - set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") -endif() - -# Options -option(USE_SYSTEM_AWS_SDK_CPP "Use system-provided aws-sdk-cpp." OFF) -option(USE_SYSTEM_PYBIND11 "Use system-provided PyBind11." OFF) -if(USE_SYSTEM_LIBS) - set(USE_SYSTEM_AWS_SDK_CPP ON) - set(USE_SYSTEM_PYBIND11 ON) -endif() - -option(BUILD_S3 "Build s3 io functionality" OFF) - -if(BUILD_S3) -include(third_party/CMakeLists.txt) -add_subdirectory(torchdata/csrc) -endif() diff --git a/docs/source/conf.py b/docs/source/conf.py index b3a903058..3b5bfab1a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -112,21 +112,10 @@ "": "SequentialSampler", "torch.utils.data.datapipes.iter.combining.T_co": "T_co", "torch.utils.data.datapipes.iter.combinatorics.T_co": "T_co", - "torchdata.datapipes.iter.transform.bucketbatcher.T_co": "T_co", "torch.utils.data.datapipes.map.grouping.T": "T", "torch.utils.data.datapipes.map.combining.T_co": "T_co", "torch.utils.data.datapipes.map.combinatorics.T_co": "T_co", - "torchdata.datapipes.iter.util.cycler.T_co": "T_co", - "torchdata.datapipes.iter.util.paragraphaggregator.T_co": "T_co", - "torchdata.datapipes.map.util.cacheholder.T_co": "T_co", - "Sequence[torchdata.datapipes.map.util.unzipper.T]": "Sequence[T]", - "torchdata.datapipes.iter.util.samplemultiplexer.T_co": "T_co", - "torchdata.datapipes.iter.util.indexadder.K": "K", - "torchdata.datapipes.iter.util.unzipper.T": "T", "torch.utils.data.datapipes.iter.grouping.T_co": "T_co", - "torchdata.datapipes.iter.util.dataframemaker.T_co": "T_co", - "torchdata.datapipes.iter.util.cacheholder.T_co": "T_co", - "torchdata.datapipes.iter.util.header.T_co": "T_co", "": "List", "typing.": "", "Union[IterDataPipe, MapDataPipe]": "DataPipe", diff --git a/docs/source/dataloader2.rst b/docs/source/dataloader2.rst deleted file mode 100644 index 3f52f9d9e..000000000 --- a/docs/source/dataloader2.rst +++ /dev/null @@ -1,70 +0,0 @@ -:tocdepth: 3 - -DataLoader2 -============ - -.. automodule:: torchdata.dataloader2 - -A new, light-weight :class:`DataLoader2` is introduced to decouple the overloaded data-manipulation functionalities from ``torch.utils.data.DataLoader`` to ``DataPipe`` operations. Besides, certain features can only be achieved with :class:`DataLoader2` like snapshotting and switching backend services to perform high-performant operations. - -DataLoader2 ------------- - -.. autoclass:: DataLoader2 - :special-members: __iter__ - :members: - -Note: -:class:`DataLoader2` doesn't support ``torch.utils.data.Dataset`` or ``torch.utils.data.IterableDataset``. Please wrap each of them with the corresponding ``DataPipe`` below: - -- :class:`torchdata.datapipes.map.SequenceWrapper`: ``torch.utils.data.Dataset`` -- :class:`torchdata.datapipes.iter.IterableWrapper`: ``torch.utils.data.IterableDataset`` - -ReadingService ---------------- - -``ReadingService`` specifies the execution backend for the data-processing graph. There are three types of ``ReadingServices`` provided in TorchData: - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: class_method_template.rst - - DistributedReadingService - InProcessReadingService - MultiProcessingReadingService - SequentialReadingService - -Each ``ReadingServices`` would take the ``DataPipe`` graph and rewrite it to achieve a few features like dynamic sharding, sharing random seeds and snapshoting for multi-/distributed processes. For more detail about those features, please refer to `the documentation `_. - -Adapter --------- - -``Adapter`` is used to configure, modify and extend the ``DataPipe`` graph in :class:`DataLoader2`. It allows in-place -modification or replace the pre-assembled ``DataPipe`` graph provided by PyTorch domains. For example, ``Shuffle(False)`` can be -provided to :class:`DataLoader2`, which would disable any ``shuffle`` operations in the ``DataPipes`` graph. - -.. module:: torchdata.dataloader2.adapter - -.. autoclass:: Adapter - :special-members: __call__ - -Here are the list of :class:`Adapter` provided by TorchData in ``torchdata.dataloader2.adapter``: - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: class_template.rst - - Shuffle - CacheTimeout - -And, we will provide more ``Adapters`` to cover data-processing options: - -- ``PinMemory``: Attach a ``DataPipe`` at the end of the data-processing graph that coverts output data to ``torch.Tensor`` in pinned memory. -- ``FullSync``: Attach a ``DataPipe`` to make sure the data-processing graph synchronized between distributed processes to prevent hanging. -- ``ShardingPolicy``: Modify sharding policy if ``sharding_filter`` is presented in the ``DataPipe`` graph. -- ``PrefetchPolicy``, ``InvalidateCache``, etc. - -If you have feature requests about the ``Adapters`` you'd like to be provided, please open a GitHub issue. For specific -needs, ``DataLoader2`` also accepts any custom ``Adapter`` as long as it inherits from the ``Adapter`` class. diff --git a/docs/source/dlv2_tutorial.rst b/docs/source/dlv2_tutorial.rst deleted file mode 100644 index b7c645e87..000000000 --- a/docs/source/dlv2_tutorial.rst +++ /dev/null @@ -1,70 +0,0 @@ -DataLoader2 Tutorial -===================== - -This is the tutorial for users to create a ``DataPipe`` graph and load data via ``DataLoader2`` with different backend systems (``ReadingService``). An usage example can be found in `this colab notebook `_. - -DataPipe ---------- - -Please refer to `DataPipe Tutorial `_ for more details. Here are the most important caveats necessary: -to make sure the data pipeline has different order per epoch and data shards are mutually exclusive and collectively exhaustive: - -- Place ``sharding_filter`` or ``sharding_round_robin_dispatch`` as early as possible in the pipeline to avoid repeating expensive operations in worker/distributed processes. -- Add a ``shuffle`` DataPipe before sharding to achieve inter-shard shuffling. ``ReadingService`` will handle synchronization of those ``shuffle`` operations to ensure the order of data are the same before sharding so that all shards are mutually exclusive and collectively exhaustive. - -Here is an example of a ``DataPipe`` graph: - -.. code:: python - - datapipe = IterableWrapper(["./train1.csv", "./train2.csv"]) - datapipe = datapipe.open_files(encoding="utf-8").parse_csv() - datapipe = datapipe.shuffle().sharding_filter() - datapipe = datapipe.map(fn).batch(8) - -Multiprocessing ----------------- - -``MultiProcessingReadingService`` handles multiprocessing sharding at the point of ``sharding_filter`` and synchronizes the seeds across worker processes. - -.. code:: python - - rs = MultiProcessingReadingService(num_workers=4) - dl = DataLoader2(datapipe, reading_service=rs) - for epoch in range(10): - dl.seed(epoch) - for d in dl: - model(d) - dl.shutdown() - -Distributed ------------- - -``DistributedReadingService`` handles distributed sharding at the point of ``sharding_filter`` and synchronizes the seeds across distributed processes. And, in order to balance the data shards across distributed nodes, a ``fullsync`` ``DataPipe`` will be attached to the ``DataPipe`` graph to align the number of batches across distributed ranks. This would prevent hanging issue caused by uneven shards in distributed training. - -.. code:: python - - rs = DistributedReadingService() - dl = DataLoader2(datapipe, reading_service=rs) - for epoch in range(10): - dl.seed(epoch) - for d in dl: - model(d) - dl.shutdown() - -Multiprocessing + Distributed ------------------------------- - -``SequentialReadingService`` can be used to combine both ``ReadingServices`` together to achieve multiprocessing and distributed training at the same time. - -.. code:: python - - mp_rs = MultiProcessingReadingService(num_workers=4) - dist_rs = DistributedReadingService() - rs = SequentialReadingService(dist_rs, mp_rs) - - dl = DataLoader2(datapipe, reading_service=rs) - for epoch in range(10): - dl.seed(epoch) - for d in dl: - model(d) - dl.shutdown() diff --git a/docs/source/dp_tutorial.rst b/docs/source/dp_tutorial.rst deleted file mode 100644 index b1d0965df..000000000 --- a/docs/source/dp_tutorial.rst +++ /dev/null @@ -1,544 +0,0 @@ -DataPipe Tutorial -================== - -Using DataPipes ---------------------------------------------- - -Suppose that we want to load data from CSV files with the following steps: - -- List all CSV files in a directory -- Load CSV files -- Parse CSV file and yield rows -- Split our dataset into training and validation sets - -There are a few `built-in DataPipes `_ that can help us with the above operations. - -- ``FileLister`` - `lists out files in a directory `_ -- ``Filter`` - `filters the elements in DataPipe based on a given - function `_ -- ``FileOpener`` - `consumes file paths and returns opened file - streams `_ -- ``CSVParser`` - `consumes file streams, parses the CSV contents, and returns one parsed line at a - time `_ -- ``RandomSplitter`` - `randomly split samples from a source DataPipe into - groups `_ - -As an example, the source code for ``CSVParser`` looks something like this: - -.. code:: python - - @functional_datapipe("parse_csv") - class CSVParserIterDataPipe(IterDataPipe): - def __init__(self, dp, **fmtparams) -> None: - self.dp = dp - self.fmtparams = fmtparams - - def __iter__(self) -> Iterator[Union[Str_Or_Bytes, Tuple[str, Str_Or_Bytes]]]: - for path, file in self.source_datapipe: - stream = self._helper.skip_lines(file) - stream = self._helper.strip_newline(stream) - stream = self._helper.decode(stream) - yield from self._helper.return_path(stream, path=path) # Returns 1 line at a time as List[str or bytes] - -As mentioned in a different section, DataPipes can be invoked using their functional forms (recommended) or their -class constructors. A pipeline can be assembled as the following: - -.. code:: python - - import torchdata.datapipes as dp - - FOLDER = 'path/2/csv/folder' - datapipe = dp.iter.FileLister([FOLDER]).filter(filter_fn=lambda filename: filename.endswith('.csv')) - datapipe = dp.iter.FileOpener(datapipe, mode='rt') - datapipe = datapipe.parse_csv(delimiter=',') - N_ROWS = 10000 # total number of rows of data - train, valid = datapipe.random_split(total_length=N_ROWS, weights={"train": 0.5, "valid": 0.5}, seed=0) - - for x in train: # Iterating through the training dataset - pass - - for y in valid: # Iterating through the validation dataset - pass - -You can find the full list of built-in `IterDataPipes here `_ and -`MapDataPipes here `_. - -Working with DataLoader ---------------------------------------------- - -In this section, we will demonstrate how you can use ``DataPipe`` with ``DataLoader``. -For the most part, you should be able to use it just by passing ``dataset=datapipe`` as an input argument -into the ``DataLoader``. For detailed documentation related to ``DataLoader``, -please visit `this PyTorch Core page `_. - - -Please refer to :doc:`this page ` about using ``DataPipe`` with ``DataLoader2``. - - -For this example, we will first have a helper function that generates some CSV files with random label and data. - -.. code:: python - - import csv - import random - - def generate_csv(file_label, num_rows: int = 5000, num_features: int = 20) -> None: - fieldnames = ['label'] + [f'c{i}' for i in range(num_features)] - writer = csv.DictWriter(open(f"sample_data{file_label}.csv", "w", newline=''), fieldnames=fieldnames) - writer.writeheader() - for i in range(num_rows): - row_data = {col: random.random() for col in fieldnames} - row_data['label'] = random.randint(0, 9) - writer.writerow(row_data) - -Next, we will build our DataPipes to read and parse through the generated CSV files. Note that we prefer to have -pass defined functions to DataPipes rather than lambda functions because the formers are serializable with `pickle`. - -.. code:: python - - import numpy as np - import torchdata.datapipes as dp - - def filter_for_data(filename): - return "sample_data" in filename and filename.endswith(".csv") - - def row_processor(row): - return {"label": np.array(row[0], np.int32), "data": np.array(row[1:], dtype=np.float64)} - - def build_datapipes(root_dir="."): - datapipe = dp.iter.FileLister(root_dir) - datapipe = datapipe.filter(filter_fn=filter_for_data) - datapipe = datapipe.open_files(mode='rt') - datapipe = datapipe.parse_csv(delimiter=",", skip_lines=1) - # Shuffle will happen as long as you do NOT set `shuffle=False` later in the DataLoader - datapipe = datapipe.shuffle() - datapipe = datapipe.map(row_processor) - return datapipe - -Lastly, we will put everything together in ``'__main__'`` and pass the DataPipe into the DataLoader. Note that -if you choose to use ``Batcher`` while setting ``batch_size > 1`` for DataLoader, your samples will be -batched more than once. You should choose one or the other. - -.. code:: python - - from torch.utils.data import DataLoader - - if __name__ == '__main__': - num_files_to_generate = 3 - for i in range(num_files_to_generate): - generate_csv(file_label=i, num_rows=10, num_features=3) - datapipe = build_datapipes() - dl = DataLoader(dataset=datapipe, batch_size=5, num_workers=2) - first = next(iter(dl)) - labels, features = first['label'], first['data'] - print(f"Labels batch shape: {labels.size()}") - print(f"Feature batch shape: {features.size()}") - print(f"{labels = }\n{features = }") - n_sample = 0 - for row in iter(dl): - n_sample += 1 - print(f"{n_sample = }") - -The following statements will be printed to show the shapes of a single batch of labels and features. - -.. code:: - - Labels batch shape: torch.Size([5]) - Feature batch shape: torch.Size([5, 3]) - labels = tensor([8, 9, 5, 9, 7], dtype=torch.int32) - features = tensor([[0.2867, 0.5973, 0.0730], - [0.7890, 0.9279, 0.7392], - [0.8930, 0.7434, 0.0780], - [0.8225, 0.4047, 0.0800], - [0.1655, 0.0323, 0.5561]], dtype=torch.float64) - n_sample = 12 - -The reason why ``n_sample = 12`` is because ``ShardingFilter`` (``datapipe.sharding_filter()``) was not used, such that -each worker will independently return all samples. In this case, there are 10 rows per file and 3 files, with a -batch size of 5, that gives us 6 batches per worker. With 2 workers, we get 12 total batches from the ``DataLoader``. - -In order for DataPipe sharding to work with ``DataLoader``, we need to add the following. - -.. code:: python - - def build_datapipes(root_dir="."): - datapipe = ... - # Add the following line to `build_datapipes` - # Note that it is somewhere after `Shuffler` in the DataPipe line, but before expensive operations - datapipe = datapipe.sharding_filter() - return datapipe - -When we re-run, we will get: - -.. code:: - - ... - n_sample = 6 - -Note: - -- Place ``ShardingFilter`` (``datapipe.sharding_filter``) as early as possible in the pipeline, especially before expensive - operations such as decoding, in order to avoid repeating these expensive operations across worker/distributed processes. -- For the data source that needs to be sharded, it is crucial to add ``Shuffler`` before ``ShardingFilter`` - to ensure data are globally shuffled before being split into shards. Otherwise, each worker process would - always process the same shard of data for all epochs. And, it means each batch would only consist of data - from the same shard, which leads to low accuracy during training. However, it doesn't apply to the data - source that has already been sharded for each multi-/distributed process, since ``ShardingFilter`` is no - longer required to be presented in the pipeline. -- There may be cases where placing ``Shuffler`` earlier in the pipeline lead to worse performance, because some - operations (e.g. decompression) are faster with sequential reading. In those cases, we recommend decompressing - the files prior to shuffling (potentially prior to any data loading). - - -You can find more DataPipe implementation examples for various research domains `on this page `_. - - -Implementing a Custom DataPipe ---------------------------------------------- -Currently, we already have a large number of built-in DataPipes and we expect them to cover most necessary -data processing operations. If none of them supports your need, you can create your own custom DataPipe. - -As a guiding example, let us implement an ``IterDataPipe`` that applies a callable to the input iterator. For -``MapDataPipe``, take a look at the -`map `_ -folder for examples, and follow the steps below for the ``__getitem__`` method instead of the ``__iter__`` method. - -Naming -^^^^^^^^^^^^^^^^^^ -The naming convention for ``DataPipe`` is "Operation"-er, followed by ``IterDataPipe`` or ``MapDataPipe``, as each -DataPipe is essentially a container to apply an operation to data yielded from a source ``DataPipe``. For succinctness, -we alias to just "Operation-er" in **init** files. For our ``IterDataPipe`` example, we'll name the module -``MapperIterDataPipe`` and alias it as ``iter.Mapper`` under ``torchdata.datapipes``. - -For the functional method name, the naming convention is ``datapipe.``. For instance, -the functional method name of ``Mapper`` is ``map``, such that it can be invoked by ``datapipe.map(...)``. - - -Constructor -^^^^^^^^^^^^^^^^^^ - -DataSets are now generally constructed as stacks of ``DataPipes``, so each ``DataPipe`` typically takes a -source ``DataPipe`` as its first argument. Here is a simplified version of `Mapper` as an example: - -.. code:: python - - from torchdata.datapipes.iter import IterDataPipe - - class MapperIterDataPipe(IterDataPipe): - def __init__(self, source_dp: IterDataPipe, fn) -> None: - super().__init__() - self.source_dp = source_dp - self.fn = fn - -Note: - -- Avoid loading data from the source DataPipe in ``__init__`` function, in order to support lazy data loading and save - memory. - -- If ``IterDataPipe`` instance holds data in memory, please be ware of the in-place modification of data. When second - iterator is created from the instance, the data may have already changed. Please take ``IterableWrapper`` - `class `_ - as reference to ``deepcopy`` data for each iterator. - -- Avoid variables names that are taken by the functional names of existing DataPipes. For instance, ``.filter`` is - the functional name that can be used to invoke ``FilterIterDataPipe``. Having a variable named ``filter`` inside - another ``IterDataPipe`` can lead to confusion. - - -Iterator -^^^^^^^^^^^^^^^^^^ -For ``IterDataPipes``, an ``__iter__`` function is needed to consume data from the source ``IterDataPipe`` then -apply the operation over the data before ``yield``. - -.. code:: python - - class MapperIterDataPipe(IterDataPipe): - # ... See __init__() defined above - - def __iter__(self): - for d in self.dp: - yield self.fn(d) - -Length -^^^^^^^^^^^^^^^^^^ -In many cases, as in our ``MapperIterDataPipe`` example, the ``__len__`` method of a DataPipe returns the length of the -source DataPipe. - -.. code:: python - - class MapperIterDataPipe(IterDataPipe): - # ... See __iter__() defined above - - def __len__(self): - return len(self.dp) - -However, note that ``__len__`` is optional for ``IterDataPipe`` and often inadvisable. For ``CSVParserIterDataPipe`` -in the using DataPipes section below, ``__len__`` is not implemented because the number of rows in each file -is unknown before loading it. In some special cases, ``__len__`` can be made to either return an integer or raise -an Error depending on the input. In those cases, the Error must be a ``TypeError`` to support Python's -build-in functions like ``list(dp)``. - -Registering DataPipes with the functional API -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Each DataPipe can be registered to support functional invocation using the decorator ``functional_datapipe``. - -.. code:: python - - @functional_datapipe("map") - class MapperIterDataPipe(IterDataPipe): - # ... - -The stack of DataPipes can then be constructed using their functional forms (recommended) or class constructors: - -.. code:: python - - import torchdata.datapipes as dp - - # Using functional form (recommended) - datapipes1 = dp.iter.FileOpener(['a.file', 'b.file']).map(fn=decoder).shuffle().batch(2) - # Using class constructors - datapipes2 = dp.iter.FileOpener(['a.file', 'b.file']) - datapipes2 = dp.iter.Mapper(datapipes2, fn=decoder) - datapipes2 = dp.iter.Shuffler(datapipes2) - datapipes2 = dp.iter.Batcher(datapipes2, 2) - -In the above example, ``datapipes1`` and ``datapipes2`` represent the exact same stack of ``IterDataPipe``\s. We -recommend using the functional form of DataPipes. - -Working with Cloud Storage Providers ---------------------------------------------- - -In this section, we show examples accessing AWS S3, Google Cloud Storage, and Azure Cloud Storage with built-in ``fsspec`` DataPipes. -Although only those two providers are discussed here, with additional libraries, ``fsspec`` DataPipes -should allow you to connect with other storage systems as well (`list of known -implementations `_). - -Let us know on GitHub if you have a request for support for other cloud storage providers, -or you have code examples to share with the community. - -Accessing AWS S3 with ``fsspec`` DataPipes -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -This requires the installation of the libraries ``fsspec`` -(`documentation `__) and ``s3fs`` -(`s3fs GitHub repo `_). - -You can list out the files within a S3 bucket directory by passing a path that starts -with ``"s3://BUCKET_NAME"`` to -`FSSpecFileLister `_ (``.list_files_by_fsspec(...)``). - -.. code:: python - - from torchdata.datapipes.iter import IterableWrapper - - dp = IterableWrapper(["s3://BUCKET_NAME"]).list_files_by_fsspec() - -You can also open files using `FSSpecFileOpener `_ -(``.open_files_by_fsspec(...)``) and stream them -(if supported by the file format). - -Note that you can also provide additional parameters via -the argument ``kwargs_for_open``. This can be useful for purposes such as accessing specific -bucket version, which you can do so by passing in ``{version_id: 'SOMEVERSIONID'}`` (more `details -about S3 bucket version awareness `_ -by ``s3fs``). The supported arguments vary by the (cloud) file system that you are accessing. - -In the example below, we are streaming the archive by using -`TarArchiveLoader `_ (``.load_from_tar(mode="r|")``), -in contrast with the usual ``mode="r:"``. This allows us to begin processing data inside the archive -without downloading the whole archive into memory first. - -.. code:: python - - from torchdata.datapipes.iter import IterableWrapper - dp = IterableWrapper(["s3://BUCKET_NAME/DIRECTORY/1.tar"]) - dp = dp.open_files_by_fsspec(mode="rb", anon=True).load_from_tar(mode="r|") # Streaming version - # The rest of data processing logic goes here - - -Finally, `FSSpecFileSaver `_ -is also available for writing data to cloud. - -Accessing Google Cloud Storage (GCS) with ``fsspec`` DataPipes -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -This requires the installation of the libraries ``fsspec`` -(`documentation `__) and ``gcsfs`` -(`gcsfs GitHub repo `_). - -You can list out the files within a GCS bucket directory by specifying a path that starts -with ``"gcs://BUCKET_NAME"``. The bucket name in the example below is ``uspto-pair``. - -.. code:: python - - from torchdata.datapipes.iter import IterableWrapper - - dp = IterableWrapper(["gcs://uspto-pair/"]).list_files_by_fsspec() - print(list(dp)) - # ['gcs://uspto-pair/applications', 'gcs://uspto-pair/docs', 'gcs://uspto-pair/prosecution-history-docs'] - -Here is an example of loading a zip file ``05900035.zip`` from a bucket named ``uspto-pair`` inside the -directory ``applications``. - -.. code:: python - - from torchdata.datapipes.iter import IterableWrapper - - dp = IterableWrapper(["gcs://uspto-pair/applications/05900035.zip"]) \ - .open_files_by_fsspec(mode="rb") \ - .load_from_zip() - # Logic to process those archive files comes after - for path, filestream in dp: - print(path, filestream) - # gcs:/uspto-pair/applications/05900035.zip/05900035/README.txt, StreamWrapper<...> - # gcs:/uspto-pair/applications/05900035.zip/05900035/05900035-address_and_attorney_agent.tsv, StreamWrapper<...> - # gcs:/uspto-pair/applications/05900035.zip/05900035/05900035-application_data.tsv, StreamWrapper<...> - # gcs:/uspto-pair/applications/05900035.zip/05900035/05900035-continuity_data.tsv, StreamWrapper<...> - # gcs:/uspto-pair/applications/05900035.zip/05900035/05900035-transaction_history.tsv, StreamWrapper<...> - -Accessing Azure Blob storage with ``fsspec`` DataPipes -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -This requires the installation of the libraries ``fsspec`` -(`documentation `__) and ``adlfs`` -(`adlfs GitHub repo `_). -You can access data in Azure Data Lake Storage Gen2 by providing URIs staring with ``abfs://``. -For example, -`FSSpecFileLister `_ (``.list_files_by_fsspec(...)``) -can be used to list files in a directory in a container: - -.. code:: python - - from torchdata.datapipes.iter import IterableWrapper - - storage_options={'account_name': ACCOUNT_NAME, 'account_key': ACCOUNT_KEY} - dp = IterableWrapper(['abfs://CONTAINER/DIRECTORY']).list_files_by_fsspec(**storage_options) - print(list(dp)) - # ['abfs://container/directory/file1.txt', 'abfs://container/directory/file2.txt', ...] - -You can also open files using `FSSpecFileOpener `_ -(``.open_files_by_fsspec(...)``) and stream them -(if supported by the file format). - -Here is an example of loading a CSV file ``ecdc_cases.csv`` from a public container inside the -directory ``curated/covid-19/ecdc_cases/latest``, belonging to account ``pandemicdatalake``. - -.. code:: python - - from torchdata.datapipes.iter import IterableWrapper - dp = IterableWrapper(['abfs://public/curated/covid-19/ecdc_cases/latest/ecdc_cases.csv']) \ - .open_files_by_fsspec(account_name='pandemicdatalake') \ - .parse_csv() - print(list(dp)[:3]) - # [['date_rep', 'day', ..., 'iso_country', 'daterep'], - # ['2020-12-14', '14', ..., 'AF', '2020-12-14'], - # ['2020-12-13', '13', ..., 'AF', '2020-12-13']] - -If necessary, you can also access data in Azure Data Lake Storage Gen1 by using URIs staring with -``adl://`` and ``abfs://``, as described in `README of adlfs repo `_ - -Accessing Azure ML Datastores with ``fsspec`` DataPipes -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -An Azure ML datastore is a *reference* to an existing storage account on Azure. The key benefits of creating and using an Azure ML datastore are: - -- A common and easy-to-use API to interact with different storage types in Azure (Blob/Files/). -- Easier to discover useful datastores when working as a team. -- Authentication is automatically handled - both *credential-based* access (service principal/SAS/key) and *identity-based* access (Azure Active Directory/managed identity) are supported. When using credential-based authentication, you do not need to expose secrets in your code. - -This requires the installation of the library ``azureml-fsspec`` -(`documentation `__). - -You can access data in an Azure ML datastore by providing URIs staring with ``azureml://``. -For example, -`FSSpecFileLister `_ (``.list_files_by_fsspec(...)``) -can be used to list files in a directory in a container: - -.. code:: python - - from torchdata.datapipes.iter import IterableWrapper - - # set the subscription_id, resource_group, and AzureML workspace_name - subscription_id = "" - resource_group = "" - workspace_name = "" - - # set the datastore name and path on the datastore - datastore_name = "" - path_on_datastore = "" - - uri = f"azureml://subscriptions/{subscription_id}/resourcegroups/{resource_group}/workspaces/{workspace_name}/datastores/{datastore_name}/paths/{path_on_datastore}" - - dp = IterableWrapper([uri]).list_files_by_fsspec() - print(list(dp)) - # ['azureml:////resourcegroups//workspaces//datastores//paths//file1.txt', - # 'azureml:////resourcegroups//workspaces//datastores//paths//file2.txt', ...] - -You can also open files using `FSSpecFileOpener `_ -(``.open_files_by_fsspec(...)``) and stream them -(if supported by the file format). - -Here is an example of loading a tar file from the default Azure ML datastore ``workspaceblobstore`` where the path is ``/cifar-10-python.tar.gz`` (top-level folder). - -.. code:: python - - from torchdata.datapipes.iter import IterableWrapper - - # set the subscription_id, resource_group, and AzureML workspace_name - subscription_id = "" - resource_group = "" - workspace_name = "" - - # set the datastore name and path on the datastore - datastore_name = "workspaceblobstore" - path_on_datastore = "cifar-10-python.tar.gz" - - uri = f"azureml://subscriptions/{subscription_id}/resourcegroups/{resource_group}/workspaces/{workspace_name}/datastores/{datastore_name}/paths/{path_on_datastore}" - - dp = IterableWrapper([uri]) \ - .open_files_by_fsspec(mode="rb") \ - .load_from_tar() - - for path, filestream in dp: - print(path) - # ['azureml:/subscriptions//resourcegroups//workspaces//datastores//paths/cifar-10-python.tar.gz/cifar-10-batches-py/data_batch_4', - # 'azureml:/subscriptions//resourcegroups//workspaces//datastores//paths/cifar-10-python.tar.gz/cifar-10-batches-py/readme.html', - # 'azureml:/subscriptions//resourcegroups//workspaces//datastores//paths/cifar-10-python.tar.gz/cifar-10-batches-py/test_batch', - # 'azureml:/subscriptions//resourcegroups//workspaces//datastores//paths/cifar-10-python.tar.gz/cifar-10-batches-py/data_batch_3', - # 'azureml:/subscriptions//resourcegroups//workspaces//datastores//paths/cifar-10-python.tar.gz/cifar-10-batches-py/batches.meta', - # 'azureml:/subscriptions//resourcegroups//workspaces//datastores//paths/cifar-10-python.tar.gz/cifar-10-batches-py/data_batch_2', - # 'azureml:/subscriptions//resourcegroups//workspaces//datastores//paths/cifar-10-python.tar.gz/cifar-10-batches-py/data_batch_5', - # 'azureml:/subscriptions//resourcegroups//workspaces//datastores//paths/cifar-10-python.tar.gz/cifar-10-batches-py/data_batch_1] - -Here is an example of loading a CSV file - the famous Titanic dataset (`download `_) - from the Azure ML datastore ``workspaceblobstore`` where the path is ``/titanic.csv`` (top-level folder). - -.. code:: python - - from torchdata.datapipes.iter import IterableWrapper - - # set the subscription_id, resource_group, and AzureML workspace_name - subscription_id = "" - resource_group = "" - workspace_name = "" - - # set the datastore name and path on the datastore - datastore_name = "workspaceblobstore" - path_on_datastore = "titanic.csv" - - uri = f"azureml://subscriptions/{subscription_id}/resourcegroups/{resource_group}/workspaces/{workspace_name}/datastores/{datastore_name}/paths/{path_on_datastore}" - - def row_processer(row): - # return the label and data (the class and age of the passenger) - # if missing age, set to 50 - if row[5] == "": - row[5] = 50.0 - return {"label": np.array(row[1], np.int32), "data": np.array([row[2],row[5]], dtype=np.float32)} - - dp = IterableWrapper([uri]) \ - .open_files_by_fsspec() \ - .parse_csv(delimiter=",", skip_lines=1) \ - .map(row_processer) - - print(list(dp)[:3]) - # [{'label': array(0, dtype=int32), 'data': array([ 3., 22.], dtype=float32)}, - # {'label': array(1, dtype=int32), 'data': array([ 1., 38.], dtype=float32)}, - # {'label': array(1, dtype=int32), 'data': array([ 3., 26.], dtype=float32)}] diff --git a/docs/source/examples.rst b/docs/source/examples.rst deleted file mode 100644 index ed5535551..000000000 --- a/docs/source/examples.rst +++ /dev/null @@ -1,148 +0,0 @@ -Examples -================ - -.. currentmodule:: examples - -In this section, you will find the data loading implementations (using DataPipes) of various -popular datasets across different research domains. Some of the examples are implements by the PyTorch team and the -implementation codes are maintained within PyTorch libraries. Others are created by members of the PyTorch community. - -Audio ------------ - -LibriSpeech -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -`LibriSpeech dataset `_ is corpus of approximately 1000 hours of 16kHz read -English speech. Here is the -`DataPipe implementation of LibriSpeech `_ -to load the data. - -Text ------------ - -Amazon Review Polarity -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The Amazon reviews dataset contains reviews from Amazon. Its purpose is to train text/sentiment classification models. -In our DataPipe -`implementation of the dataset `_, -we described every step with detailed comments to help you understand what each DataPipe is doing. We recommend -having a look at this example. - - -IMDB -^^^^^^^^^^^^^^^^^^^^^^^^^^ -This is a `large movie review dataset `_ for binary sentiment -classification containing 25,000 highly polar movie reviews for training and 25,00 for testing. Here is the -`DataPipe implementation to load the data `_. - - -SQuAD -^^^^^^^^^^^^^^^^^^^^^^^^^^ -`SQuAD (Stanford Question Answering Dataset) `_ is a dataset for -reading comprehension. It consists of a list of questions by crowdworkers on a set of Wikipedia articles. Here are the -DataPipe implementations for `version 1.1 `_ -is here and `version 2.0 `_. - -Additional Datasets in TorchText -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -In a separate PyTorch domain library `TorchText `_, you will find some of the most -popular datasets in the NLP field implemented as loadable datasets using DataPipes. You can find -all of those `NLP datasets here `_. - - -Vision ------------ - -Caltech 101 -^^^^^^^^^^^^^^^^^^^^^^^^^^ -The `Caltech 101 dataset `_ contains pictures of objects -belonging to 101 categories. Here is the -`DataPipe implementation of Caltech 101 `_. - -Caltech 256 -^^^^^^^^^^^^^^^^^^^^^^^^^^ -The `Caltech 256 dataset `_ contains 30607 images -from 256 categories. Here is the -`DataPipe implementation of Caltech 256 `_. - -CamVid - Semantic Segmentation (community example) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The `Cambridge-driving Labeled Video Database (CamVid) `_ is a collection of videos with object class semantic -labels, complete with metadata. The database provides ground truth labels that associate each pixel with one of 32 -semantic classes. Here is a -`DataPipe implementation of CamVid -`_ -created by our community. - -laion2B-en-joined -^^^^^^^^^^^^^^^^^^^^^^ -The `laion2B-en-joined dataset `_ is a subset of the `LAION-5B dataset `_ containing english captions, URls pointing to images, -and other metadata. It contains around 2.32 billion entries. -Currently (February 2023) around 86% of the URLs still point to valid images. Here is a `DataPipe implementation of laion2B-en-joined -`_ that filters out unsafe images and images with watermarks and loads the images from the URLs. - -Additional Datasets in TorchVision -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -In a separate PyTorch domain library `TorchVision `_, you will find some of the most -popular datasets in the computer vision field implemented as loadable datasets using DataPipes. You can find all of -those `vision datasets here `_. - -Note that these implementations are currently in the prototype phase, but they should be fully supported -in the coming months. Nonetheless, they demonstrate the different ways DataPipes can be used for data loading. - -Recommender System ---------------------------------- - -Criteo 1TB Click Logs -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The `Criteo dataset `_ contains feature values -and click feedback for millions of display advertisements. It aims to benchmark algorithms for -click through rate (CTR) prediction. You can find a prototype stage implementation of the -`dataset with DataPipes in TorchRec `_. - -Graphs, Meshes and Point Clouds -------------------------------- - -TigerGraph (community example) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -TigerGraph is a scalable graph data platform for AI and ML. You can find an `implementation `_ of graph feature engineering and machine learning with DataPipes in TorchData and data stored in a TigerGraph database, which includes computing PageRank scores in-database, pulling graph data and features with multiple DataPipes, and training a neural network using graph features in PyTorch. - -MoleculeNet (community example) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -`MoleculeNet `_ is a benchmark specially designed for testing machine learning methods of -molecular properties. You can find an implementation of the -`HIV dataset with DataPipes in PyTorch Geometric `_, -which includes converting SMILES strings into molecular graph representations. - -Princeton ModelNet (community example) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The Princeton ModelNet project provides a comprehensive and clean collection of 3D CAD models across various object types. -You can find an implementation of the -`ModelNet10 dataset with DataPipes in PyTorch Geometric `_, -which includes reading in meshes via `meshio `_, and sampling of points from object surfaces and dynamic -graph generation via `PyG's functional transformations `_. - -Timeseries ---------------------------------- - -Custom DataPipe for Timeseries rolling window (community example) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Implementing a rolling window custom `DataPipe` for timeseries forecasting tasks. -Here is the -`DataPipe implementation of a rolling window -`_. - - -Using AIStore -------------------------- - -Caltech 256 and Microsoft COCO (community example) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Listing and loading data from AIS buckets (buckets that are not 3rd party backend-based) and remote cloud buckets (3rd party -backend-based cloud buckets) using `AISFileLister `_ and `AISFileLoader `_. - -Here is an `example which uses AISIO DataPipe `_ for the `Caltech-256 Object Category Dataset `_ containing 256 object categories and a total -of 30607 images stored on an AIS bucket and the `Microsoft COCO Dataset `_ which has 330K images with over 200K -labels of more than 1.5 million object instances across 80 object categories stored on Google Cloud. diff --git a/docs/source/index.rst b/docs/source/index.rst index cec30b4ad..18a8d21fe 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -37,11 +37,6 @@ Features described in this documentation are classified by release status: :caption: API Reference: torchdata.stateful_dataloader.rst - torchdata.datapipes.iter.rst - torchdata.datapipes.map.rst - torchdata.datapipes.utils.rst - dataloader2.rst - reading_service.rst .. toctree:: @@ -49,9 +44,6 @@ Features described in this documentation are classified by release status: :caption: Tutorial and Examples: stateful_dataloader_tutorial.rst - dp_tutorial.rst - dlv2_tutorial.rst - examples.rst .. toctree:: @@ -59,8 +51,8 @@ Features described in this documentation are classified by release status: :caption: PyTorch Libraries PyTorch + torchtune torchaudio - torchtext torchvision TorchElastic TorchServe diff --git a/docs/source/reading_service.rst b/docs/source/reading_service.rst deleted file mode 100644 index e5836a407..000000000 --- a/docs/source/reading_service.rst +++ /dev/null @@ -1,163 +0,0 @@ -:tocdepth: 3 - -.. currentmodule:: torchdata.datapipes.iter - -ReadingService -=============== - -``ReadingService`` handles in-place modification of ``DataPipe`` graph based on different use cases. - -Features ---------- - -Dynamic Sharding -^^^^^^^^^^^^^^^^ - -Dynamic sharding is achieved by ``MultiProcessingReadingService`` and ``DistributedReadingService`` to shard the pipeline based on the information of corresponding multiprocessing and distributed workers. And, TorchData offers two types of ``DataPipe`` letting users define the sharding place within the pipeline. - -- ``sharding_filter`` (:class:`ShardingFilter`): When the pipeline is replicable, each distributed/multiprocessing worker loads data from its own replica of the ``DataPipe`` graph, while skipping samples that do not belong to the corresponding worker at the point where ``sharding_filter`` is placed. - -- ``sharding_round_robin_dispatch`` (:class:`ShardingRoundRobinDispatcher`): When there is any ``sharding_round_robin_dispatch`` ``DataPipe`` in the pipeline, that branch (i.e. all DataPipes prior to ``sharding_round_robin_dispatch``) will be treated as a non-replicable branch (in the context of multiprocessing). A single dispatching process will be created to load data from the non-replicable branch and distribute data to the subsequent worker processes. - -The following is an example of having two types of sharding strategies in the pipeline. - -.. graphviz:: - - digraph Example { - subgraph cluster_replicable { - label="Replicable" - a -> b -> c -> d -> l; - color=blue; - } - - subgraph cluster_non_replicable { - style=filled; - color=lightgrey; - node [style=filled,color=white]; - label="Non-Replicable" - e -> f -> g -> k; - h -> i -> j -> k; - } - - k -> l -> fullsync -> end; - - a [label="DP1"]; - b [label="shuffle"]; - c [label="sharding_filter", color=blue]; - d [label="DP4"]; - e [label="DP2"]; - f [label="shuffle"]; - g [label="sharding_round_robin_dispatch", style="filled,rounded", color=red, fillcolor=white]; - h [label="DP3"]; - i [label="shuffle"]; - j [label="sharding_round_robin_dispatch", style="filled,rounded", color=red, fillcolor=white]; - k [label="DP5 (Lowest common ancestor)"]; - l [label="DP6"]; - fullsync; - end [shape=box]; - } - -When multiprocessing takes place, the graph becomes: - -.. graphviz:: - - digraph Example { - subgraph cluster_worker_0 { - label="Worker 0" - a0 -> b0 -> c0 -> d0 -> l0; - m0 -> l0; - color=blue; - } - - subgraph cluster_worker_1 { - label="Worker 1" - a1 -> b1 -> c1 -> d1 -> l1; - m1 -> l1; - color=blue; - } - - subgraph cluster_non_replicable { - style=filled; - color=lightgrey; - node [style=filled,color=white]; - label="Non-Replicable" - e -> f -> g -> k; - h -> i -> j -> k; - k -> round_robin_demux; - } - - round_robin_demux -> m0; - round_robin_demux -> m1; - l0 -> n; - l1 -> n; - n -> fullsync -> end; - - a0 [label="DP1"]; - b0 [label="shuffle"]; - c0 [label="sharding_filter", color=blue]; - d0 [label="DP4"]; - a1 [label="DP1"]; - b1 [label="shuffle"]; - c1 [label="sharding_filter", color=blue]; - d1 [label="DP4"]; - e [label="DP2"]; - f [label="shuffle"]; - g [label="sharding_round_robin_dispatch", style="filled,rounded", color=red, fillcolor=white]; - h [label="DP3"]; - i [label="shuffle"]; - j [label="sharding_round_robin_dispatch", style="filled,rounded", color=red, fillcolor=white]; - k [label="DP5 (Lowest common ancestor)"]; - fullsync; - l0 [label="DP6"]; - l1 [label="DP6"]; - m0 [label="Client"] - m1 [label="Client"] - n [label="Client"] - end [shape=box]; - } - -``Client`` in the graph is a ``DataPipe`` that sends a request and receives a response from multiprocessing queues. - -.. module:: torchdata.dataloader2 - -Determinism -^^^^^^^^^^^^ - -In ``DataLoader2``, a ``SeedGenerator`` becomes a single source of randomness and each ``ReadingService`` would access it via ``initialize_iteration()`` and generate corresponding random seeds for random ``DataPipe`` operations. - -In order to make sure that the Dataset shards are mutually exclusive and collectively exhaustive on multiprocessing processes and distributed nodes, ``MultiProcessingReadingService`` and ``DistributedReadingService`` would help :class:`DataLoader2` to synchronize random states for any random ``DataPipe`` operation prior to ``sharding_filter`` or ``sharding_round_robin_dispatch``. For the remaining ``DataPipe`` operations after sharding, unique random states are generated based on the distributed rank and worker process id by each ``ReadingService``, in order to perform different random transformations. - -Graph Mode -^^^^^^^^^^^ - -This also allows easier transition of data-preprocessing pipeline from research to production. After the ``DataPipe`` graph is created and validated with the ``ReadingServices``, a different ``ReadingService`` that configures and connects to the production service/infrastructure such as ``AIStore`` can be provided to :class:`DataLoader2` as a drop-in replacement. The ``ReadingService`` could potentially search the graph, and find ``DataPipe`` operations that can be delegated to the production service/infrastructure, then modify the graph correspondingly to achieve higher-performant execution. - -Extend ReadingService ----------------------- - -The followings are interfaces for custom ``ReadingService``. - -.. autoclass:: ReadingServiceInterface - :members: - -The checkpoint/snapshotting feature is a work in progress. Here is the preliminary interface (small changes are likely): - -.. autoclass:: CheckpointableReadingServiceInterface - :members: - -Graph Functions -^^^^^^^^^^^^^^^^ -And, graph utility functions are provided in ``torchdata.dataloader.graph`` to help users to do ``DataPipe`` graph rewrite for custom ``ReadingService``: - -.. module:: torchdata.dataloader2.graph - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: function.rst - - traverse_dps - find_dps - list_dps - remove_dp - replace_dp diff --git a/docs/source/torchdata.datapipes.iter.rst b/docs/source/torchdata.datapipes.iter.rst deleted file mode 100644 index 10866ec73..000000000 --- a/docs/source/torchdata.datapipes.iter.rst +++ /dev/null @@ -1,230 +0,0 @@ - -Iterable-style DataPipes -========================== - -.. currentmodule:: torchdata.datapipes.iter - -An iterable-style dataset is an instance of a subclass of IterableDataset that implements the ``__iter__()`` protocol, -and represents an iterable over data samples. This type of datasets is particularly suitable for cases where random -reads are expensive or even improbable, and where the batch size depends on the fetched data. - -For example, such a dataset, when called ``iter(iterdatapipe)``, could return a stream of data reading from a database, -a remote server, or even logs generated in real time. - -This is an updated version of ``IterableDataset`` in ``torch``. - -.. autoclass:: IterDataPipe - - -We have different types of Iterable DataPipes: - -1. Archive - open and decompress archive files of different formats. - -2. Augmenting - augment your samples (e.g. adding index, or cycle through indefinitely). - -3. Combinatorial - perform combinatorial operations (e.g. sampling, shuffling). - -4. Combining/Splitting - interact with multiple DataPipes by combining them or splitting one to many. - -5. Grouping - group samples within a DataPipe - -6. IO - interacting with the file systems or remote server (e.g. downloading, opening, - saving files, and listing the files in directories). - -7. Mapping - apply the a given function to each element in the DataPipe. - -8. Others - perform miscellaneous set of operations. - -9. Selecting - select specific samples within a DataPipe. - -10. Text - parse, read, and transform text files and data - -Archive DataPipes -------------------------- - -These DataPipes help opening and decompressing archive files of different formats. - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: class_template.rst - - Bz2FileLoader - Decompressor - RarArchiveLoader - TarArchiveLoader - TFRecordLoader - WebDataset - XzFileLoader - ZipArchiveLoader - -Augmenting DataPipes ------------------------------ -These DataPipes help to augment your samples. - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: class_template.rst - - Cycler - Enumerator - IndexAdder - Repeater - -Combinatorial DataPipes ------------------------------ -These DataPipes help to perform combinatorial operations. - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: class_template.rst - - InBatchShuffler - Sampler - Shuffler - -Combining/Splitting DataPipes ------------------------------ -These tend to involve multiple DataPipes, combining them or splitting one to many. - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: class_template.rst - - Concater - Demultiplexer - Forker - IterKeyZipper - MapKeyZipper - Multiplexer - MultiplexerLongest - RoundRobinDemultiplexer - SampleMultiplexer - UnZipper - Zipper - ZipperLongest - -Grouping DataPipes ------------------------------ -These DataPipes have you group samples within a DataPipe. - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: class_template.rst - - Batcher - BucketBatcher - Collator - Grouper - MaxTokenBucketizer - UnBatcher - -IO DataPipes -------------------------- - -These DataPipes help interacting with the file systems or remote server (e.g. downloading, opening, -saving files, and listing the files in directories). - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: class_template.rst - - AISFileLister - AISFileLoader - FSSpecFileLister - FSSpecFileOpener - FSSpecSaver - FileLister - FileOpener - GDriveReader - HttpReader - HuggingFaceHubReader - IoPathFileLister - IoPathFileOpener - IoPathSaver - OnlineReader - ParquetDataFrameLoader - S3FileLister - S3FileLoader - Saver - -Mapping DataPipes -------------------------- - -These DataPipes apply the a given function to each element in the DataPipe. - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: class_template.rst - - BatchAsyncMapper - BatchMapper - FlatMapper - Mapper - ShuffledFlatMapper - ThreadPoolMapper - -Other DataPipes -------------------------- -A miscellaneous set of DataPipes with different functionalities. - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: class_template.rst - - DataFrameMaker - EndOnDiskCacheHolder - FullSync - HashChecker - InMemoryCacheHolder - IterableWrapper - LengthSetter - MapToIterConverter - OnDiskCacheHolder - PinMemory - Prefetcher - RandomSplitter - ShardExpander - ShardingFilter - ShardingRoundRobinDispatcher - -Selecting DataPipes -------------------------- - -These DataPipes helps you select specific samples within a DataPipe. - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: class_template.rst - - Filter - Header - Dropper - Slicer - Flattener - -Text DataPipes ------------------------------ -These DataPipes help you parse, read, and transform text files and data. - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: class_template.rst - - CSVDictParser - CSVParser - JsonParser - LineReader - ParagraphAggregator - RoutedDecoder - Rows2Columnar - StreamReader diff --git a/docs/source/torchdata.datapipes.map.rst b/docs/source/torchdata.datapipes.map.rst deleted file mode 100644 index 37365a81c..000000000 --- a/docs/source/torchdata.datapipes.map.rst +++ /dev/null @@ -1,45 +0,0 @@ -Map-style DataPipes -=========================== - -.. currentmodule:: torchdata.datapipes.map - -A Map-style DataPipe is one that implements the ``__getitem__()`` and ``__len__()`` protocols, and represents a map -from (possibly non-integral) indices/keys to data samples. This is a close equivalent of ``Dataset`` from the PyTorch -core library. - -For example, when accessed with ``mapdatapipe[idx]``, could read the ``idx``-th image and its -corresponding label from a folder on the disk. - -.. autoclass:: MapDataPipe - - -By design, there are fewer ``MapDataPipe`` than ``IterDataPipe`` to avoid duplicate implementations of the same -functionalities as ``MapDataPipe``. We encourage users to use the built-in ``IterDataPipe`` for various functionalities, -and convert it to ``MapDataPipe`` as needed using :class:`.IterToMapConverter` or ``.to_map_datapipe()``. -If you have any question about usage or best practices while using ``MapDataPipe``, feel free to ask on the PyTorch -forum under the `'data' category `_. - -We are open to add additional ``MapDataPipe`` where the operations can be lazily executed and ``__len__`` can be -known in advance. Feel free to make suggestions with description of your use case in -`this Github issue `_. Feedback about our design choice is also -welcomed in that Github issue. - -Here is the list of available Map-style DataPipes: - -List of MapDataPipes -------------------------- - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: class_template.rst - - Batcher - Concater - InMemoryCacheHolder - IterToMapConverter - Mapper - SequenceWrapper - Shuffler - UnZipper - Zipper diff --git a/docs/source/torchdata.datapipes.utils.rst b/docs/source/torchdata.datapipes.utils.rst deleted file mode 100644 index ab06f82a1..000000000 --- a/docs/source/torchdata.datapipes.utils.rst +++ /dev/null @@ -1,60 +0,0 @@ -Utility Functions -=========================== - -.. - Comment: the next section will become "DataPipe Graph Visualization and Linter" when linter is added. - -DataPipe Graph Visualization -------------------------------------- -.. currentmodule:: torchdata.datapipes.utils - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: function.rst - - to_graph - -Common Utility Functions --------------------------------------- -.. currentmodule:: torchdata.datapipes.utils - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: function.rst - - janitor - pin_memory_fn - - -File Object and Stream Utility -------------------------------------- - -.. currentmodule:: torchdata.datapipes.utils - -.. autosummary:: - :nosignatures: - :toctree: generated/ - :template: class_template.rst - - StreamWrapper - - -DataLoader -------------------------------------- - -For documentation related to DataLoader, please refer to the -``torch.utils.data`` `documentation `_. Or, more specifically, the -`DataLoader API section `_. - -DataLoader v2 is currently in development. For more information please refer to :doc:`dataloader2`. - - -Sampler -------------------------------------- - -For documentation related to Sampler, please refer to the -``torch.utils.data`` -`documentation on Data Loading order `_. -The Sampler API section is `here `_. diff --git a/examples/dataloader2/train_loop.py b/examples/dataloader2/train_loop.py deleted file mode 100644 index f38a47fce..000000000 --- a/examples/dataloader2/train_loop.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch - -from torchdata.dataloader2 import DataLoader2 -from torchdata.datapipes.iter import IterableWrapper - - -class ToyModel(torch.nn.Module): - def __init__(self) -> None: - """ - In the model constructor, we instantiate four parameters and use them - as member parameters. - """ - super().__init__() - self.a = torch.nn.Parameter(torch.randn(())) - self.b = torch.nn.Parameter(torch.randn(())) - self.c = torch.nn.Parameter(torch.randn(())) - self.d = torch.nn.Parameter(torch.randn(())) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Simple model forward function - """ - return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 - - -def main() -> None: - model = ToyModel() - - train_features = IterableWrapper([torch.rand(3) for _ in range(20000)]) - train_labels = IterableWrapper([torch.rand(3) for _ in range(20000)]) - train_data_pipe = train_features.zip(train_labels).shuffle() - - # DataLoader2 wraps an iterable around the Datapipe to enable easy access to - # the features and labels. - data_loader = DataLoader2(datapipe=train_data_pipe) - - # Construct the loss function and the optimizer. - criterion = torch.nn.MSELoss(reduction="sum") - optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) - - # Loop over the dataset multiple times. Here we are doing only 3 training - # epochs - that is, three passes over the training datapipes. - for epoch in range(3): - - # Set manual seed per epoch to control the randomness for shuffle. - torch.manual_seed(epoch) - - running_loss = 0.0 - for step, data in enumerate(data_loader): - # Obtain the inputs and labels from data. - train_feature, train_label = data - - # Zero the parameter gradients. - optimizer.zero_grad() - - # Train step: forward + backward + optimize. - predicted_outputs = model(train_feature) - loss = criterion(predicted_outputs, train_label) - loss.backward() - optimizer.step() - - # Calculate the statistics. - running_loss += loss.item() - # Print the loss every 2000 mini-batches. - if step % 2000 == 1999: - print("[epoch: %d, %5d] loss: %.3f" % (epoch + 1, step + 1, running_loss / 2000)) - running_loss = 0.0 - - print("Finished Training") - - -if __name__ == "__main__": - main() # pragma: no cover diff --git a/examples/dataloader2/train_loop_distributed_reading_service.py b/examples/dataloader2/train_loop_distributed_reading_service.py deleted file mode 100644 index b9d60953b..000000000 --- a/examples/dataloader2/train_loop_distributed_reading_service.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import os - -import torch -import torch.distributed as dist -from torch import nn - -from torchdata.dataloader2 import DataLoader2, DistributedReadingService -from torchdata.datapipes.iter import IterableWrapper - - -class ToyModel(nn.Module): - def __init__(self) -> None: - """ - In the model constructor, we instantiate four parameters and use them - as member parameters. - """ - super().__init__() - self.a = nn.Parameter(torch.randn(())) - self.b = nn.Parameter(torch.randn(())) - self.c = nn.Parameter(torch.randn(())) - self.d = nn.Parameter(torch.randn(())) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Simple model forward function - """ - return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 - - -def main() -> None: - model = ToyModel() - - os.environ["RANK"] = str(0) - os.environ["WORLD_SIZE"] = str(2) - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "0" - - dist.init_process_group("gloo") - - # Use a prime number to make sure uneven data sharding and let - # DistributedReadingService prevent hanging with the unbalanced data shard - data_length = 19997 - - train_features = IterableWrapper([torch.rand(3) for _ in range(data_length)]) - train_labels = IterableWrapper([torch.rand(3) for _ in range(data_length)]) - - # sharding_filter will automatically shard the data based on the - # distributed ranks - train_data_pipe = train_features.zip(train_labels).shuffle().sharding_filter() - - # Torch Distributed is required to use DistributedReadingService - reading_service = DistributedReadingService() - - # Create DataLoader2 with DistributedReadingService - data_loader2 = DataLoader2( - datapipe=train_data_pipe, - reading_service=reading_service, - ) - - criterion = torch.nn.MSELoss(reduction="sum") - optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) - - for epoch in range(5): - - # Set manual seed per epoch to control the randomness for shuffle. - torch.manual_seed(epoch) - - running_loss = 0.0 - for step, data in enumerate(data_loader2): - train_feature, train_label = data - optimizer.zero_grad() - - predicted_outputs = model(train_feature) - loss = criterion(predicted_outputs, train_label) - loss.backward() - optimizer.step() - - running_loss += loss.item() - if step % 2000 == 1999: - print("[epoch: %d, %5d] loss: %.3f" % (epoch + 1, step + 1, running_loss / 2000)) - running_loss = 0.0 - - print("Finished Training") - - """ - Training Output: - - [epoch: 1, 2000] loss: 0.860 - [epoch: 1, 4000] loss: 0.823 - [epoch: 1, 6000] loss: 0.809 - [epoch: 1, 8000] loss: 0.778 - [epoch: 1, 10000] loss: 0.753 - [epoch: 1, 12000] loss: 0.756 - [epoch: 1, 14000] loss: 0.730 - [epoch: 1, 16000] loss: 0.727 - [epoch: 1, 18000] loss: 0.704 - [epoch: 1, 20000] loss: 0.703 - [epoch: 2, 2000] loss: 0.677 - [epoch: 2, 4000] loss: 0.649 - [epoch: 2, 6000] loss: 0.648 - [epoch: 2, 8000] loss: 0.629 - [epoch: 2, 10000] loss: 0.623 - [epoch: 2, 12000] loss: 0.593 - [epoch: 2, 14000] loss: 0.586 - [epoch: 2, 16000] loss: 0.584 - [epoch: 2, 18000] loss: 0.571 - [epoch: 2, 20000] loss: 0.558 - [epoch: 3, 2000] loss: 0.537 - [epoch: 3, 4000] loss: 0.540 - [epoch: 3, 6000] loss: 0.544 - [epoch: 3, 8000] loss: 0.512 - [epoch: 3, 10000] loss: 0.496 - [epoch: 3, 12000] loss: 0.506 - [epoch: 3, 14000] loss: 0.486 - [epoch: 3, 16000] loss: 0.489 - [epoch: 3, 18000] loss: 0.489 - [epoch: 3, 20000] loss: 0.456 - [epoch: 4, 2000] loss: 0.474 - [epoch: 4, 4000] loss: 0.445 - [epoch: 4, 6000] loss: 0.442 - [epoch: 4, 8000] loss: 0.440 - [epoch: 4, 10000] loss: 0.434 - [epoch: 4, 12000] loss: 0.421 - [epoch: 4, 14000] loss: 0.415 - [epoch: 4, 16000] loss: 0.404 - [epoch: 4, 18000] loss: 0.427 - [epoch: 4, 20000] loss: 0.410 - [epoch: 5, 2000] loss: 0.395 - [epoch: 5, 4000] loss: 0.393 - [epoch: 5, 6000] loss: 0.389 - [epoch: 5, 8000] loss: 0.397 - [epoch: 5, 10000] loss: 0.375 - [epoch: 5, 12000] loss: 0.375 - [epoch: 5, 14000] loss: 0.372 - [epoch: 5, 16000] loss: 0.365 - [epoch: 5, 18000] loss: 0.371 - [epoch: 5, 20000] loss: 0.359 - Finished Training - - """ - - -if __name__ == "__main__": - main() # pragma: no cover diff --git a/examples/dataloader2/train_loop_reading_service.py b/examples/dataloader2/train_loop_reading_service.py deleted file mode 100644 index e66cb4101..000000000 --- a/examples/dataloader2/train_loop_reading_service.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import torch - -from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService -from torchdata.datapipes.iter import IterableWrapper - - -class ToyModel(torch.nn.Module): - def __init__(self) -> None: - """ - In the model constructor, we instantiate four parameters and use them - as member parameters. - """ - super().__init__() - self.a = torch.nn.Parameter(torch.randn(())) - self.b = torch.nn.Parameter(torch.randn(())) - self.c = torch.nn.Parameter(torch.randn(())) - self.d = torch.nn.Parameter(torch.randn(())) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Simple model forward function - """ - return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 - - -def main() -> None: - model = ToyModel() - - train_features = IterableWrapper([torch.rand(3) for _ in range(20000)]) - train_labels = IterableWrapper([torch.rand(3) for _ in range(20000)]) - train_data_pipe = train_features.zip(train_labels).shuffle().sharding_filter() - - # Create DataLoader2 with MultiProcessingReadingService - data_loader = DataLoader2( - datapipe=train_data_pipe, - reading_service=MultiProcessingReadingService(num_workers=2), - ) - - criterion = torch.nn.MSELoss(reduction="sum") - optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) - - for epoch in range(3): - - # Set manual seed per epoch to control the randomness for shuffle. - torch.manual_seed(epoch) - - running_loss = 0.0 - for step, data in enumerate(data_loader): - train_feature, train_label = data - optimizer.zero_grad() - - predicted_outputs = model(train_feature) - loss = criterion(predicted_outputs, train_label) - loss.backward() - optimizer.step() - - running_loss += loss.item() - if step % 2000 == 1999: - print("[epoch: %d, %5d] loss: %.3f" % (epoch + 1, step + 1, running_loss / 2000)) - running_loss = 0.0 - - print("Finished Training") - - -if __name__ == "__main__": - main() # pragma: no cover diff --git a/examples/dataloader2/train_loop_torchtext.py b/examples/dataloader2/train_loop_torchtext.py deleted file mode 100644 index 3b77c094a..000000000 --- a/examples/dataloader2/train_loop_torchtext.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -import torchtext -import torchtext.functional as F - -import torchtext.transforms as T -from torch.hub import load_state_dict_from_url -from torch.optim import AdamW -from torchdata.dataloader2 import DataLoader2 -from torchtext.datasets import SST2 - - -LEARNING_RATE = 1e-5 -PADDING_IDX = 1 -BOS_IDX = 0 -EOS_IDX = 2 -MAX_SEQ_LEN = 256 - - -XLMR_VOCAB_PATH = r"https://download.pytorch.org/models/text/xlmr.vocab.pt" -XLMR_SPM_MODEL_PATH = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model" - -DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - -text_transform = T.Sequential( - T.SentencePieceTokenizer(XLMR_SPM_MODEL_PATH), - T.VocabTransform(load_state_dict_from_url(XLMR_VOCAB_PATH)), - T.Truncate(MAX_SEQ_LEN - 2), - T.AddToken(token=BOS_IDX, begin=True), - T.AddToken(token=EOS_IDX, begin=False), -) - -NUM_EPOCHS = 1 -BATCH_SIZE = 8 -NUM_CLASSES = 2 -INPUT_DIM = 768 - - -def apply_transform(x): - return text_transform(x[0]), x[1] - - -def train_step(input: torch.Tensor, target: torch.Tensor) -> None: - output = model(input) - loss = criteria(output, target) - optim.zero_grad() - loss.backward() - optim.step() - - -def eval_step(input: torch.Tensor, target: torch.Tensor) -> None: - output = model(input) - loss = criteria(output, target).item() - return float(loss), (output.argmax(1) == target).type(torch.float).sum().item() - - -def evaluate() -> None: - model.eval() - total_loss = 0 - correct_predictions = 0 - total_predictions = 0 - counter = 0 - with torch.no_grad(): - for batch in eval_dataloader: - input = F.to_tensor(batch["token_ids"], padding_value=PADDING_IDX).to(DEVICE) - target = torch.tensor(batch["target"]).to(DEVICE) - loss, predictions = eval_step(input, target) - total_loss += loss - correct_predictions += predictions - total_predictions += len(target) - counter += 1 - - return total_loss / counter, correct_predictions / total_predictions - - -def main() -> None: - global eval_dataloader, model, train_datapipe, criteria, optim - - train_datapipe = SST2(split="train") - eval_datapipe = SST2(split="dev") - - train_datapipe = train_datapipe.map(apply_transform) - train_datapipe = train_datapipe.batch(BATCH_SIZE) - train_datapipe = train_datapipe.rows2columnar(["token_ids", "target"]) - train_dataloader = DataLoader2(datapipe=train_datapipe) - print("Created train dataloader") - - eval_datapipe = eval_datapipe.map(apply_transform) - eval_datapipe = eval_datapipe.batch(BATCH_SIZE) - eval_datapipe = eval_datapipe.rows2columnar(["token_ids", "target"]) - eval_dataloader = DataLoader2(datapipe=eval_datapipe) - print("Created eval dataloader") - - classifier_head = torchtext.models.RobertaClassificationHead(num_classes=NUM_CLASSES, input_dim=INPUT_DIM) - model = torchtext.models.XLMR_BASE_ENCODER.get_model(head=classifier_head) - model.to(DEVICE) - - optim = AdamW(model.parameters(), lr=LEARNING_RATE) - criteria = nn.CrossEntropyLoss() - - for epoch in range(NUM_EPOCHS): - for step, batch in enumerate(train_dataloader): - input = F.to_tensor(batch["token_ids"], padding_value=PADDING_IDX).to(DEVICE) - target = torch.tensor(batch["target"]).to(DEVICE) - train_step(input, target) - - # stop early for example purpose - if step == 10: - break - - loss, accuracy = evaluate() - print(f"Epoch: {epoch}, loss: {loss}, accuracy: {accuracy}") - - print("Finished Training") - - -if __name__ == "__main__": - main() # pragma: no cover diff --git a/examples/text/CC100.ipynb b/examples/text/CC100.ipynb index 943bfbf12..7532b0f4a 100644 --- a/examples/text/CC100.ipynb +++ b/examples/text/CC100.ipynb @@ -1,247 +1,251 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "source": [ - "import torch\n", - "import os\n", - "\n", - "from functools import partial\n", - "from operator import itemgetter\n", - "from torchdata.datapipes.iter import (\n", - " FileOpener,\n", - " HttpReader,\n", - " IterableWrapper,\n", - " SampleMultiplexer,\n", - ")\n", - "\n", - "ROOT_DIR = os.path.expanduser('~/.torchdata/CC100') # This directory needs to be crated and set\n", - "\n", - "\n", - "def _path_fn(root, x):\n", - " return os.path.join(root, os.path.basename(x).rstrip(\".xz\"))\n", - "\n", - "def _process_tuple(language_code, t):\n", - " return language_code, t[1].decode()" - ], - "outputs": [], - "metadata": { - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 2, - "source": [ - "# CC100 support (http://data.statmt.org/cc-100/)\n", - "\n", - "URL=\"http://data.statmt.org/cc-100/%s.txt.xz\"\n", - "VALID_CODES = [\n", - " \"am\", \"ar\", \"as\", \"az\", \"be\", \"bg\", \"bn\", \"bn_rom\", \"br\", \"bs\", \"ca\", \"cs\", \"cy\", \"da\", \"de\", \n", - " \"el\", \"en\", \"eo\", \"es\", \"et\", \"eu\", \"fa\", \"ff\", \"fi\", \"fr\", \"fy\", \"ga\", \"gd\", \"gl\", \"gn\", \"gu\", \n", - " \"ha\", \"he\", \"hi\", \"hi_rom\", \"hr\", \"ht\", \"hu\", \"hy\", \"id\", \"ig\", \"is\", \"it\", \"ja\", \"jv\", \"ka\", \n", - " \"kk\", \"km\", \"kn\", \"ko\", \"ku\", \"ky\", \"la\", \"lg\", \"li\", \"ln\", \"lo\", \"lt\", \"lv\", \"mg\", \"mk\", \"ml\", \n", - " \"mn\", \"mr\", \"ms\", \"my\", \"my_zaw\", \"ne\", \"nl\", \"no\", \"ns\", \"om\", \"or\", \"pa\", \"pl\", \"ps\", \"pt\", \n", - " \"qu\", \"rm\", \"ro\", \"ru\", \"sa\", \"si\", \"sc\", \"sd\", \"sk\", \"sl\", \"so\", \"sq\", \"sr\", \"ss\", \"su\", \"sv\", \n", - " \"sw\", \"ta\", \"ta_rom\", \"te\", \"te_rom\", \"th\", \"tl\", \"tn\", \"tr\", \"ug\", \"uk\", \"ur\", \"ur_rom\", \"uz\", \n", - " \"vi\", \"wo\", \"xh\", \"yi\", \"yo\", \"zh-Hans\", \"zh-Hant\", \"zu\", \n", - "]\n", - "\n", - "def CC100(root, language_code, use_caching=True):\n", - " if language_code not in VALID_CODES:\n", - " raise ValueError(f\"Invalid language code {language_code}\")\n", - " url = URL % language_code\n", - " if use_caching:\n", - " cache_compressed_dp = HttpReader(cache_compressed_dp).map(itemgetter(0))\n", - " cache_compressed_dp = cache_compressed_dp.end_caching(mode=\"wb\", same_filepath_fn=True)\n", - " cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_path_fn, root))\n", - " cache_decompressed_dp = FileOpener(cache_decompressed_dp).read_from_xz()\n", - " cache_decompressed_dp = cache_decompressed_dp.end_caching(mode=\"wb\")\n", - " data_dp = FileOpener(cache_decompressed_dp)\n", - " else:\n", - " data_dp = HttpReader([url]).read_from_xz()\n", - " units_dp = data_dp.readlines().map(partial(_process_tuple, language_code))\n", - " return units_dp\n" - ], - "outputs": [], - "metadata": { - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 3, - "source": [ - "# Sample from multi-gigabyte-size compressed dataset without downloading the whole thing\n", - "# This executes very fast\n", - "import time\n", - "start_time = time.time()\n", - "for i, x in enumerate(CC100(ROOT_DIR, 'en', use_caching=False)):\n", - " print(x)\n", - " if i > 5:\n", - " break\n", - "print(f\"Execution time {(time.time() - start_time):.2f} secs\")" - ], - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "('en', 'Belmont Estate is on the market for $63 million and boasts roughly 22,000 square feet of luxurious finishes and elaborate architecture on 1.28 acres. Listed on Thursday, the home is being sold by high-end real estate firm Sotheby’s International Realty Canada.')\n", - "('en', '“Within the city we’ve had homes that have sold for $56 million, $33 million, $31 million but this will be the record of the offering price,” listing agent Christa Frosch of Sotheby’s tells BuzzBuzzNews.')\n", - "('en', 'The three-storey home has five bedrooms, twelve bathrooms and an elevator in the west wing. Built to entertain, two main gallery halls can seat up to 100 guests. The Italian-inspired kitchen includes a fireplace and walls and ceilings throughout the home feature murals and artwork. Lavish amenities include an indoor pool and sauna, a six-car garage and a private entrance in-law’s suite.')\n", - "('en', 'Surrounding the property is a Versailles-inspired garden with a variety of trees, plants and an orchard. In the spring, over 12,000 flowers bloom in the tiered, three-level garden.')\n", - "('en', 'According to Frosch, the listing has received global attention and, despite being on the market for only 24 hours, buyers are already showing interest.')\n", - "('en', '“We just went to the market yesterday, it’s private through Sotheby’s and we’ve already started to get calls,” says Frosch.')\n", - "('en', '')\n", - "Execution time 0.55 secs\n" - ] - } - ], - "metadata": { - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 5, - "source": [ - "# cache\n", - "# This cell is very slow to run the first time as it downloads a dataset from a very slow server\n", - "next(iter(CC100(ROOT_DIR, 'ha', use_caching=True)))" - ], - "outputs": [ + "cell_type": "code", + "execution_count": 1, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import torch\n", + "import os\n", + "\n", + "from functools import partial\n", + "from operator import itemgetter\n", + "from torchdata.datapipes.iter import (\n", + " FileOpener,\n", + " HttpReader,\n", + " IterableWrapper,\n", + " SampleMultiplexer,\n", + ")\n", + "\n", + "ROOT_DIR = os.path.expanduser('~/.torchdata/CC100') # This directory needs to be crated and set\n", + "\n", + "\n", + "def _path_fn(root, x):\n", + " return os.path.join(root, os.path.basename(x).rstrip(\".xz\"))\n", + "\n", + "def _process_tuple(language_code, t):\n", + " return language_code, t[1].decode()" + ] + }, { - "data": { - "text/plain": "('ha',\n 'Dangantaka tsakanin kasashen Masar da Turkiya ta yi tsami a cikin yan kwanakin nan, saboda sanin iyakokin da kowanne daga cikin yake mallaka a tekun Mediterranean .')" - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "metadata": { - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 5, - "source": [ - "# cache\n", - "# This cell is very slow to run the first time as it downloads a dataset from a very slow server\n", - "next(iter(CC100(ROOT_DIR, 'yi', use_caching=True)))" - ], - "outputs": [ + "cell_type": "code", + "execution_count": 2, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# CC100 support (http://data.statmt.org/cc-100/)\n", + "\n", + "URL=\"http://data.statmt.org/cc-100/%s.txt.xz\"\n", + "VALID_CODES = [\n", + " \"am\", \"ar\", \"as\", \"az\", \"be\", \"bg\", \"bn\", \"bn_rom\", \"br\", \"bs\", \"ca\", \"cs\", \"cy\", \"da\", \"de\",\n", + " \"el\", \"en\", \"eo\", \"es\", \"et\", \"eu\", \"fa\", \"ff\", \"fi\", \"fr\", \"fy\", \"ga\", \"gd\", \"gl\", \"gn\", \"gu\",\n", + " \"ha\", \"he\", \"hi\", \"hi_rom\", \"hr\", \"ht\", \"hu\", \"hy\", \"id\", \"ig\", \"is\", \"it\", \"ja\", \"jv\", \"ka\",\n", + " \"kk\", \"km\", \"kn\", \"ko\", \"ku\", \"ky\", \"la\", \"lg\", \"li\", \"ln\", \"lo\", \"lt\", \"lv\", \"mg\", \"mk\", \"ml\",\n", + " \"mn\", \"mr\", \"ms\", \"my\", \"my_zaw\", \"ne\", \"nl\", \"no\", \"ns\", \"om\", \"or\", \"pa\", \"pl\", \"ps\", \"pt\",\n", + " \"qu\", \"rm\", \"ro\", \"ru\", \"sa\", \"si\", \"sc\", \"sd\", \"sk\", \"sl\", \"so\", \"sq\", \"sr\", \"ss\", \"su\", \"sv\",\n", + " \"sw\", \"ta\", \"ta_rom\", \"te\", \"te_rom\", \"th\", \"tl\", \"tn\", \"tr\", \"ug\", \"uk\", \"ur\", \"ur_rom\", \"uz\",\n", + " \"vi\", \"wo\", \"xh\", \"yi\", \"yo\", \"zh-Hans\", \"zh-Hant\", \"zu\",\n", + "]\n", + "\n", + "def CC100(root, language_code, use_caching=True):\n", + " if language_code not in VALID_CODES:\n", + " raise ValueError(f\"Invalid language code {language_code}\")\n", + " url = URL % language_code\n", + " if use_caching:\n", + " cache_compressed_dp = HttpReader(cache_compressed_dp).map(itemgetter(0))\n", + " cache_compressed_dp = cache_compressed_dp.end_caching(mode=\"wb\", same_filepath_fn=True)\n", + " cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_path_fn, root))\n", + " cache_decompressed_dp = FileOpener(cache_decompressed_dp).read_from_xz()\n", + " cache_decompressed_dp = cache_decompressed_dp.end_caching(mode=\"wb\")\n", + " data_dp = FileOpener(cache_decompressed_dp)\n", + " else:\n", + " data_dp = HttpReader([url]).read_from_xz()\n", + " units_dp = data_dp.readlines().map(partial(_process_tuple, language_code))\n", + " return units_dp\n" + ] + }, { - "output_type": "execute_result", - "data": { - "text/plain": [ - "('yi', 'קאַטעגאָריע:cs-- – װיקיװערטערבוך')" + "cell_type": "code", + "execution_count": 3, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "('en', 'Belmont Estate is on the market for $63 million and boasts roughly 22,000 square feet of luxurious finishes and elaborate architecture on 1.28 acres. Listed on Thursday, the home is being sold by high-end real estate firm Sotheby’s International Realty Canada.')\n", + "('en', '“Within the city we’ve had homes that have sold for $56 million, $33 million, $31 million but this will be the record of the offering price,” listing agent Christa Frosch of Sotheby’s tells BuzzBuzzNews.')\n", + "('en', 'The three-storey home has five bedrooms, twelve bathrooms and an elevator in the west wing. Built to entertain, two main gallery halls can seat up to 100 guests. The Italian-inspired kitchen includes a fireplace and walls and ceilings throughout the home feature murals and artwork. Lavish amenities include an indoor pool and sauna, a six-car garage and a private entrance in-law’s suite.')\n", + "('en', 'Surrounding the property is a Versailles-inspired garden with a variety of trees, plants and an orchard. In the spring, over 12,000 flowers bloom in the tiered, three-level garden.')\n", + "('en', 'According to Frosch, the listing has received global attention and, despite being on the market for only 24 hours, buyers are already showing interest.')\n", + "('en', '“We just went to the market yesterday, it’s private through Sotheby’s and we’ve already started to get calls,” says Frosch.')\n", + "('en', '')\n", + "Execution time 0.55 secs\n" + ] + } + ], + "source": [ + "# Sample from multi-gigabyte-size compressed dataset without downloading the whole thing\n", + "# This executes very fast\n", + "import time\n", + "start_time = time.time()\n", + "for i, x in enumerate(CC100(ROOT_DIR, 'en', use_caching=False)):\n", + " print(x)\n", + " if i > 5:\n", + " break\n", + "print(f\"Execution time {(time.time() - start_time):.2f} secs\")" ] - }, - "metadata": {}, - "execution_count": 5 - } - ], - "metadata": { - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 6, - "source": [ - "import itertools\n", - "# Cache two of the datasets. The backend rate-limits connections to 1 per ip, \n", - "# so you can't have more than one dataset running without caching\n", - "\n", - "# If you do \"run all\" this may fail because the previous http connections might still be alive\n", - "\n", - "z1 = CC100(ROOT_DIR, 'li', use_caching=False).cycle()\n", - "z2 = CC100(ROOT_DIR, 'ha', use_caching=True).cycle()\n", - "z3 = CC100(ROOT_DIR, 'yi', use_caching=True).cycle()\n", - "\n", - "z = SampleMultiplexer({z1: 0.7, z2: 0.2, z3: 0.1})\n", - "\n", - "l = list(itertools.islice(z, 0, 500000))\n", - "print(l[0:20])\n", - "\n", - "ratio = sum(1 for k,v in l if k == 'li') / len(l)\n", - "print(f\"Expected ratio: 0.7, actual {ratio}\")\n" - ], - "outputs": [ + }, { - "output_type": "stream", - "name": "stdout", - "text": [ - "[('li', \"Kop van 't Ende - Wikipedia\"), ('li', ''), ('li', \"Coos is 'n in 1853 gestiech graofsjap in Oregon, VS. Coos is verneump nao de Cook-koo-oose, 'n inheims Amerikaans stam, die allewijl neet mie besteit. De hoofplaots vaan 't graofsjap is Coquille.\"), ('ha', 'Dangantaka tsakanin kasashen Masar da Turkiya ta yi tsami a cikin yan kwanakin nan, saboda sanin iyakokin da kowanne daga cikin yake mallaka a tekun Mediterranean .'), ('yi', 'קאַטעגאָריע:cs-- – װיקיװערטערבוך'), ('li', \"'t Graofsjap heet 'n totaal oppervlak vaan 4.678 km² boevaan 4.145 km² land is en 533 km² water.\"), ('ha', \"Kamfanin dillancin labaran IRNA na kasar Iran ya nakalto Ahmad Abu-Zaid kakakin ma'aikatar harkokin wajen kasar Masar yarjejeniyar da kasar Masar ta cimma da kasar Cyprus kan iyakokin da kowanne daga cikinsu yake mallaka daga gabacin tekun Mediterranean ta zama doka ce, kuma duk wanda yayi kokarin taka ta Masar zata kalubalance shi.\"), ('ha', 'Abu-Zaid ya kara da cewa yarjejeniyar rabon kan iyaka a cikin tekun Mediterranean , yarjejjeniya ce ta kasa da kasa wacce Majalisar dinkin duniya ta amince da ita.'), ('li', \"Volgens de census vaan 2000 bedroog 't totaol bevolkingsaontal in Coos County 62.779.\"), ('ha', 'Amma ministan harkokin wajen kasar Turkiya Maulud Chavis-Uglu, a ranar litinin da ta gabata ce ya bada sanarwan cewa kasar Turkiya ba ta amince da yarjejeniyar da kasashen Masar ta Cyprus suka cimma kan rabon kan iyaka da kuma amfani da tekun Mediterranean a shekara ta 2013 ba.'), ('li', \"De twie belaankriekste plaotse vaan 't graofsjap zien:\"), ('ha', 'Wani Sabon Sabani Ya Kunno kai Tsakanin Kasashen Masar Da Turkiyya'), ('li', \"Gesjreve in 't Mestreechs\"), ('li', \"Dees pazjena is 't lèts verangerd op 9 mrt 2013, 04:24.\"), ('ha', 'Masar Ta Zargi Mahukuntan Turkiyya Da Kokarin Yin Zagon Kasa Ga Harkar Tattalin Arzikin Kasarta'), ('li', ''), ('li', \"'ne Centimeter (aofkorting: cm) is geliek aon 'ne hoonderdste meter, ofwel 0,01 meter. Dit is weer geliek aon 10 millimeter. 't Voorveugsel centi is aofkomsteg vaan 't Latiense centum, wat hoonderd beteikent. In 't dageleks leve weurt de maot dèks gebruuk: me gebruuk 't veur 't mete vaan liechaamslengde, meubelaofmetinge, kleiding, enz. In technische teikeninge gebruuk me liever de millimeter.\"), ('ha', ''), ('li', \"'n Meetlint weurt dèks ouch 'ne centimeter geneump.\"), ('li', \"Gesjreve in 't Mestreechs\")]\n", - "Expected ratio: 0.7, actual 0.699058\n" - ] - } - ], - "metadata": { - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 8, - "source": [ - "next(iter(CC100(ROOT_DIR, 'ha', use_caching=False).lines_to_paragraphs()))" - ], - "outputs": [ + "cell_type": "code", + "execution_count": 5, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "('ha',\n", + " 'Dangantaka tsakanin kasashen Masar da Turkiya ta yi tsami a cikin yan kwanakin nan, saboda sanin iyakokin da kowanne daga cikin yake mallaka a tekun Mediterranean .')" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# cache\n", + "# This cell is very slow to run the first time as it downloads a dataset from a very slow server\n", + "next(iter(CC100(ROOT_DIR, 'ha', use_caching=True)))" + ] + }, { - "output_type": "execute_result", - "data": { - "text/plain": [ - "('ha',\n", - " \"Dangantaka tsakanin kasashen Masar da Turkiya ta yi tsami a cikin yan kwanakin nan, saboda sanin iyakokin da kowanne daga cikin yake mallaka a tekun Mediterranean .\\nKamfanin dillancin labaran IRNA na kasar Iran ya nakalto Ahmad Abu-Zaid kakakin ma'aikatar harkokin wajen kasar Masar yarjejeniyar da kasar Masar ta cimma da kasar Cyprus kan iyakokin da kowanne daga cikinsu yake mallaka daga gabacin tekun Mediterranean ta zama doka ce, kuma duk wanda yayi kokarin taka ta Masar zata kalubalance shi.\\nAbu-Zaid ya kara da cewa yarjejeniyar rabon kan iyaka a cikin tekun Mediterranean , yarjejjeniya ce ta kasa da kasa wacce Majalisar dinkin duniya ta amince da ita.\\nAmma ministan harkokin wajen kasar Turkiya Maulud Chavis-Uglu, a ranar litinin da ta gabata ce ya bada sanarwan cewa kasar Turkiya ba ta amince da yarjejeniyar da kasashen Masar ta Cyprus suka cimma kan rabon kan iyaka da kuma amfani da tekun Mediterranean a shekara ta 2013 ba.\\nWani Sabon Sabani Ya Kunno kai Tsakanin Kasashen Masar Da Turkiyya\\nMasar Ta Zargi Mahukuntan Turkiyya Da Kokarin Yin Zagon Kasa Ga Harkar Tattalin Arzikin Kasarta\")" + "cell_type": "code", + "execution_count": 5, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "('yi', 'קאַטעגאָריע:cs-- – װיקיװערטערבוך')" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# cache\n", + "# This cell is very slow to run the first time as it downloads a dataset from a very slow server\n", + "next(iter(CC100(ROOT_DIR, 'yi', use_caching=True)))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[('li', \"Kop van 't Ende - Wikipedia\"), ('li', ''), ('li', \"Coos is 'n in 1853 gestiech graofsjap in Oregon, VS. Coos is verneump nao de Cook-koo-oose, 'n inheims Amerikaans stam, die allewijl neet mie besteit. De hoofplaots vaan 't graofsjap is Coquille.\"), ('ha', 'Dangantaka tsakanin kasashen Masar da Turkiya ta yi tsami a cikin yan kwanakin nan, saboda sanin iyakokin da kowanne daga cikin yake mallaka a tekun Mediterranean .'), ('yi', 'קאַטעגאָריע:cs-- – װיקיװערטערבוך'), ('li', \"'t Graofsjap heet 'n totaal oppervlak vaan 4.678 km² boevaan 4.145 km² land is en 533 km² water.\"), ('ha', \"Kamfanin dillancin labaran IRNA na kasar Iran ya nakalto Ahmad Abu-Zaid kakakin ma'aikatar harkokin wajen kasar Masar yarjejeniyar da kasar Masar ta cimma da kasar Cyprus kan iyakokin da kowanne daga cikinsu yake mallaka daga gabacin tekun Mediterranean ta zama doka ce, kuma duk wanda yayi kokarin taka ta Masar zata kalubalance shi.\"), ('ha', 'Abu-Zaid ya kara da cewa yarjejeniyar rabon kan iyaka a cikin tekun Mediterranean , yarjejjeniya ce ta kasa da kasa wacce Majalisar dinkin duniya ta amince da ita.'), ('li', \"Volgens de census vaan 2000 bedroog 't totaol bevolkingsaontal in Coos County 62.779.\"), ('ha', 'Amma ministan harkokin wajen kasar Turkiya Maulud Chavis-Uglu, a ranar litinin da ta gabata ce ya bada sanarwan cewa kasar Turkiya ba ta amince da yarjejeniyar da kasashen Masar ta Cyprus suka cimma kan rabon kan iyaka da kuma amfani da tekun Mediterranean a shekara ta 2013 ba.'), ('li', \"De twie belaankriekste plaotse vaan 't graofsjap zien:\"), ('ha', 'Wani Sabon Sabani Ya Kunno kai Tsakanin Kasashen Masar Da Turkiyya'), ('li', \"Gesjreve in 't Mestreechs\"), ('li', \"Dees pazjena is 't lèts verangerd op 9 mrt 2013, 04:24.\"), ('ha', 'Masar Ta Zargi Mahukuntan Turkiyya Da Kokarin Yin Zagon Kasa Ga Harkar Tattalin Arzikin Kasarta'), ('li', ''), ('li', \"'ne Centimeter (aofkorting: cm) is geliek aon 'ne hoonderdste meter, ofwel 0,01 meter. Dit is weer geliek aon 10 millimeter. 't Voorveugsel centi is aofkomsteg vaan 't Latiense centum, wat hoonderd beteikent. In 't dageleks leve weurt de maot dèks gebruuk: me gebruuk 't veur 't mete vaan liechaamslengde, meubelaofmetinge, kleiding, enz. In technische teikeninge gebruuk me liever de millimeter.\"), ('ha', ''), ('li', \"'n Meetlint weurt dèks ouch 'ne centimeter geneump.\"), ('li', \"Gesjreve in 't Mestreechs\")]\n", + "Expected ratio: 0.7, actual 0.699058\n" + ] + } + ], + "source": [ + "import itertools\n", + "# Cache two of the datasets. The backend rate-limits connections to 1 per ip,\n", + "# so you can't have more than one dataset running without caching\n", + "\n", + "# If you do \"run all\" this may fail because the previous http connections might still be alive\n", + "\n", + "z1 = CC100(ROOT_DIR, 'li', use_caching=False).cycle()\n", + "z2 = CC100(ROOT_DIR, 'ha', use_caching=True).cycle()\n", + "z3 = CC100(ROOT_DIR, 'yi', use_caching=True).cycle()\n", + "\n", + "z = SampleMultiplexer({z1: 0.7, z2: 0.2, z3: 0.1})\n", + "\n", + "l = list(itertools.islice(z, 0, 500000))\n", + "print(l[0:20])\n", + "\n", + "ratio = sum(1 for k,v in l if k == 'li') / len(l)\n", + "print(f\"Expected ratio: 0.7, actual {ratio}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "('ha',\n", + " \"Dangantaka tsakanin kasashen Masar da Turkiya ta yi tsami a cikin yan kwanakin nan, saboda sanin iyakokin da kowanne daga cikin yake mallaka a tekun Mediterranean .\\nKamfanin dillancin labaran IRNA na kasar Iran ya nakalto Ahmad Abu-Zaid kakakin ma'aikatar harkokin wajen kasar Masar yarjejeniyar da kasar Masar ta cimma da kasar Cyprus kan iyakokin da kowanne daga cikinsu yake mallaka daga gabacin tekun Mediterranean ta zama doka ce, kuma duk wanda yayi kokarin taka ta Masar zata kalubalance shi.\\nAbu-Zaid ya kara da cewa yarjejeniyar rabon kan iyaka a cikin tekun Mediterranean , yarjejjeniya ce ta kasa da kasa wacce Majalisar dinkin duniya ta amince da ita.\\nAmma ministan harkokin wajen kasar Turkiya Maulud Chavis-Uglu, a ranar litinin da ta gabata ce ya bada sanarwan cewa kasar Turkiya ba ta amince da yarjejeniyar da kasashen Masar ta Cyprus suka cimma kan rabon kan iyaka da kuma amfani da tekun Mediterranean a shekara ta 2013 ba.\\nWani Sabon Sabani Ya Kunno kai Tsakanin Kasashen Masar Da Turkiyya\\nMasar Ta Zargi Mahukuntan Turkiyya Da Kokarin Yin Zagon Kasa Ga Harkar Tattalin Arzikin Kasarta\")" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "next(iter(CC100(ROOT_DIR, 'ha', use_caching=False).lines_to_paragraphs()))" ] - }, - "metadata": {}, - "execution_count": 8 } - ], - "metadata": { - "pycharm": { - "name": "#%%\n" + ], + "metadata": { + "fileHeader": "", + "fileUid": "6a787fbf-7d19-47c8-ae45-d670856ce03b", + "isAdHoc": false, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "bento_kernel_default" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" } - } - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.5" } - }, - "nbformat": 4, - "nbformat_minor": 5 } diff --git a/packaging/build_wheel.sh b/packaging/build_wheel.sh index 7a965dd02..724612e90 100755 --- a/packaging/build_wheel.sh +++ b/packaging/build_wheel.sh @@ -19,7 +19,6 @@ setup_env pip_install future wheel setup_pip_pytorch_version -git submodule update --init --recursive pip_install -r requirements.txt python setup.py clean # TODO: Add windows support diff --git a/packaging/env-var-script.txt b/packaging/env-var-script.txt index 25bb523ac..087ba842d 100644 --- a/packaging/env-var-script.txt +++ b/packaging/env-var-script.txt @@ -1,2 +1 @@ -export BUILD_S3="1" export MACOSX_DEPLOYMENT_TARGET="10.13" diff --git a/packaging/torchdata/meta.yaml b/packaging/torchdata/meta.yaml index 2c95d0e0a..3fb5aea27 100644 --- a/packaging/torchdata/meta.yaml +++ b/packaging/torchdata/meta.yaml @@ -10,9 +10,6 @@ requirements: build: - cmake - ninja - # TODO: Enable AWSSDK on windows - # - {{ compiler('c') }} # [win] - # - {{ compiler('cxx') }} # [win] - python - setuptools - cpuonly @@ -30,13 +27,11 @@ build: string: py{{py}} script_env: - BUILD_VERSION - - BUILD_S3 test: imports: - torchdata - - torchdata.dataloader2 - - torchdata.datapipes + - torchdata.stateful_dataloader source_files: - test requires: diff --git a/setup.py b/setup.py index 726b1da9d..b3afd9a16 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ from setuptools import find_packages, setup -from tools.setup_helpers.extension import CMakeBuild, get_ext_modules +from tools.setup_helpers.extension import get_ext_modules ROOT_DIR = Path(__file__).parent.resolve() @@ -29,49 +29,6 @@ RUN_BUILD_DEP = False -def _get_submodule_folders(): - git_modules_path = ROOT_DIR / ".gitmodules" - if not os.path.exists(git_modules_path): - return [] - with open(git_modules_path) as f: - return [ - os.path.join(ROOT_DIR, line.split("=", 1)[1].strip()) - for line in f.readlines() - if line.strip().startswith("path") - ] - - -def _check_submodules(): - def check_for_files(folder, files): - if not any(os.path.exists(os.path.join(folder, f)) for f in files): - print("Could not find any of {} in {}".format(", ".join(files), folder)) - print("Did you run 'git submodule update --init --recursive --jobs 0'?") - sys.exit(1) - - def not_exists_or_empty(folder): - return not os.path.exists(folder) or (os.path.isdir(folder) and len(os.listdir(folder)) == 0) - - if bool(os.getenv("USE_SYSTEM_LIBS", False)): - return - folders = _get_submodule_folders() - # If none of the submodule folders exists, try to initialize them - if all(not_exists_or_empty(folder) for folder in folders): - try: - import time - - print(" --- Trying to initialize submodules") - start = time.time() - subprocess.check_call(["git", "submodule", "update", "--init", "--recursive"], cwd=ROOT_DIR) - end = time.time() - print(f" --- Submodule initialization took {end - start:.2f} sec") - except Exception: - print(" --- Submodule initalization failed") - print("Please run:\n\tgit submodule update --init --recursive --jobs 0") - sys.exit(1) - for folder in folders: - check_for_files(folder, ["CMakeLists.txt", "Makefile", "setup.py", "LICENSE", "LICENSE.md", "LICENSE.txt"]) - - def _get_version(): with open(os.path.join(ROOT_DIR, "version.txt")) as f: version = f.readline().strip() @@ -145,13 +102,6 @@ def remove_extension(pattern): _export_version(VERSION, SHA) print("-- Building version " + VERSION) - - if RUN_BUILD_DEP: - from tools.gen_pyi import gen_pyi - - _check_submodules() - gen_pyi() - setup( # Metadata name="torchdata", @@ -178,19 +128,10 @@ def remove_extension(pattern): "Programming Language :: Python :: Implementation :: CPython", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], - package_data={ - "torchdata": [ - "datapipes/iter/*.pyi", - "datapipes/map/*.pyi", - ], - }, # Package Info - packages=find_packages(exclude=["test*", "examples*", "tools*", "torchdata.csrc*", "build*"]), + packages=find_packages(exclude=["test*", "examples*", "tools*", "build*"]), zip_safe=False, # C++ Extension Modules ext_modules=get_ext_modules(), - cmdclass={ - "build_ext": CMakeBuild, - "clean": clean, - }, + cmdclass={"clean": clean}, ) diff --git a/test/_utils/_common_utils_for_test.py b/test/_utils/_common_utils_for_test.py index fea9c82fc..f328e22b1 100644 --- a/test/_utils/_common_utils_for_test.py +++ b/test/_utils/_common_utils_for_test.py @@ -11,9 +11,6 @@ import tempfile from typing import List, Tuple, TypeVar -from torchdata.datapipes.iter import IterDataPipe - - T_co = TypeVar("T_co", covariant=True) @@ -24,31 +21,10 @@ IS_M1 = IS_MACOS and "arm" in platform.platform() -class IDP_NoLen(IterDataPipe): - def __init__(self, input_dp) -> None: - super().__init__() - self.input_dp = input_dp - - def __iter__(self): - yield from self.input_dp - - def get_name(path_and_stream): return os.path.basename(path_and_stream[0]), path_and_stream[1] -# Given a DataPipe and integer n, iterate the DataPipe for n elements and store the elements into a list -# Then, reset the DataPipe and return a tuple of two lists -# 1. A list of elements yielded before the reset -# 2. A list of all elements of the DataPipe after the reset -def reset_after_n_next_calls(datapipe: IterDataPipe[T_co], n: int) -> Tuple[List[T_co], List[T_co]]: - it = iter(datapipe) - res_before_reset = [] - for _ in range(n): - res_before_reset.append(next(it)) - return res_before_reset, list(datapipe) - - def create_temp_dir(dir=None): # The temp dir and files within it will be released and deleted in tearDown(). # Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function. diff --git a/test/bin/elastic_training.py b/test/bin/elastic_training.py index b1f9569b0..536c1005b 100644 --- a/test/bin/elastic_training.py +++ b/test/bin/elastic_training.py @@ -12,24 +12,18 @@ from torch.distributed.elastic.multiprocessing.errors import record from torch.utils.data import DataLoader -from torchdata.dataloader2 import DataLoader2, DistributedReadingService from torchdata.datapipes.iter import IterableWrapper -def _get_dataloader(data_length: int, dl2: bool, shuffle: bool, rs=None): +def _get_dataloader(data_length: int, bool, shuffle: bool, rs=None): data_source = IterableWrapper(list(range(data_length))) dp = data_source.sharding_filter() if shuffle: dp = dp.shuffle() - if dl2: - if rs is None: - rs = DistributedReadingService() - dl = DataLoader2(dp, reading_service=rs) - else: - dp = dp.fullsync() - dl = DataLoader(dp) + dp = dp.fullsync() + dl = DataLoader(dp) return dl @@ -76,10 +70,6 @@ def main(backend, dl2): assert len(results[0]) == len(results[2]) assert results[0] != results[2] - # Properly shutdown the process group - if isinstance(dl, DataLoader2): - dl.shutdown() - if __name__ == "__main__": parser = argparse.ArgumentParser(description="Elastic Training") @@ -99,8 +89,6 @@ def main(backend, dl2): elif args.mpi: backend = "mpi" - dl2 = True - if args.dl1: - dl2 = False + dl2 = False main(backend, dl2) diff --git a/test/dataloader2/test_dataloader2.py b/test/dataloader2/test_dataloader2.py deleted file mode 100644 index a438b1ead..000000000 --- a/test/dataloader2/test_dataloader2.py +++ /dev/null @@ -1,815 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import multiprocessing as mp -import os -import pickle -import queue -import random -import socket -import unittest - -from unittest import TestCase - -import numpy as np - -import torch -import torch.distributed as dist -from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize - -from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES - -from torchdata.dataloader2 import ( - DataLoader2, - DistributedReadingService, - InProcessReadingService, - MultiProcessingReadingService, - ReadingServiceInterface, - SequentialReadingService, -) -from torchdata.dataloader2.dataloader2 import READING_SERVICE_STATE_KEY_NAME, SERIALIZED_DATAPIPE_KEY_NAME - -from torchdata.dataloader2.graph import DataPipe, list_dps, replace_dp, set_datapipes_seed, traverse_dps -from torchdata.dataloader2.random import SeedGenerator -from torchdata.datapipes.iter import IterableWrapper, IterDataPipe, ShardingRoundRobinDispatcher - -try: - import dill - - # XXX: By default, dill writes the Pickler dispatch table to inject its - # own logic there. This globally affects the behavior of the standard library - # pickler for any user who transitively depends on this module! - # Undo this extension to avoid altering the behavior of the pickler globally. - dill.extend(use_dill=False) - HAS_DILL = True -except ImportError: - HAS_DILL = False - -skipIfNoDill = unittest.skipIf(not HAS_DILL, "no dill") - -if dist.is_available(): - HAS_DIST = True -else: - HAS_DIST = False - -skipIfNoDistributed = unittest.skipIf(not HAS_DIST, "no torch.distributed") - -TEST_WITH_TSAN = os.getenv("PYTORCH_TEST_WITH_TSAN", "0") == "1" - -mp_ctx_parametrize = parametrize("ctx", mp.get_all_start_methods()) - -EXCEPTION_ITERATION_NUM = 7 - - -class _ReadingServiceWrapper: - def __init__(self, dp): - self.dp = dp - - def __iter__(self): - self.it = iter(self.dp) - return self - - def __next__(self): - return next(self.it) - - @staticmethod - def return_one(): - return 1 - - -class TestReadingService(ReadingServiceInterface): - def initialize(self, dp: DataPipe) -> DataPipe: - return _ReadingServiceWrapper(dp) # type: ignore[return-value] - - -class DataLoader2Test(TestCase): - def test_dataloader2(self) -> None: - test_data_pipe = IterableWrapper(range(3)) - data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe) - - expected_batch = 0 - for batch in iter(data_loader): - self.assertEqual(batch, expected_batch) - expected_batch += 1 - - def test_dataloader2_shutdown(self) -> None: - test_data_pipe = IterableWrapper(range(3)) - data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe) - data_loader.shutdown() - - def test_dataloader2_state_dict(self) -> None: - test_data_pipe = IterableWrapper(range(3)) - data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe) - - state = data_loader.state_dict() - self.assertIsNotNone(state) - self.assertIsNotNone(state[SERIALIZED_DATAPIPE_KEY_NAME]) - self.assertIsNone(state[READING_SERVICE_STATE_KEY_NAME]) - data_loader.shutdown() - - def test_dataloader2_reading_service(self) -> None: - test_data_pipe = IterableWrapper(range(3)) - reading_service = TestReadingService() - data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service) - - expected_batch = 0 - for batch in iter(data_loader): - self.assertEqual(batch, expected_batch) - expected_batch += 1 - - def test_dataloader2_load_state_dict(self) -> None: - test_data_pipe = IterableWrapper(range(3)) - reading_service = TestReadingService() - data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service) - - batch = next(iter(data_loader)) - self.assertEqual(batch, 0) - - state = data_loader.state_dict() - self.assertIsNotNone(state) - self.assertIsNotNone(state[SERIALIZED_DATAPIPE_KEY_NAME]) - self.assertIsNone(state[READING_SERVICE_STATE_KEY_NAME]) - data_loader.shutdown() - - restored_data_loader: DataLoader2 = DataLoader2(datapipe=None, reading_service=reading_service) - restored_data_loader.load_state_dict(state) - new_state = restored_data_loader.state_dict() - self.assertDictEqual(state, new_state) - - restored_data_loader_datapipe = restored_data_loader.datapipe - deserialized_datapipe = pickle.loads(state[SERIALIZED_DATAPIPE_KEY_NAME]) - for batch_1, batch_2 in zip(restored_data_loader_datapipe, deserialized_datapipe): - self.assertEqual(batch_1, batch_2) - - self.assertEqual( - restored_data_loader.reading_service_state, - state[READING_SERVICE_STATE_KEY_NAME], - ) - - restored_data_loader.shutdown() - - def test_dataloader2_iterates_correctly(self) -> None: - test_data_pipe = IterableWrapper(range(10)).sharding_filter() - reading_services = [ - None, - TestReadingService(), - MultiProcessingReadingService(num_workers=4), - MultiProcessingReadingService(num_workers=4, worker_prefetch_cnt=0), - ] - for reading_service in reading_services: - data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service) - self.assertEqual(list(range(10)), list(data_loader)) - self.assertEqual(list(range(10)), list(data_loader)) - self.assertEqual(list(range(10)), list(data_loader)) - actual = [] - for i in data_loader: - actual.append(i) - self.assertEqual(list(range(10)), actual) - actual = [] - for i in data_loader: - actual.append(i) - self.assertEqual(list(range(10)), actual) - - def test_dataloader2_reset(self) -> None: - test_data_pipe = IterableWrapper(range(10)) - reading_services = [None, TestReadingService(), MultiProcessingReadingService(num_workers=1)] - - for reading_service in reading_services: - data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=reading_service) - - # Functional Test: Ensure multiple sequential reads of DL2 is possible - self.assertEqual(list(range(10)), list(data_loader)) - self.assertEqual(list(range(10)), list(data_loader)) - self.assertEqual(list(range(10)), list(data_loader)) - - # Functional Test: Ensure that the creation of a new iterator invalidates the old one - it1 = iter(data_loader) - self.assertEqual(0, next(it1)) - self.assertEqual(1, next(it1)) - it2 = iter(data_loader) - self.assertEqual(0, next(it2)) - self.assertEqual(1, next(it2)) - with self.assertRaisesRegex(RuntimeError, "iterator has been invalidated"): - next(it1) - self.assertEqual(list(range(2, 10)), list(it2)) - - def test_dataloader2_delegate_attribute(self) -> None: - test_data_pipe = IterableWrapper(range(10)) - data_loader: DataLoader2 = DataLoader2(datapipe=test_data_pipe, reading_service=TestReadingService()) - - # Functional Test: Ensure multiple sequential reads of DL2 is possible - self.assertEqual(list(range(10)), list(data_loader)) - self.assertEqual(list(range(10)), list(data_loader)) - - # Functional Test: Ensure that attribute/method of `dataloader._datapipe_iter` can be used - it = iter(data_loader) - self.assertEqual(1, it.return_one()) # type: ignore[attr-defined] - - -class DataLoader2ConsistencyTest(TestCase): - r""" - These tests ensure that the behaviors of `DataLoader2` are consistent across `ReadingServices` and potentially - with `DataLoaderV1`. - """ - - @staticmethod - def _get_no_reading_service(): - return None - - @staticmethod - def _get_mp_reading_service(): - return MultiProcessingReadingService(num_workers=2) - - @staticmethod - def _get_in_process_reading_service(): - return InProcessReadingService() - - def _collect_data(self, datapipe, reading_service_gen): - dl: DataLoader2 = DataLoader2(datapipe, reading_service=reading_service_gen()) - result = [] - # Testing how RS handles partial reading and reiterations - for row, _ in zip(dl, range(10)): - result.append(row) - for row in dl: - result.append(row) - dl.shutdown() - return result - - @staticmethod - def _no_op(x): - return x - - def test_dataloader2_batch_collate(self) -> None: - dp: IterDataPipe = IterableWrapper(range(100)).batch(2).sharding_filter().collate(self._no_op) # type: ignore[assignment] - expected = self._collect_data(dp, reading_service_gen=self._get_no_reading_service) - - reading_service_generators = ( - self._get_mp_reading_service, - self._get_in_process_reading_service, - ) - for reading_service_gen in reading_service_generators: - actual = self._collect_data(dp, reading_service_gen=reading_service_gen) - # TODO(588): This comparison only indicates that somethings is broken and not helping with debug - self.assertEqual(expected, actual, reading_service_gen) - - def test_dataloader2_shuffle(self) -> None: - # TODO(589): Add shuffle test - pass - - -def _x_mult_2(d): - return d * 2 - - -class NonReplicableDataPipe(IterDataPipe): - def __init__(self, datapipe): - self.datapipe = datapipe - - def __iter__(self): - yield from self.datapipe - - def is_replicable(self): - return False - - -class _CustomException(Exception): - pass - - -class MakeMistakeDataPipe(IterDataPipe): - def __init__(self, source_datapipe, exc_iteration=EXCEPTION_ITERATION_NUM): - self.source_datapipe = source_datapipe - self.exc_iteration = exc_iteration - - def __iter__(self): - for i, x in enumerate(self.source_datapipe): - if i == self.exc_iteration: - raise _CustomException("oops") - yield x - - -class MultiProcessingReadingServiceTest(TestCase): - @staticmethod - def _worker_init_fn(datapipe, worker_info): - datapipe = datapipe.sharding_filter() - torch.utils.data.graph_settings.apply_sharding( - datapipe, worker_info.num_workers, worker_info.worker_id, SHARDING_PRIORITIES.MULTIPROCESSING - ) - return datapipe - - @staticmethod - def _worker_reset_fn(datapipe, worker_info, worker_seed_generator: SeedGenerator): - graph = traverse_dps(datapipe) - dps = list_dps(graph) - worker_seed_generator.seed(123) - set_datapipes_seed(dps, seed_generator=worker_seed_generator, distributed_shared=True) - return datapipe - - @mp_ctx_parametrize - def test_worker_fns(self, ctx): - dp: IterDataPipe = IterableWrapper(range(100)).batch(2).shuffle() - - rs = MultiProcessingReadingService( - num_workers=2, - multiprocessing_context=ctx, - worker_init_fn=self._worker_init_fn, - worker_reset_fn=self._worker_reset_fn, - ) - dl = DataLoader2(dp, reading_service=rs) - - res1 = list(dl) - res2 = list(dl) - - # Test worker_init_fn to set sharding - def _expand_fn(res): - result = [] - for batch in res: - result.extend(batch) - return result - - exp = list(range(100)) - self.assertEqual(sorted(_expand_fn(res1)), exp) - self.assertEqual(sorted(_expand_fn(res2)), exp) - - # Test worker_reset_fn to set the same random seed across epoches - self.assertEqual(res1, res2) - - @mp_ctx_parametrize - def test_single_branch_non_replicable(self, ctx): - r""" - For single branch pipeline with a non-replicable DataPipe, all ``sharding_filters`` - in the pipeline become non-replicable. - """ - - def _make_dp(): - single_br_dp = IterableWrapper(list(range(10))).shuffle() - map_dp = single_br_dp.map(_x_mult_2) - end_dp = map_dp.map(_x_mult_2).shuffle() - return single_br_dp, map_dp, end_dp - - def _assert_deterministic_dl_res(dl, exp): - torch.manual_seed(123) - res = list(dl) - self.assertEqual(sorted(res), exp) - # Second epoch - torch.manual_seed(123) - self.assertEqual(list(dl), res) - # Different seed - torch.manual_seed(321) - self.assertNotEqual(list(dl), res) - # Properly shutdown - dl.shutdown() - - # By-default, all replicable - single_br_dp, _, end_dp = _make_dp() - graph = traverse_dps(end_dp) - sf_dp = single_br_dp.sharding_filter() - replace_dp(graph, single_br_dp, sf_dp) - dl = DataLoader2( - end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) - ) - # Determinism and dynamic sharding - # _assert_deterministic_dl_res(dl, [i * 4 for i in range(10)]) - - # Non-replicable before sharding_filter - # shuffle in dispatch process - single_br_dp, map_dp, end_dp = _make_dp() - graph = traverse_dps(end_dp) - round_robin_dispatcher = ShardingRoundRobinDispatcher(single_br_dp, SHARDING_PRIORITIES.MULTIPROCESSING) - replace_dp(graph, single_br_dp, round_robin_dispatcher) - sf_dp = map_dp.sharding_filter() - replace_dp(graph, map_dp, sf_dp) - dl = DataLoader2( - end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) - ) - # Determinism for non-replicable pipeline - _assert_deterministic_dl_res(dl, [i * 4 for i in range(10)]) - - # Non-replicable after sharding_filter - # shuffle in dispatch process - single_br_dp, map_dp, end_dp = _make_dp() - graph = traverse_dps(end_dp) - sf_dp = single_br_dp.sharding_filter() - replace_dp(graph, single_br_dp, sf_dp) - round_robin_dispatcher = ShardingRoundRobinDispatcher(map_dp, SHARDING_PRIORITIES.MULTIPROCESSING) - replace_dp(graph, map_dp, round_robin_dispatcher) - dl = DataLoader2( - end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) - ) - # Determinism for non-replicable pipeline - _assert_deterministic_dl_res(dl, [i * 4 for i in range(10)]) - - @mp_ctx_parametrize - def test_multi_branch_non_replicable(self, ctx) -> None: - r""" - For multi-branch pipeline with a non-replicable DataPipe on one branch, - all ``sharding_filter`` on the other branches should remain replicable. - """ - - def _make_dp(): - branch1_dp = IterableWrapper(list(range(10))).shuffle() - branch2_dp = IterableWrapper(list(range(10))).shuffle() - map_dp = branch1_dp.map(_x_mult_2) - end_dp = map_dp.zip(branch2_dp) - return branch1_dp, map_dp, branch2_dp, end_dp - - def _assert_deterministic_dl_res(dl, exp1, exp2): - torch.manual_seed(123) - res = list(dl) - res1, res2 = list(zip(*res)) - self.assertEqual(sorted(res1), exp1) - self.assertEqual(sorted(res2), exp2) - # Second epoch - torch.manual_seed(123) - self.assertEqual(list(dl), res) - # Different seed - torch.manual_seed(321) - self.assertNotEqual(list(dl), res) - # Properly shutdown - dl.shutdown() - - # By-default, all replicable - branch1_dp, _, branch2_dp, end_dp = _make_dp() - graph = traverse_dps(end_dp) - sf1_dp = branch1_dp.sharding_filter() - sf2_dp = branch2_dp.sharding_filter() - replace_dp(graph, branch1_dp, sf1_dp) - replace_dp(graph, branch2_dp, sf2_dp) - dl = DataLoader2( - end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) - ) - # Determinism and dynamic sharding - _assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10))) - - # Non-replicable on one branch - # shuffle in dispatch process - branch1_dp, _, branch2_dp, end_dp = _make_dp() - graph = traverse_dps(end_dp) - non_replicable_dp = ShardingRoundRobinDispatcher(branch1_dp, SHARDING_PRIORITIES.MULTIPROCESSING) - replace_dp(graph, branch1_dp, non_replicable_dp) - # The other branch should has a sharding_filter to make data even - sf_dp = branch2_dp.sharding_filter() - replace_dp(graph, branch2_dp, sf_dp) - dl = DataLoader2( - end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) - ) - # Determinism for non-replicable pipeline - _assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10))) - - # Non-replicable on both branches - # shuffle in dispatch process - branch1_dp, _, branch2_dp, end_dp = _make_dp() - graph = traverse_dps(end_dp) - non_replicable_dp1 = ShardingRoundRobinDispatcher(branch1_dp, SHARDING_PRIORITIES.MULTIPROCESSING) - replace_dp(graph, branch1_dp, non_replicable_dp1) - non_replicable_dp2 = ShardingRoundRobinDispatcher(branch2_dp, SHARDING_PRIORITIES.MULTIPROCESSING) - replace_dp(graph, branch2_dp, non_replicable_dp2) - dl = DataLoader2( - end_dp, reading_service=MultiProcessingReadingService(num_workers=2, multiprocessing_context=ctx) - ) - # Determinism for non-replicable pipeline - _assert_deterministic_dl_res(dl, [i * 2 for i in range(10)], list(range(10))) - - @mp_ctx_parametrize - def test_multi_worker_determinism(self, ctx): - dp: IterDataPipe = IterableWrapper(range(100)) - dp = dp.shuffle().sharding_filter() - dp = dp.batch(2) - - rs = MultiProcessingReadingService( - num_workers=2, - multiprocessing_context=ctx, - ) - dl = DataLoader2(dp, reading_service=rs) - - torch.manual_seed(123) - res = list(dl) + list(dl) - - torch.manual_seed(123) - self.assertEqual(res, list(dl) + list(dl)) - - torch.manual_seed(321) - self.assertNotEqual(res, list(dl) + list(dl)) - - # Using seed API for DataLoader2 - dl.seed(123) - res = list(dl) + list(dl) - - dl.seed(123) - self.assertEqual(res, list(dl) + list(dl)) - - dl.seed(321) - self.assertNotEqual(res, list(dl) + list(dl)) - - @mp_ctx_parametrize - def test_dispatching_worker_determinism(self, ctx): - dp: IterDataPipe = IterableWrapper(range(101)) - dp = dp.shuffle().sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING) - dp = dp.batch(2) - - rs = MultiProcessingReadingService( - num_workers=2, - multiprocessing_context=ctx, - ) - dl = DataLoader2(dp, reading_service=rs) - - torch.manual_seed(123) - res = list(dl) + list(dl) - - torch.manual_seed(123) - self.assertEqual(res, list(dl) + list(dl)) - - torch.manual_seed(321) - self.assertNotEqual(res, list(dl) + list(dl)) - - # Using seed API for DataLoader2 - dl.seed(123) - res = list(dl) + list(dl) - - dl.seed(123) - self.assertEqual(res, list(dl) + list(dl)) - - dl.seed(321) - self.assertNotEqual(res, list(dl) + list(dl)) - - @mp_ctx_parametrize - def test_non_replicable_datapipe(self, ctx) -> None: - r""" - For the pipeline with non-replicable DataPipe, make sure - the DataPipe remains in the main process. - """ - dp: IterDataPipe = IterableWrapper(range(100)) - dp = dp.shuffle().sharding_filter() - dp = dp.batch(2) - non_rep_dp = NonReplicableDataPipe(dp) - - rs = MultiProcessingReadingService( - num_workers=2, - multiprocessing_context=ctx, - ) - dl = DataLoader2(non_rep_dp, reading_service=rs) - - torch.manual_seed(123) - it = iter(dl) - # Validate NonReplicableDataPipe still in the main process - non_rep_dp = dl.reading_service._end_datapipe - self.assertEqual(type(non_rep_dp), NonReplicableDataPipe) - - res = list(it) + list(dl) - - torch.manual_seed(123) - self.assertEqual(res, list(dl) + list(dl)) - - torch.manual_seed(321) - self.assertNotEqual(res, list(dl) + list(dl)) - - @parametrize("num_workers", [1, 3]) - @parametrize("worker_prefetch_cnt", [0, 5, 10]) - def test_worker_exception_raised(self, num_workers, worker_prefetch_cnt): - dp = IterableWrapper(range(100)).sharding_filter() - dp = MakeMistakeDataPipe(dp) - rs = MultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt) - dl = DataLoader2(dp, reading_service=rs) - it = iter(dl) - for _ in range(EXCEPTION_ITERATION_NUM * num_workers): - next(it) - with self.assertRaises(_CustomException) as cm: - next(it) - exc_msg = str(cm.exception) - self.assertTrue("Caught _CustomException in worker process 0" in exc_msg) - self.assertTrue("Original Traceback" in exc_msg) - self.assertTrue("_CustomException: oops" in exc_msg) - - @parametrize("num_workers", [1, 3]) - @parametrize("worker_prefetch_cnt", [0, 5, 10]) - def test_dispatching_exception_raised(self, num_workers, worker_prefetch_cnt): - dp = IterableWrapper(range(100)) - dp = MakeMistakeDataPipe(dp) - dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING) - dp = dp.map(_x_mult_2) - rs = MultiProcessingReadingService(num_workers=num_workers, worker_prefetch_cnt=worker_prefetch_cnt) - dl = DataLoader2(dp, reading_service=rs) - it = iter(dl) - for _ in range(EXCEPTION_ITERATION_NUM): - next(it) - with self.assertRaises(_CustomException) as cm: - next(it) - exc_msg = str(cm.exception) - self.assertTrue("Caught _CustomException in dispatching process" in exc_msg) - self.assertTrue("Original Traceback" in exc_msg) - self.assertTrue("_CustomException: oops" in exc_msg) - - -TEST_MASTER_ADDR = "127.0.0.1" -DEFAULT_WORLD_SIZE = 2 - - -def _get_open_port(): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("", 0)) - port = s.getsockname()[1] - s.close() - return str(port) - - -class TerminateSignal: - pass - - -def _launch_distributed_training(world_size, *args, fn): - os.environ["MASTER_ADDR"] = TEST_MASTER_ADDR - os.environ["MASTER_PORT"] = _get_open_port() - ctx = mp.get_context("spawn") - q = ctx.Queue() - ps = [] - for rank in range(world_size): - p = ctx.Process( - target=fn, - args=( - rank, - world_size, - q, - *args, - ), - ) - p.start() - ps.append(p) - res = [] - while True: - try: - d = q.get() - if isinstance(d, TerminateSignal): - break - res.append(d) - except queue.Empty: - continue - for p in ps: - p.join() - return res - - -def _dist_one_epoch(dl): - res = [] - for d in dl: - res.append(d) - # Simulate training synchronization - dist.barrier() - return res - - -def _finalize_distributed_queue(rank, q): - r""" - Synchronize all distributed processes to guarantee all data have been put into - the Multiprocessing Queue. - """ - pg = dist.new_group(backend="gloo") - end_tensor = torch.tensor([rank], dtype=torch.int64) - dist.all_reduce(end_tensor, group=pg) - if rank == 0: - q.put(TerminateSignal()) - - dist.destroy_process_group(pg) - - -def _random_fn(data): - r""" - Used to validate the randomness of subprocess-local RNGs are set deterministically. - """ - py_random_num = random.randint(0, 2 ** 32) - np_random_num = np.random.randint(0, 2 ** 32) - torch_random_num = torch.randint(0, 2 ** 32, size=[]).item() - return (data, py_random_num, np_random_num, torch_random_num) - - -def _dist_training_fn(rank, world_size, q, dp_fn, rs_fn, num_workers, ctx): - # Use gloo - dist.init_process_group("gloo", rank=rank, world_size=world_size) - - # Uneven shards - data_length = world_size * num_workers * 10 + 1 - dp = dp_fn(data_length) - rs = rs_fn(num_workers, ctx) - dl = DataLoader2(dp, reading_service=rs) - - # No seed - res = _dist_one_epoch(dl) - q.put((0, rank, res)) - - # Shuffle with seed - for epoch in range(2): - dl.seed(123) - res = _dist_one_epoch(dl) - q.put((epoch + 1, rank, res)) - - # Different seed - dl.seed(321) - res = _dist_one_epoch(dl) - q.put((3, rank, res)) - - _finalize_distributed_queue(rank, q) - - dl.shutdown() - - -@skipIfNoDistributed -@unittest.skipIf(IS_WINDOWS, "Remove when https://github.com/pytorch/data/issues/857 is fixed") -class SequentialReadingServiceTest(TestCase): - @staticmethod - def _make_dp(data_length): - data_source = IterableWrapper(list(range(data_length))) - dp = data_source.shuffle().sharding_filter().map(_random_fn) - return dp - - @staticmethod - def _make_dispatching_dp(data_length): - data_source = IterableWrapper(list(range(data_length))) - dp = data_source.shuffle().sharding_filter() - dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING).map(_random_fn) - return dp - - @staticmethod - def _make_rs(num_workers, ctx): - mp_rs = MultiProcessingReadingService( - num_workers=num_workers, - multiprocessing_context=ctx, - ) - dist_rs = DistributedReadingService() - rs = SequentialReadingService(dist_rs, mp_rs) - return rs - - @mp_ctx_parametrize - def test_sequential_reading_service_normal_dp(self, ctx): - world_size = DEFAULT_WORLD_SIZE - num_workers = 2 - res = _launch_distributed_training( - world_size, - SequentialReadingServiceTest._make_dp, - SequentialReadingServiceTest._make_rs, - num_workers, - ctx, - fn=_dist_training_fn, - ) - result = ({}, {}, {}, {}) - for epoch, rank, r in res: - d, *ran_nums = list(zip(*r)) - result[epoch][rank] = (d, ran_nums) - - # Guarantee the same length per rank - for rr in result: - exp_len = num_workers * 10 - for _, (d, _) in rr.items(): - self.assertEqual(len(d), exp_len) - - # Same seed generate the same order of data and the same random state - self.assertEqual(result[1], result[2]) - - # Different seeds - for rank in range(world_size): - # Different shuffle order - self.assertNotEqual(result[1][rank][0], result[3][rank][0]) - # Different subprocess-local random state - self.assertNotEqual(result[1][rank][1], result[3][rank][1]) - - @mp_ctx_parametrize - def test_sequential_reading_service_dispatching_dp(self, ctx): - world_size = DEFAULT_WORLD_SIZE - num_workers = 2 - res = _launch_distributed_training( - world_size, - SequentialReadingServiceTest._make_dispatching_dp, - SequentialReadingServiceTest._make_rs, - num_workers, - ctx, - fn=_dist_training_fn, - ) - result = ({}, {}, {}, {}) - for epoch, rank, r in res: - d, *ran_nums = list(zip(*r)) - result[epoch][rank] = (d, ran_nums) - - # Guarantee the same length per rank - for rr in result: - exp_len = num_workers * 10 - for _, (d, _) in rr.items(): - self.assertEqual(len(d), exp_len) - - # Same seed generate the same order of data and the same random state - self.assertEqual(result[1], result[2]) - - # Different seeds - for rank in range(world_size): - # Different shuffle order - self.assertNotEqual(result[1][rank][0], result[3][rank][0]) - # Different subprocess-local random state - self.assertNotEqual(result[1][rank][1], result[3][rank][1]) - - -instantiate_parametrized_tests(MultiProcessingReadingServiceTest) -instantiate_parametrized_tests(SequentialReadingServiceTest) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/dataloader2/test_mprs.py b/test/dataloader2/test_mprs.py deleted file mode 100644 index cbd0f4b27..000000000 --- a/test/dataloader2/test_mprs.py +++ /dev/null @@ -1,546 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import multiprocessing as mp -import unittest -from unittest import TestCase - -from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize, subtest - -from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES - -from torchdata.dataloader2 import ( - DataLoader2, - DataLoader2Iterator, - InProcessReadingService, - MultiProcessingReadingService, -) -from torchdata.datapipes.iter import IterableWrapper, IterDataPipe - - -def _add_one(x: int) -> int: - return x + 1 - - -# Test DataPipes -n_elements = 10 -dp1 = IterableWrapper(range(n_elements)).shuffle().sharding_filter() -double_pause_dp = dp1.prefetch().prefetch() -test_dps = [dp1, double_pause_dp] - - -mp_ctx_parametrize = parametrize("ctx", mp.get_all_start_methods()) -dp_parametrize = parametrize("dp", test_dps) - - -class TestInProcessReadingService(TestCase): - r""" - This tests specific functionalities of InProcessReadingService, notably - `pause`, `resume`, `snapshot`. - """ - - @dp_parametrize - def test_reading_service_pause_resume(self, dp) -> None: - - # Functional Test: Testing various configuration of DataPipe/ReadingService to ensure the pipeline - # properly pauses and resumes - rs1 = InProcessReadingService() - dl1: DataLoader2 = DataLoader2(dp, reading_service=rs1) - res = [] - for i, x in enumerate(dl1): - res.append(x) - if i in {2, n_elements - 2}: - dl1._pause() - dl1._resume() - - self.assertEqual(list(range(n_elements)), sorted(res)) - dl1.shutdown() - - rs2 = InProcessReadingService(5) - dl2: DataLoader2 = DataLoader2(dp, reading_service=rs2) - res = [] - for i, x in enumerate(dl2): - res.append(x) - if i in {2, n_elements - 2}: - dl2._pause() - dl2._resume() - - self.assertEqual(list(range(n_elements)), sorted(res)) - dl2.shutdown() - - @dp_parametrize - def test_reading_service_pause_stop_yield(self, dp) -> None: - - # Functional Test: Confirms that `dl` will stop yielding elements after `_pause` is called - rs = InProcessReadingService(5) - dl: DataLoader2 = DataLoader2(dp, reading_service=rs) - res = [] - for i, x in enumerate(dl): - res.append(x) - if i in {2}: - dl._pause() - self.assertEqual(3, len(res)) - dl.shutdown() - - @dp_parametrize - def test_reading_service_limit(self, dp) -> None: - - rs = InProcessReadingService(5) - - dl: DataLoader2 = DataLoader2(dp, reading_service=rs) - res = [] - cumulative_res = [] - n_limit = 3 - - it: DataLoader2Iterator = iter(dl) - it.limit(n_limit) - for x in it: - res.append(x) - # Functional Test: Verify that the number of elements yielded equals to the specified limit - self.assertEqual(n_limit, len(res)) # 3 - cumulative_res.extend(res) - - # Functional Test: Calling `next` after `limit` will trigger `StopIteration` - with self.assertRaises(StopIteration): - next(it) - - # Functional Test: Verify that `limit` persists without the need to set it again - it.resume() - res = [] - for x in it: - res.append(x) - self.assertEqual(n_limit, len(res)) # 3 - cumulative_res.extend(res) - - # Functional Test: Clear the `limit` and yield the rest of the elements - it.limit(None) - it.resume() - res = [] - for x in it: - res.append(x) - self.assertEqual(n_elements - 2 * n_limit, len(res)) # 4 - - cumulative_res.extend(res) - self.assertEqual(list(range(n_elements)), sorted(cumulative_res)) - - # Functional Test: Setting `limit` to a different value during after each mini-epoch - dl2: DataLoader2 = DataLoader2(double_pause_dp, reading_service=rs) - res = [] - it2: DataLoader2Iterator = iter(dl2) - it2.limit(3) - for x in it2: - res.append(x) - - # Limit can be set before `resume` - it2.limit(4) - it2.resume() - for x in it2: - res.append(x) - self.assertEqual(7, len(res)) - - # Limit can also be set after `resume`, but before the next `for` loop - it2.resume() - it2.limit(2) - for x in it2: - res.append(x) - self.assertEqual(9, len(res)) - - def test_initial_epoch_checkpointing(self): - dp = IterableWrapper(range(20)).shuffle() - rs = InProcessReadingService(5) - - # Functional Test: Saving state before iterator is created - dl: DataLoader2 = DataLoader2(datapipe=dp, reading_service=rs) - dl.seed(1) - initial_state = dl.state_dict() - it1 = iter(dl) - - restored_dl: DataLoader2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] - restored_dl._restore_checkpoint_beginning_of_epoch() - self.assertEqual(list(it1), list(restored_dl)) - - dl.shutdown() - restored_dl.shutdown() - - # Functional Test: Saving state after iterator is created - dl = DataLoader2(datapipe=dp, reading_service=rs) - dl.seed(1) - it1 = iter(dl) - initial_state = dl.state_dict() - - restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] - restored_dl._restore_checkpoint_beginning_of_epoch() - self.assertEqual(list(it1), list(restored_dl)) - - dl.shutdown() - restored_dl.shutdown() - - # Functional Test: Saving state after iterator is created and began iterating - dl = DataLoader2(datapipe=dp, reading_service=rs) - dl.seed(1) - it1 = iter(dl) - temp = next(it1) # Starts iterating - initial_state = dl.state_dict() - - restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] - restored_dl._restore_checkpoint_beginning_of_epoch() - - self.assertEqual([temp] + list(it1), list(restored_dl)) # Note skipping over 1st element from actual result - - dl.shutdown() - restored_dl.shutdown() - - -def _non_dispatching_dp(n_elements=1000): - dp = IterableWrapper(list(range(n_elements))).shuffle() - dp = dp.sharding_filter() - dp = dp.map(_add_one).batch(8) - return dp - - -def _dispatching_dp(n_elements=1000): - dp = IterableWrapper(list(range(n_elements))).shuffle() - dp = dp.prefetch(20) - dp = dp.sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING) - dp = dp.map(_add_one).batch(16) - return dp - - -class NonShardableDataPipe(IterDataPipe): - def __init__(self, dp: IterDataPipe): - self.dp = dp - - def is_replicable(self): - return False - - def __iter__(self): - yield from self.dp - - -class TestMultiProcessingReadingService(TestCase): - r""" - This tests specific functionalities of MultiProcessingReadingService, notably - `pause`, `resume`, `snapshot`. - """ - - @mp_ctx_parametrize - @parametrize("dp_fn", [subtest(_non_dispatching_dp, "non_dispatch"), subtest(_dispatching_dp, "dispatch")]) - @parametrize("main_prefetch", [0, 10]) - @parametrize("worker_prefetch", [0, 10]) - def test_early_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None: - dp = dp_fn(1000) - rs = MultiProcessingReadingService( - num_workers=2, - main_prefetch_cnt=main_prefetch, - worker_prefetch_cnt=worker_prefetch, - multiprocessing_context=ctx, - ) - dl: DataLoader2 = DataLoader2(dp, reading_service=rs) - it = iter(dl) - for _ in range(10): - _ = next(it) - dl.shutdown() - - @mp_ctx_parametrize - @parametrize("dp_fn", [subtest(_non_dispatching_dp, "non_dispatch"), subtest(_dispatching_dp, "dispatch")]) - @parametrize("main_prefetch", [0, 10]) - @parametrize("worker_prefetch", [0, 10]) - def test_exit(self, ctx, dp_fn, main_prefetch, worker_prefetch) -> None: - dp = dp_fn(1000) - rs = MultiProcessingReadingService( - num_workers=2, - main_prefetch_cnt=main_prefetch, - worker_prefetch_cnt=worker_prefetch, - multiprocessing_context=ctx, - ) - dl: DataLoader2 = DataLoader2(dp, reading_service=rs) - _ = list(dl) - dl.shutdown() - - @mp_ctx_parametrize - @dp_parametrize - @parametrize( - "n_workers,worker_prefetch_cnt,main_prefetch_cnt", - [(1, 0, 0), (1, 0, 2), (2, 0, 0), (2, 2, 0), (2, 0, 2), (2, 2, 2)], - ) - def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None: - - # Functional Test: Testing various configuration of DataPipe/ReadingService to ensure the pipeline - # properly pauses and resumes - rs = MultiProcessingReadingService( - num_workers=n_workers, - worker_prefetch_cnt=worker_prefetch_cnt, - main_prefetch_cnt=main_prefetch_cnt, - multiprocessing_context=ctx, - ) - dl: DataLoader2 = DataLoader2(dp, reading_service=rs) - res = [] - for i, x in enumerate(dl): - res.append(x) - if i in {2, n_elements - 2}: - dl._pause() - dl._resume() - - self.assertEqual( - list(range(n_elements)), - sorted(res), - msg=f"The test is failing with '{ctx}', num_workers = {rs.num_workers}, " - f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, " - f"main_prefetch_cnt = {rs.main_prefetch_cnt}", - ) - dl.shutdown() - - @mp_ctx_parametrize - @dp_parametrize - @parametrize("n_workers,worker_prefetch_cnt,main_prefetch_cnt", [(2, 0, 1), (2, 1, 0), (2, 0, 0)]) - def test_reading_service_pause_stop_yield(self, ctx, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None: - - # Functional Test: Confirms that `dl` will stop yielding elements after `_pause` is called - rs = MultiProcessingReadingService( - num_workers=n_workers, - worker_prefetch_cnt=worker_prefetch_cnt, - main_prefetch_cnt=main_prefetch_cnt, - multiprocessing_context=ctx, - ) - dl: DataLoader2 = DataLoader2(dp, reading_service=rs) - res = [] - for i, x in enumerate(dl): - res.append(x) - if i in {2}: - dl._pause() - self.assertEqual( - 3, - len(res), - msg=f"The test is failing with '{ctx}', num_workers = {rs.num_workers}, " - f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", - ) - dl.shutdown() - - @dp_parametrize - @parametrize("n_workers,worker_prefetch_cnt,main_prefetch_cnt", [(1, 0, 0), (1, 0, 2), (2, 0, 0), (2, 2, 2)]) - def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None: - - rs = MultiProcessingReadingService( - num_workers=n_workers, worker_prefetch_cnt=worker_prefetch_cnt, main_prefetch_cnt=main_prefetch_cnt - ) - - dl: DataLoader2 = DataLoader2(dp, reading_service=rs) - res = [] - cumulative_res = [] - n_limit = 3 - - it: DataLoader2Iterator = iter(dl) - it.limit(n_limit) - for x in it: - res.append(x) - # Functional Test: Verify that the number of elements yielded equals to the specified limit - self.assertEqual( - n_limit, - len(res), # 3 - msg=f"The test is failing with default multiprocessing method, " - f"num_workers = {rs.num_workers}, " - f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", - ) - cumulative_res.extend(res) - - # Functional Test: Calling `next` after `limit` will trigger `StopIteration` - with self.assertRaises(StopIteration): - next(it) - - # Functional Test: Verify that `limit` persists without the need to set it again - it.resume() - res = [] - for x in it: - res.append(x) - self.assertEqual( - n_limit, - len(res), # 3 - msg=f"The test is failing with default multiprocessing method, " - f"num_workers = {rs.num_workers}, " - f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", - ) - cumulative_res.extend(res) - - # Functional Test: Clear the `limit` and yield the rest of the elements - it.limit(None) - it.resume() - res = [] - for x in it: - res.append(x) - self.assertEqual( - n_elements - 2 * n_limit, - len(res), # 4 - msg=f"The test is failing with default multiprocessing method, " - f"num_workers = {rs.num_workers}, " - f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", - ) - - cumulative_res.extend(res) - self.assertEqual(list(range(n_elements)), sorted(cumulative_res)) - - # Functional Test: Setting `limit` to a different value during after each mini-epoch - dl2: DataLoader2 = DataLoader2(double_pause_dp, reading_service=rs) - res = [] - it2: DataLoader2Iterator = iter(dl2) - it2.limit(3) - for x in it2: - res.append(x) - - # Limit can be set before `resume` - it2.limit(4) - it2.resume() - for x in it2: - res.append(x) - self.assertEqual(7, len(res)) - - # Limit can also be set after `resume`, but before the next `for` loop - it2.resume() - it2.limit(2) - for x in it2: - res.append(x) - self.assertEqual(9, len(res)) - - def test_initial_epoch_checkpointing(self): - dp = IterableWrapper(range(20)).shuffle().sharding_filter() - # Note that the second `shuffle` occurs in the main process, which uses a different RNG from - # the `shuffle` done in the worker processes - dp = NonShardableDataPipe(dp).shuffle() # type: ignore[assignment, arg-type] - rs = MultiProcessingReadingService(num_workers=2) - - # Functional Test: Saving state before iterator is created - dl: DataLoader2 = DataLoader2(datapipe=dp, reading_service=rs) - dl.seed(1) - initial_state = dl.state_dict() - it1 = iter(dl) - - restored_dl: DataLoader2 = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] - restored_dl._restore_checkpoint_beginning_of_epoch() - self.assertEqual(list(it1), list(restored_dl)) - - dl.shutdown() - restored_dl.shutdown() - - # Functional Test: Saving state after iterator is created - dl = DataLoader2(datapipe=dp, reading_service=rs) - dl.seed(1) - it1 = iter(dl) - initial_state = dl.state_dict() - - restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] - restored_dl._restore_checkpoint_beginning_of_epoch() - self.assertEqual(list(it1), list(restored_dl)) - - dl.shutdown() - restored_dl.shutdown() - - # Functional Test: Saving state after iterator is created and began iterating - dl = DataLoader2(datapipe=dp, reading_service=rs) - dl.seed(1) - it1 = iter(dl) - temp = next(it1) # Starts iterating - initial_state = dl.state_dict() - - restored_dl = DataLoader2.from_state(initial_state, rs) # type: ignore[arg-type] - restored_dl._restore_checkpoint_beginning_of_epoch() - - self.assertEqual([temp] + list(it1), list(restored_dl)) # Note skipping over 1st element from actual result - - dl.shutdown() - restored_dl.shutdown() - - # TODO: Test cases when there is official support of `pause` and `resume` with round-robin sharding - # Currently, using sharding_round_robin raises a warning - # def test_round_robin_dispatching_pause_limit(self): - # source_dp = IterableWrapper(range(20)) - # dp = source_dp.shuffle().sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING) - # dp = dp.map(_add_one) - - # TODO: This doesn't work with `num_workers > 1` - # TODO: Try checking if `dp_list`'s elements are _IterateQueueDP or QueueWrapper, we can safely assume - # those DPs belong to a dispatching process and only do pause if worker_id == 0 - # There might still be a race condition, need to look into the messages - - # rs1 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0) - # rs2 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2) - # rs3 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0) - # rs4 = MultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2) - # rss = [rs1, rs2, rs3, rs4] - - # for n, rs in enumerate(rss): - # dl = DataLoader2(dp, reading_service=rs) - # res = [] - # # cumulative_res = [] - # n_limit = 3 - # - # it: DataLoader2Iterator = iter(dl) - # it.limit(n_limit) # The `pause` call here doesn't stop - # for x in it: - # res.append(x) - # - # print() - # print(res) - # - # dl.shutdown() - - # # Functional Test: Verify that the number of elements yielded equals to the specified limit - # # self.assertEqual( - # # n_limit, - # # len(res), # 3 - # # msg=f"The test is failing for rs{n + 1} with default multiprocessing method, " - # # f"num_workers = {rs.num_workers}, " - # # f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", - # # ) - # cumulative_res.extend(res) - # - # # Functional Test: Calling `next` after `limit` will trigger `StopIteration` - # with self.assertRaisesRegex(StopIteration, "pause"): - # next(it) - # - # # Functional Test: Verify that `limit` persists without the need to set it again - # it.resume() - # res = [] - # for x in it: - # res.append(x) - # # self.assertEqual( - # # n_limit, - # # len(res), # 3 - # # msg=f"The test is failing for rs{n + 1} with default multiprocessing method, " - # # f"num_workers = {rs.num_workers}, " - # # f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", - # # ) - # cumulative_res.extend(res) - # - # # Functional Test: Clear the `limit` and yield the rest of the elements - # it.limit(None) - # it.resume() - # res = [] - # for x in it: - # res.append(x) - # # self.assertEqual( - # # n_elements - 2 * n_limit, - # # len(res), # 4 - # # msg=f"The test is failing for rs{n + 1} with default multiprocessing method, " - # # f"num_workers = {rs.num_workers}, " - # # f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}", - # # ) - # - # cumulative_res.extend(res) - # self.assertEqual(list(range(n_elements)), sorted(cumulative_res)) - - # TODO: Implemented in an upcoming PR - # def test_reading_service_snapshot(self) -> None: - # pass - # - # def test_dataloader2_snapshot(self) -> None: - # pass - - -instantiate_parametrized_tests(TestInProcessReadingService) -instantiate_parametrized_tests(TestMultiProcessingReadingService) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/dataloader2/test_random.py b/test/dataloader2/test_random.py deleted file mode 100644 index 90f06d964..000000000 --- a/test/dataloader2/test_random.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import random -import unittest - -from unittest import TestCase - -import numpy as np - -import torch - -from torch.testing._internal.common_utils import instantiate_parametrized_tests, IS_WINDOWS, parametrize -from torchdata.dataloader2 import DataLoader2, InProcessReadingService, MultiProcessingReadingService -from torchdata.dataloader2.graph.settings import set_graph_random_seed -from torchdata.dataloader2.random import SeedGenerator -from torchdata.datapipes.iter import IterableWrapper - - -def _random_fn(data): - r""" - Used to validate the randomness of subprocess-local RNGs are set deterministically. - """ - py_random_num = random.randint(0, 2 ** 32) - np_random_num = np.random.randint(0, 2 ** 32, dtype=np.uint32) - torch_random_num = torch.randint(0, 2 ** 32, size=[]).item() - return (data, py_random_num, np_random_num, torch_random_num) - - -class DeterminismTest(TestCase): - @unittest.skipIf(IS_WINDOWS, "Remove when https://github.com/pytorch/data/issues/857 is fixed") - @parametrize("num_workers", [1, 8]) - def test_mprs_determinism(self, num_workers): - data_length = 64 - exp = list(range(data_length)) - - data_source = IterableWrapper(exp) - dp = data_source.shuffle().sharding_filter().map(_random_fn) - rs = MultiProcessingReadingService(num_workers=num_workers) - dl = DataLoader2(dp, reading_service=rs) - - # No seed - res = [] - for d, *_ in dl: - res.append(d) - self.assertEqual(sorted(res), exp) - - # Shuffle with seed - results = [] - for _ in range(2): - res = [] - ran_res = [] - torch.manual_seed(123) - random.seed(123) - np.random.seed(123) - for d, *ran_nums in dl: - res.append(d) - ran_res.append(ran_nums) - self.assertEqual(sorted(res), exp) - results.append((res, ran_res)) - # Same seed generate the same order of data and the same random state - self.assertEqual(results[0], results[1]) - - # Different seed - res = [] - ran_res = [] - torch.manual_seed(321) - random.seed(321) - np.random.seed(321) - for d, *ran_nums in dl: - res.append(d) - ran_res.append(ran_nums) - self.assertEqual(sorted(res), exp) - # Different shuffle order - self.assertNotEqual(results[0][0], res) - # Different subprocess-local random state - self.assertNotEqual(results[0][1], ran_res) - - def test_graph_random_settings(self): - def _get_dp_seeds_after_setting(worker_id, seed=123): - data_source = IterableWrapper(list(range(100))) - dp0 = data_source.shuffle() - dp1, dp2, dp3 = dp0.fork(3) - dp1 = dp1.sharding_filter() - dp2 = dp2.shuffle() - dp3 = dp3.shuffle() - dp3_ = dp3.sharding_filter() - dp4 = dp1.zip(dp2, dp3_).shuffle() - - sg = SeedGenerator(seed).spawn(worker_id) - set_graph_random_seed(dp4, sg) - - # same seeds, different seeds - return (dp0._seed, dp3._seed), (dp2._seed, dp4._seed) - - ss_0_123, ds_0_123 = _get_dp_seeds_after_setting(worker_id=0, seed=123) - ss_1_123, ds_1_123 = _get_dp_seeds_after_setting(worker_id=1, seed=123) - self.assertEqual(ss_0_123, ss_1_123) - self.assertNotEqual(ds_0_123, ds_1_123) - - ss_0_123_, ds_0_123_ = _get_dp_seeds_after_setting(worker_id=0, seed=123) - self.assertEqual(ss_0_123, ss_0_123_) - self.assertEqual(ds_0_123, ds_0_123_) - - ss_0_321, ds_0_321 = _get_dp_seeds_after_setting(worker_id=0, seed=321) - self.assertNotEqual(ss_0_123, ss_0_321) - self.assertNotEqual(ds_0_123, ds_0_321) - - def test_sprs_determinism(self): - data_length = 64 - exp = list(range(data_length)) - - data_source = IterableWrapper(exp) - dp = data_source.shuffle().sharding_filter().map(_random_fn) - rs = InProcessReadingService() - dl = DataLoader2(dp, reading_service=rs) - - # No seed - res = [] - for d, *_ in dl: - res.append(d) - self.assertEqual(sorted(res), exp) - - # Shuffle with seed - results = [] - for _ in range(2): - res = [] - ran_res = [] - torch.manual_seed(123) - random.seed(123) - np.random.seed(123) - for d, *ran_nums in dl: - res.append(d) - ran_res.append(ran_nums) - self.assertEqual(sorted(res), exp) - results.append((res, ran_res)) - # Same seed generate the same order of data and the same random state - self.assertEqual(results[0], results[1]) - - # Different seed - res = [] - ran_res = [] - torch.manual_seed(321) - random.seed(321) - np.random.seed(321) - for d, *ran_nums in dl: - res.append(d) - ran_res.append(ran_nums) - self.assertEqual(sorted(res), exp) - # Different shuffle order - self.assertNotEqual(results[0][0], res) - # Different subprocess-local random state - self.assertNotEqual(results[0][1], ran_res) - - -instantiate_parametrized_tests(DeterminismTest) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/requirements.txt b/test/requirements.txt index 6daa28c5b..169e812cb 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -1,16 +1,7 @@ pytest expecttest fsspec -s3fs -iopath == 0.1.9 numpy<2 -rarfile -portalocker >= 2.0.0 -# Protobuf 4.0 is binary incompatible with what C++ TF uses. -# See: https://github.com/tensorflow/tensorflow/blob/8dcaf6b98a6a49c85eb470140ba8506e71a3b5af/tensorflow/tools/pip_package/setup.py#L88-L94 -# Protobuf 3.20.2 is also broken on MacOS Python 3.10 -# See: https://github.com/protocolbuffers/protobuf/issues/10571 -protobuf >= 3.9.2, < 3.20 datasets @ git+https://github.com/huggingface/datasets@main graphviz adlfs diff --git a/test/requirements_aistore.txt b/test/requirements_aistore.txt deleted file mode 100644 index c491e3c12..000000000 --- a/test/requirements_aistore.txt +++ /dev/null @@ -1,2 +0,0 @@ -aistore >= 1.0.2 -pytest diff --git a/test/smoke_test/smoke_test.py b/test/smoke_test/smoke_test.py index 6b4319843..4e55db28f 100644 --- a/test/smoke_test/smoke_test.py +++ b/test/smoke_test/smoke_test.py @@ -5,17 +5,6 @@ # LICENSE file in the root directory of this source tree. -import argparse - -import torchdata -import torchdata.dataloader2 -import torchdata.datapipes - - -def s3_test(): - from torchdata._torchdata import S3Handler - - def stateful_dataloader_test(): from torchdata.stateful_dataloader import StatefulDataLoader @@ -24,11 +13,5 @@ def stateful_dataloader_test(): r""" TorchData Smoke Test """ - parser = argparse.ArgumentParser() - parser.add_argument("--no-s3", dest="s3", action="store_false") - - options = parser.parse_args() - if options.s3: - s3_test() stateful_dataloader_test() diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 7233369e1..5bd1e6161 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -918,18 +918,6 @@ def test_num_workers_mismatch(self): self.assertTrue(False, "Error should be of type AssertionError") -class TestTorchDataLazyImport_shard3(TestCase): - def test_lazy_imports(self) -> None: - import torchdata - - self.assertFalse("datapipes" in torchdata.__dict__) - - from torchdata import datapipes as dp, janitor # noqa - - self.assertTrue("datapipes" in torchdata.__dict__) - dp.iter.IterableWrapper([1, 2]) - - class TestConcurrentDataLoaders_shard3(TestCase): def test_two_dataloaders(self) -> None: dataset = DummyMapDataset(100, shuffle=False) diff --git a/test/test_adapter.py b/test/test_adapter.py deleted file mode 100644 index dfa2f1261..000000000 --- a/test/test_adapter.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from unittest import TestCase - -from torchdata.dataloader2 import DataLoader2 -from torchdata.dataloader2.adapter import Shuffle -from torchdata.datapipes.iter import IterableWrapper - - -class AdapterTest(TestCase): - def test_shuffle(self) -> None: - size = 500 - dp = IterableWrapper(range(size)) - - dl = DataLoader2(datapipe=dp) - self.assertEqual(list(range(size)), list(dl)) - - with self.assertWarns(Warning, msg="`shuffle=True` was set, but the datapipe does not contain a `Shuffler`."): - dl = DataLoader2(datapipe=dp, datapipe_adapter_fn=Shuffle(True)) - self.assertNotEqual(list(range(size)), list(dl)) - - dp = IterableWrapper(range(size)).shuffle() - - dl = DataLoader2(datapipe=dp) - self.assertNotEqual(list(range(size)), list(dl)) - - dl = DataLoader2(dp, Shuffle(True)) - self.assertNotEqual(list(range(size)), list(dl)) - - dl = DataLoader2(dp, [Shuffle(None)]) - self.assertNotEqual(list(range(size)), list(dl)) - - dl = DataLoader2(dp, [Shuffle(False)]) - self.assertEqual(list(range(size)), list(dl)) diff --git a/test/test_aistore.py b/test/test_aistore.py deleted file mode 100644 index 49b37a8a5..000000000 --- a/test/test_aistore.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import random -import string -import tempfile -import unittest - -from torchdata.datapipes.iter import AISFileLister, AISFileLoader - -try: - from aistore.client.api import Client - from aistore.client.errors import AISError, ErrBckNotFound - - AIS_CLUSTER_ENDPT = "http://localhost:8080" - - HAS_AIS = Client(AIS_CLUSTER_ENDPT).cluster().is_aistore_running() -except (ImportError, ConnectionError): - HAS_AIS = False -skipIfNoAIS = unittest.skipIf(not HAS_AIS, "AIS not running or library not installed") - - -@skipIfNoAIS -class TestAIStoreIODataPipe(unittest.TestCase): - def setUp(self): - # initialize client and create new bucket - self.client = Client(AIS_CLUSTER_ENDPT) - letters = string.ascii_lowercase - self.bck_name = "".join(random.choice(letters) for _ in range(10)) - self.client.bucket(self.bck_name).create() - # create temp files - num_objs = 10 - - # create 10 objects in the `/temp` dir - for i in range(num_objs): - object_body = "test string" * random.randrange(1, 10) - content = object_body.encode("utf-8") - obj_name = f"temp/obj{ i }" - with tempfile.NamedTemporaryFile() as file: - file.write(content) - file.flush() - self.client.bucket(self.bck_name).object(obj_name).put(file.name) - - # create 10 objects in the `/`dir - for i in range(num_objs): - object_body = "test string" * random.randrange(1, 10) - content = object_body.encode("utf-8") - obj_name = f"obj{ i }" - with tempfile.NamedTemporaryFile() as file: - file.write(content) - file.flush() - self.client.bucket(self.bck_name).object(obj_name).put(file.name) - - def tearDown(self): - # Try to destroy bucket and its items - try: - self.client.bucket(self.bck_name).delete() - except ErrBckNotFound: - pass - - def test_ais_io_iterdatapipe(self): - - prefixes = [ - ["ais://" + self.bck_name], - ["ais://" + self.bck_name + "/"], - ["ais://" + self.bck_name + "/temp/", "ais://" + self.bck_name + "/obj"], - ] - - # check if the created files exist - for prefix in prefixes: - urls = AISFileLister(url=AIS_CLUSTER_ENDPT, source_datapipe=prefix) - ais_loader = AISFileLoader(url=AIS_CLUSTER_ENDPT, source_datapipe=urls) - with self.assertRaises(TypeError): - len(urls) - self.assertEqual(len(list(urls)), 20) - self.assertEqual(sum(1 for _ in ais_loader), 20) - - # check for incorrect prefixes - prefixes = ["ais://asdasd"] - - # AISFileLister: Bucket not found - try: - list(AISFileLister(url=AIS_CLUSTER_ENDPT, source_datapipe=prefixes)) - except ErrBckNotFound as err: - self.assertEqual(err.status_code, 404) - - # AISFileLoader: incorrect inputs - url_list = [[""], ["ais:"], ["ais://"], ["s3:///unkown-bucket"]] - - for url in url_list: - with self.assertRaises(AISError): - file_loader = AISFileLoader(url=AIS_CLUSTER_ENDPT, source_datapipe=url) - for _ in file_loader: - pass - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_audio_examples.py b/test/test_audio_examples.py deleted file mode 100644 index 945258b8e..000000000 --- a/test/test_audio_examples.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import sys -import tempfile -import unittest - -import torch.multiprocessing as mp - -from torch.testing._internal.common_utils import slowTest -from torch.utils.data import DataLoader - -current = os.path.dirname(os.path.abspath(__file__)) -ROOT = os.path.dirname(current) -sys.path.insert(0, ROOT) - -from examples.audio.librispeech import LibriSpeech - - -class TestAudioExamples(unittest.TestCase): - def setUp(self): - self.temp_dir = tempfile.TemporaryDirectory() - - def tearDown(self): - self.temp_dir.cleanup() - - def _test_helper(self, fn, *args, **kwargs): - dp = fn(*args, **kwargs) - _ = list(dp) - - @staticmethod - def _collate_fn(batch): - return batch - - def _test_DL_helper(self, fn, *args, **kwargs): - dp = fn(*args, **kwargs) - mp.set_sharing_strategy("file_system") - dl = DataLoader( - dp, - batch_size=8, - num_workers=4, - collate_fn=TestAudioExamples._collate_fn, - multiprocessing_context="fork", # Using Fork her because `torchaudio.load` doesn't work well with spawn - ) - for _ in dl: - pass - - @slowTest - def test_LibriSpeech_dev(self) -> None: - root = self.temp_dir.name - self._test_helper(LibriSpeech, root, "dev-other") - # With cache and DataLoader - self._test_DL_helper(LibriSpeech, root, "dev-other") - - @unittest.skipIf(True, "Dataset is too large to run on CI") - def test_LibriSpeech_train(self) -> None: - root = self.temp_dir.name - self._test_helper(LibriSpeech, root, "train-clean-100") - # With cache and DataLoader - self._test_DL_helper(LibriSpeech, root, "train-clean-100") - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_dataframe.py b/test/test_dataframe.py deleted file mode 100644 index 0104076bb..000000000 --- a/test/test_dataframe.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import unittest -import warnings -from itertools import chain - -import expecttest -from _utils._common_utils_for_test import create_temp_dir, reset_after_n_next_calls -from torchdata.datapipes.iter import DataFrameMaker, FileLister, FileOpener, IterableWrapper, ParquetDataFrameLoader - -try: - import torcharrow - import torcharrow.dtypes as dt - - HAS_TORCHARROW = True -except ImportError: - HAS_TORCHARROW = False - -try: - import pyarrow - import pyarrow.parquet as parquet - - HAS_PYARROW = True -except ImportError: - HAS_PYARROW = False -skipIfNoPyArrow = unittest.skipIf(not HAS_PYARROW, "no PyArrow.") -skipIfNoTorchArrow = unittest.skipIf(not HAS_TORCHARROW, "no TorchArrow.") - - -@skipIfNoTorchArrow -class TestDataFrame(expecttest.TestCase): - def setUp(self) -> None: - self.temp_dir = create_temp_dir() - if HAS_PYARROW: - self._write_parquet_files() - - def tearDown(self) -> None: - try: - self.temp_dir.cleanup() - except Exception as e: - warnings.warn(f"TestDataFrame was not able to cleanup temp dir due to {e}") - - def _write_parquet_files(self): - # Create TorchArrow DataFrames - DTYPE = dt.Struct([dt.Field("Values", dt.int32)]) - df1 = torcharrow.dataframe([(i,) for i in range(10)], dtype=DTYPE) - df2 = torcharrow.dataframe([(i,) for i in range(100)], dtype=DTYPE) - # Write them as parquet files - for i, df in enumerate([df1, df2]): - fname = f"df{i}.parquet" - self._write_df_as_parquet(df, fname) - self._write_multiple_dfs_as_parquest([df1, df2], fname="merged.parquet") - - def _custom_files_set_up(self, files): - for fname, content in files.items(): - temp_file_path = os.path.join(self.temp_dir.name, fname) - with open(temp_file_path, "w") as f: - f.write(content) - - def _compare_dataframes(self, expected_df, actual_df): - self.assertEqual(len(expected_df), len(actual_df)) - for exp, act in zip(expected_df, actual_df): - self.assertEqual(exp, act) - - def _write_df_as_parquet(self, df, fname: str) -> None: - table = df.to_arrow() - parquet.write_table(table, os.path.join(self.temp_dir.name, fname)) - - def _write_multiple_dfs_as_parquest(self, dfs, fname: str) -> None: - tables = [df.to_arrow() for df in dfs] - merged_table = pyarrow.concat_tables(tables) - parquet.write_table(merged_table, os.path.join(self.temp_dir.name, fname)) - - def test_dataframe_maker_iterdatapipe(self): - source_data = [(i,) for i in range(10)] - source_dp = IterableWrapper(source_data) - DTYPE = dt.Struct([dt.Field("Values", dt.int32)]) - - # Functional Test: DataPipe correctly converts into a single TorchArrow DataFrame - df_dp = source_dp.dataframe(dtype=DTYPE) - df = list(df_dp)[0] - expected_df = torcharrow.dataframe([(i,) for i in range(10)], dtype=DTYPE) - self._compare_dataframes(expected_df, df) - - # Functional Test: DataPipe correctly converts into multiple TorchArrow DataFrames, based on size argument - df_dp = DataFrameMaker(source_dp, dataframe_size=5, dtype=DTYPE) - dfs = list(df_dp) - expected_dfs = [ - torcharrow.dataframe([(i,) for i in range(5)], dtype=DTYPE), - torcharrow.dataframe([(i,) for i in range(5, 10)], dtype=DTYPE), - ] - for exp_df, act_df in zip(expected_dfs, dfs): - self._compare_dataframes(exp_df, act_df) - - # __len__ Test: - df_dp = source_dp.dataframe(dtype=DTYPE) - self.assertEqual(1, len(df_dp)) - self.assertEqual(10, len(list(df_dp)[0])) - df_dp = source_dp.dataframe(dataframe_size=5, dtype=DTYPE) - self.assertEqual(2, len(df_dp)) - self.assertEqual(5, len(list(df_dp)[0])) - - # Reset Test: - n_elements_before_reset = 1 - res_before_reset, res_after_reset = reset_after_n_next_calls(df_dp, n_elements_before_reset) - for exp_df, act_df in zip(expected_dfs[:1], res_before_reset): - self._compare_dataframes(exp_df, act_df) - for exp_df, act_df in zip(expected_dfs, res_after_reset): - self._compare_dataframes(exp_df, act_df) - - def test_dataframe_maker_with_csv(self): - def get_name(path_and_stream): - return os.path.basename(path_and_stream[0]), path_and_stream[1] - - csv_files = {"1.csv": "key,item\na,1\nb,2"} - self._custom_files_set_up(csv_files) - datapipe1 = FileLister(self.temp_dir.name, "*.csv") - datapipe2 = FileOpener(datapipe1, mode="b") - datapipe3 = datapipe2.map(get_name) - csv_dict_parser_dp = datapipe3.parse_csv_as_dict() - - # Functional Test: Correctly generate TorchArrow DataFrame from CSV - DTYPE = dt.Struct([dt.Field("key", dt.string), dt.Field("item", dt.string)]) - df_dp = csv_dict_parser_dp.dataframe(dtype=DTYPE, columns=["key", "item"]) - expected_dfs = [torcharrow.dataframe([{"key": "a", "item": "1"}, {"key": "b", "item": "2"}], dtype=DTYPE)] - for exp_df, act_df in zip(expected_dfs, list(df_dp)): - self._compare_dataframes(exp_df, act_df) - - # Functional: making sure DataPipe works even without `columns` input - df_dp = csv_dict_parser_dp.dataframe(dtype=DTYPE) - for exp_df, act_df in zip(expected_dfs, list(df_dp)): - self._compare_dataframes(exp_df, act_df) - - @skipIfNoPyArrow - def test_parquet_dataframe_reader_iterdatapipe(self): - DTYPE = dt.Struct([dt.Field("Values", dt.int32)]) - - # Functional Test: read from Parquet files and output TorchArrow DataFrames - source_dp = FileLister(self.temp_dir.name, masks="df*.parquet") - parquet_df_dp = ParquetDataFrameLoader(source_dp, dtype=DTYPE) - expected_dfs = [ - torcharrow.dataframe([(i,) for i in range(10)], dtype=DTYPE), - torcharrow.dataframe([(i,) for i in range(100)], dtype=DTYPE), - ] - for exp_df, act_df in zip(expected_dfs, list(parquet_df_dp)): - self._compare_dataframes(exp_df, act_df) - - # Functional Test: correctly read from a Parquet file that was a merged DataFrame - merged_source_dp = FileLister(self.temp_dir.name, masks="merged.parquet") - merged_parquet_df_dp = ParquetDataFrameLoader(merged_source_dp, dtype=DTYPE) - expected_merged_dfs = [torcharrow.dataframe([(i,) for i in chain(range(10), range(100))], dtype=DTYPE)] - for exp_df, act_df in zip(expected_merged_dfs, list(merged_parquet_df_dp)): - self._compare_dataframes(exp_df, act_df) - - # __len__ Test: no valid length because we do not know the number of row groups in advance - with self.assertRaisesRegex(TypeError, "has no len"): - len(parquet_df_dp) - - # Reset Test: - n_elements_before_reset = 1 - res_before_reset, res_after_reset = reset_after_n_next_calls(parquet_df_dp, n_elements_before_reset) - for exp_df, act_df in zip(expected_dfs[:1], res_before_reset): - self._compare_dataframes(exp_df, act_df) - for exp_df, act_df in zip(expected_dfs, res_after_reset): - self._compare_dataframes(exp_df, act_df) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_distributed.py b/test/test_distributed.py deleted file mode 100644 index 84cfff046..000000000 --- a/test/test_distributed.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import os -import queue -import random -import socket -import sys -import unittest - -from functools import partial -from unittest import TestCase - -import numpy as np - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize -from torch.utils.data import DataLoader - -from torchdata.dataloader2 import DataLoader2, DistributedReadingService -from torchdata.datapipes.iter import IterableWrapper -from torchdata.datapipes.iter.util.distributed import PrefetchTimeoutError - -TEST_MASTER_ADDR = "127.0.0.1" -DEFAULT_WORLD_SIZE = 2 - - -if not dist.is_available(): - print("Distributed not available, skipping tests", file=sys.stderr) - sys.exit(0) - - -_backends = ["gloo"] -if dist.is_mpi_available(): - _backends.append("mpi") -if dist.is_nccl_available() and torch.cuda.device_count() > 0: - _backends.append("nccl") - - -world_size_parametrize = parametrize("world_size", [1, DEFAULT_WORLD_SIZE]) -backend_parametrize = parametrize("backend", _backends) - - -def abs_path(path): - return os.path.join(os.path.dirname(__file__), os.path.normpath(path)) - - -def _get_open_port(): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(("", 0)) - port = s.getsockname()[1] - s.close() - return str(port) - - -class TerminateSignal: - pass - - -# TODO(ejguan): Use queue for all distributed tests -def launch_distributed_training(backend, world_size, *args, fn): - os.environ["MASTER_ADDR"] = TEST_MASTER_ADDR - os.environ["MASTER_PORT"] = _get_open_port() - ctx = mp.get_context("spawn") - q = ctx.Queue() - ps = [] - for rank in range(world_size): - p = ctx.Process( - target=fn, - args=( - rank, - world_size, - backend, - q, - *args, - ), - ) - p.start() - ps.append(p) - res = [] - while True: - try: - d = q.get() - if isinstance(d, TerminateSignal): - break - res.append(d) - except queue.Empty: - continue - for p in ps: - p.join() - return res - - -def _dist_iterate_one_epoch(dl, seed=None): - r""" - Iterate a full epoch of DataLoader and set seeds for global RNGs if provided. - """ - if seed is not None: - torch.manual_seed(seed) - random.seed(seed) - np.random.seed(seed) - res = [] - for d in dl: - res.append(d) - # Simulate training synchronization - dist.barrier() - return res - - -def _finalize_distributed_queue(rank, q): - r""" - Synchronize all distributed processes to guarantee all data have been put into - the Multiprocessing Queue. - """ - pg = dist.new_group(backend="gloo") - end_tensor = torch.tensor([rank], dtype=torch.int64) - dist.all_reduce(end_tensor, group=pg) - if rank == 0: - q.put(TerminateSignal()) - - dist.destroy_process_group(pg) - - -class DistributedTest(TestCase): - @staticmethod - def _test_fullsync(rank, world_size, backend, q): - dist.init_process_group(backend, rank=rank, world_size=world_size) - # Use a prime number to make sure uneven data sharding - data_length = 23 - dp = IterableWrapper(list(range(data_length))).sharding_filter() - torch.utils.data.graph_settings.apply_sharding(dp, world_size, rank) - - dp1 = dp.fullsync() - for _ in range(2): - res = _dist_iterate_one_epoch(dp1) - assert res == list(range(rank, data_length // world_size * world_size, world_size)) - - # Timeout Test - dp2 = dp.fullsync(timeout=0.01) - try: - for _ in range(2): - _ = list(dp2) - except Exception as e: - assert isinstance(e, PrefetchTimeoutError) - - # Test that reset/shutdown does not hang while paused - dp3 = dp.fullsync() - it = iter(dp3) - next(it) - dp3.pause() - it2 = iter(dp3) # Reset - next(it2) - - dp4 = dp.prefetch(2) - it = iter(dp4) - next(it) - dp4.pause() - it2 = iter(dp4) # Reset - next(it2) - - _finalize_distributed_queue(rank, q) - - @world_size_parametrize - @backend_parametrize - def test_fullsync(self, world_size, backend) -> None: - world_size = world_size if backend != "nccl" else torch.cuda.device_count() - launch_distributed_training(backend, world_size, fn=DistributedTest._test_fullsync) - - @staticmethod - def _get_dataloader(data_length: int, dl2: bool, shuffle: bool, rs=None): - data_source = IterableWrapper(list(range(data_length))) - - dp = data_source.sharding_filter() - if shuffle: - dp = dp.shuffle() - - if dl2: - if rs is None: - rs = DistributedReadingService() - dl = DataLoader2(dp, reading_service=rs) - else: - dp = dp.fullsync() - dl = DataLoader(dp) - - return dl - - @staticmethod - def _test_distributed_training(dl2, rank, world_size, backend, q): - dist.init_process_group(backend, rank=rank, world_size=world_size) - # Use a prime number to make sure uneven data sharding - data_length = 23 - - # No shuffle - dl = DistributedTest._get_dataloader(data_length, dl2=dl2, shuffle=False) - res = _dist_iterate_one_epoch(dl) - assert sorted(res) == list(range(rank, data_length // world_size * world_size, world_size)) - - # Shuffle - dl = DistributedTest._get_dataloader(data_length, dl2=dl2, shuffle=True) - results = [] - for _ in range(2): - res = _dist_iterate_one_epoch(dl, seed=123) - results.append(res) - assert results[0] == results[1] - - # Different seed - res = _dist_iterate_one_epoch(dl, seed=321) - results.append(res) - assert len(results[0]) == len(results[2]) - assert results[0] != results[2] - - _finalize_distributed_queue(rank, q) - if dl2: - dl.shutdown() - - @backend_parametrize - def test_distributed_dl2(self, backend) -> None: - world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count() - launch_distributed_training(backend, world_size, fn=partial(DistributedTest._test_distributed_training, True)) - - @backend_parametrize - def test_elastic_training_dl2(self, backend) -> None: - world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count() - nnodes = 1 - from torch.distributed import run - - run.main( - [ - "--run_path", - f"--nnodes={nnodes}", - f"--nproc_per_node={world_size}", - abs_path("elastic_training.py"), - "--" + backend, - "--dl2", - ], - ) - - @backend_parametrize - def test_distributed_dl1(self, backend) -> None: - world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count() - launch_distributed_training(backend, world_size, fn=partial(DistributedTest._test_distributed_training, False)) - - @unittest.skipIf(sys.version_info < (3, 8), "Torch Elastic requires Python >= 3.8") - @backend_parametrize - def test_elastic_training_dl1(self, backend) -> None: - world_size = DEFAULT_WORLD_SIZE if backend != "nccl" else torch.cuda.device_count() - nnodes = 1 - from torch.distributed import run - - run.main( - [ - "--run_path", - f"--nnodes={nnodes}", - f"--nproc_per_node={world_size}", - abs_path("elastic_training.py"), - "--" + backend, - "--dl1", - ], - ) - - -instantiate_parametrized_tests(DistributedTest) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_fsspec.py b/test/test_fsspec.py deleted file mode 100644 index a752b7714..000000000 --- a/test/test_fsspec.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import unittest -import warnings - -import expecttest - -from _utils._common_utils_for_test import create_temp_dir, create_temp_files, reset_after_n_next_calls - -from torchdata.datapipes.iter import ( - FileLister, - FSSpecFileLister, - FSSpecFileOpener, - FSSpecSaver, - IterableWrapper, - IterDataPipe, -) - -try: - import fsspec - - HAS_FSSPEC = True -except ImportError: - HAS_FSSPEC = False -skipIfNoFSSpec = unittest.skipIf(not HAS_FSSPEC, "no fsspec") - - -class TestDataPipeFSSpec(expecttest.TestCase): - def setUp(self): - self.temp_dir = create_temp_dir() - self.temp_files = create_temp_files(self.temp_dir) - self.temp_sub_dir = create_temp_dir(self.temp_dir.name) - self.temp_sub_files = create_temp_files(self.temp_sub_dir, 4, False) - - self.temp_dir_2 = create_temp_dir() - self.temp_files_2 = create_temp_files(self.temp_dir_2) - self.temp_sub_dir_2 = create_temp_dir(self.temp_dir_2.name) - self.temp_sub_files_2 = create_temp_files(self.temp_sub_dir_2, 4, False) - - def tearDown(self): - try: - self.temp_sub_dir.cleanup() - self.temp_dir.cleanup() - self.temp_sub_dir_2.cleanup() - self.temp_dir_2.cleanup() - except Exception as e: - warnings.warn(f"TestDataPipeFSSpec was not able to cleanup temp dir due to {e}") - - def _write_text_files(self): - def filepath_fn(name: str) -> str: - return os.path.join(self.temp_dir.name, os.path.basename(name)) - - name_to_data = {"1.text": b"DATA", "2.text": b"DATA", "3.text": b"DATA"} - source_dp = IterableWrapper(sorted(name_to_data.items())) - saver_dp = source_dp.save_to_disk(filepath_fn=filepath_fn, mode="wb") - list(saver_dp) - - @skipIfNoFSSpec - def test_fsspec_file_lister_iterdatapipe(self): - datapipe: IterDataPipe = FSSpecFileLister(root="file://" + self.temp_sub_dir.name) - - # check all file paths within sub_folder are listed - for path in datapipe: - self.assertIn( - path.split("://")[1], - {fsspec.implementations.local.make_path_posix(file) for file in self.temp_sub_files}, - ) - - # checks for functional API - datapipe = IterableWrapper(["file://" + self.temp_sub_dir.name]) - datapipe = datapipe.list_files_by_fsspec() - for path in datapipe: - self.assertIn( - path.split("://")[1], - {fsspec.implementations.local.make_path_posix(file) for file in self.temp_sub_files}, - ) - - @skipIfNoFSSpec - def test_fsspec_file_lister_iterdatapipe_with_list(self): - datapipe: IterDataPipe = FSSpecFileLister( - root=["file://" + self.temp_sub_dir.name, "file://" + self.temp_sub_dir_2.name] - ) - - # check all file paths within sub_folder are listed - file_lister = list(map(lambda path: path.split("://")[1], datapipe)) - file_lister.sort() - temp_files = list( - map( - lambda file: fsspec.implementations.local.make_path_posix(file), - self.temp_sub_files + self.temp_sub_files_2, - ) - ) - temp_files.sort() - - # check all file paths within sub_folder are listed - self.assertEqual(file_lister, temp_files) - - # checks for functional API - datapipe = IterableWrapper(["file://" + self.temp_sub_dir.name, "file://" + self.temp_sub_dir_2.name]) - datapipe = datapipe.list_files_by_fsspec() - res = list(map(lambda path: path.split("://")[1], datapipe)) - res.sort() - temp_files = list( - map( - lambda file: fsspec.implementations.local.make_path_posix(file), - self.temp_sub_files + self.temp_sub_files_2, - ) - ) - temp_files.sort() - self.assertEqual(res, temp_files) - - @skipIfNoFSSpec - def test_fsspec_file_loader_iterdatapipe(self): - datapipe1 = FSSpecFileLister(root="file://" + self.temp_sub_dir.name) - datapipe2 = FSSpecFileOpener(datapipe1) - datapipe3 = FSSpecFileOpener(datapipe1, kwargs_for_open={"encoding": "cp037"}) - - # check contents of file match - for _, f in datapipe2: - self.assertEqual(f.read(), "0123456789abcdef") - - # Opened with a different encoding, hence NotEqual - for _, f in datapipe3: - self.assertNotEqual(f.read(), "0123456789abcdef") - - # Reset Test: Ensure the resulting streams are still readable after the DataPipe is reset/exhausted - self._write_text_files() - lister_dp = FileLister(self.temp_dir.name, "*.text") - fsspec_file_opener_dp = lister_dp.open_files_by_fsspec(mode="rb") - - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(fsspec_file_opener_dp, n_elements_before_reset) - self.assertEqual(2, len(res_before_reset)) - self.assertEqual(3, len(res_after_reset)) - for _name, stream in res_before_reset: - self.assertEqual(b"DATA", stream.read()) - for _name, stream in res_after_reset: - self.assertEqual(b"DATA", stream.read()) - - @skipIfNoFSSpec - def test_fsspec_saver_iterdatapipe(self): - def filepath_fn(name: str) -> str: - return "file://" + os.path.join(self.temp_dir.name, os.path.basename(name)) - - # Functional Test: Saving some data - name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"} - source_dp = IterableWrapper(sorted(name_to_data.items())) - saver_dp = source_dp.save_by_fsspec(filepath_fn=filepath_fn, mode="wb") - res_file_paths = list(saver_dp) - expected_paths = [filepath_fn(name) for name in name_to_data.keys()] - self.assertEqual(expected_paths, res_file_paths) - for name in name_to_data.keys(): - p = filepath_fn(name).split("://")[1] - with open(p) as f: - self.assertEqual(name_to_data[name], f.read().encode()) - - # Reset Test: - saver_dp = FSSpecSaver(source_dp, filepath_fn=filepath_fn, mode="wb") - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(saver_dp, n_elements_before_reset) - self.assertEqual([filepath_fn("1.txt"), filepath_fn("2.txt")], res_before_reset) - self.assertEqual(expected_paths, res_after_reset) - for name in name_to_data.keys(): - p = filepath_fn(name).split("://")[1] - with open(p) as f: - self.assertEqual(name_to_data[name], f.read().encode()) - - # __len__ Test: returns the length of source DataPipe - self.assertEqual(3, len(saver_dp)) - - @skipIfNoFSSpec - def test_fsspec_memory_list(self): - fs = fsspec.filesystem("memory") - fs.mkdir("foo") - fs.touch("foo/bar1") - fs.touch("foo/bar2") - - datapipe = FSSpecFileLister(root="memory://foo") - self.assertEqual(set(datapipe), {"memory:///foo/bar1", "memory:///foo/bar2"}) - - datapipe = FSSpecFileLister(root="memory://foo/bar1") - self.assertEqual(set(datapipe), {"memory://foo/bar1"}) - - @skipIfNoFSSpec - def test_fsspec_memory_load(self): - fs = fsspec.filesystem("memory") - with fs.open("file", "w") as f: - f.write("hello") - with fs.open("file2", "w") as f: - f.write("hello2") - - files = ["memory://file", "memory://file2"] - datapipe = FSSpecFileOpener(files) - self.assertEqual([f.read() for _, f in datapipe], ["hello", "hello2"]) - - @skipIfNoFSSpec - def test_fsspec_memory_save(self): - def filepath_fn(name: str) -> str: - return "memory://" + name - - name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2"} - source_dp = IterableWrapper(sorted(name_to_data.items())) - saver_dp = FSSpecSaver(source_dp, filepath_fn=filepath_fn, mode="wb") - - self.assertEqual(set(saver_dp), {"memory://1.txt", "memory://2.txt"}) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_graph.py b/test/test_graph.py deleted file mode 100644 index c0c346554..000000000 --- a/test/test_graph.py +++ /dev/null @@ -1,518 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import types -import unittest - -from typing import Dict, Iterator, List, Tuple, TypeVar - -import expecttest - -from _utils._common_utils_for_test import IS_WINDOWS - -from torch.utils.data import IterDataPipe -from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES - -from torchdata.dataloader2 import DataLoader2, ReadingServiceInterface -from torchdata.dataloader2.graph import find_dps, list_dps, remove_dp, replace_dp, traverse_dps -from torchdata.dataloader2.graph.utils import _find_replicable_branches -from torchdata.dataloader2.random import SeedGenerator -from torchdata.dataloader2.utils.dispatch import ( - _DummyIterDataPipe, - find_lca_round_robin_sharding_dp, - find_non_dispatching_branches, -) -from torchdata.datapipes.iter import IterableWrapper, Mapper, ShardingRoundRobinDispatcher -from torchdata.datapipes.utils import to_graph - -T_co = TypeVar("T_co", covariant=True) - -try: - import graphviz - - HAS_GRAPHVIZ = True -except ImportError: - HAS_GRAPHVIZ = False - - -class Adaptor(IterDataPipe[T_co]): - def __init__(self, datapipe: IterDataPipe) -> None: - self.datapipe = datapipe - self.started = False - - def __iter__(self) -> Iterator[T_co]: - yield from self.datapipe - - -class DummyIterDataPipe(IterDataPipe[T_co]): - def __iter__(self) -> Iterator[T_co]: - yield from range(10) - - -class TempReadingService(ReadingServiceInterface): - adaptors: List[IterDataPipe] = [] - - def initialize(self, datapipe: IterDataPipe) -> IterDataPipe: - graph = traverse_dps(datapipe) - dps = find_dps(graph, Mapper) - - for dp in reversed(dps): - new_dp = Adaptor(dp) - self.adaptors.append(new_dp) - graph = replace_dp(graph, dp, new_dp) - - return list(graph.values())[0][0] - - def initialize_iteration(self, seed_generator: SeedGenerator) -> None: - seed_generator.seed(123) - for dp in self.adaptors: - dp.started = True - - def finalize_iteration(self) -> None: - for dp in self.adaptors: - dp.started = False - - -def _x_and_x_plus_5(x): - return [x, x + 5] - - -def _x_mod_2(x): - return x % 2 - - -def _x_mult_2(x): - return x * 2 - - -class TestGraph(expecttest.TestCase): - def _get_datapipes(self) -> Tuple[IterDataPipe, IterDataPipe, IterDataPipe]: - src_dp = IterableWrapper(range(20)) - m1 = src_dp.map(_x_and_x_plus_5) - ub = m1.unbatch() - c1, c2 = ub.demux(2, _x_mod_2) - dm = c1.main_datapipe - m2 = c1.map(_x_mult_2) - dp = m2.zip(c2) - - return traverse_dps(dp), (src_dp, m1, ub, dm, c1, c2, m2, dp) - - def test_find_dps(self) -> None: - graph, (_, m1, *_, m2, _) = self._get_datapipes() # pyre-ignore - - dps = find_dps(graph, Mapper) - - expected_dps = {m1, m2} - for dp in dps: - self.assertTrue(dp in expected_dps) - - def test_list_dps(self) -> None: - def _validate_fn(dps, exp_dps): - self.assertEqual(len(dps), len(exp_dps)) - # Validate BFS Order - for dp, exp_dp in zip(dps, exp_dps): - self.assertEqual(dp, exp_dp) - - graph, ( - src_dp, - m1, - ub, - dm, - c1, - c2, - m2, - dp, - ) = self._get_datapipes() - exp_all_dps = [dp, m2, c2, c1, dm, ub, m1, src_dp] - - # List all DataPipes - dps = list_dps(graph) - _validate_fn(dps, exp_all_dps) - - # List all DataPipes excluding a single DataPipe - dps = list_dps(graph, exclude_dps=m1) - *exp_dps, _, _ = exp_all_dps - _validate_fn(dps, exp_dps) - - # Exclude a DataPipe on one branch - dps = list_dps(graph, exclude_dps=m2) - exp_dps = [dp, c2] - _validate_fn(dps, exp_dps) - - # List all DataPipes excluding multiple DataPipes - dps = list_dps(graph, exclude_dps=[m1, m2]) - exp_dps = [dp, c2] - _validate_fn(dps, exp_dps) - - def _validate_graph(self, graph, nested_dp): - self.assertEqual(len(graph), len(nested_dp)) - for dp_id, sub_nested_dp in zip(graph, nested_dp): - self.assertEqual(graph[dp_id][0], sub_nested_dp[0]) - if len(graph[dp_id][1]) > 0: - self._validate_graph(graph[dp_id][1], sub_nested_dp[1]) - - def test_replace_dps(self) -> None: - # pyre-fixme[23]: Unable to unpack 3 values, 2 were expected. - graph, ( - src_dp, - m1, - ub, - dm, - c1, - c2, - m2, - dp, - ) = self._get_datapipes() - - new_dp1 = Adaptor(m1) - new_dp2 = Adaptor(m2) - new_dp3 = DummyIterDataPipe() - - graph = replace_dp(graph, m1, new_dp1) - exp_g1 = [ - [ - dp, - [ - [m2, [[c1, [[dm, [[ub, [[new_dp1, [[m1, [[src_dp, []]]]]]]]]]]]]], - [c2, [[dm, [[ub, [[new_dp1, [[m1, [[src_dp, []]]]]]]]]]]], - ], - ] - ] - self._validate_graph(traverse_dps(dp), exp_g1) - - graph = replace_dp(graph, m2, new_dp2) - exp_g2 = [ - [ - dp, - [ - [new_dp2, [[m2, [[c1, [[dm, [[ub, [[new_dp1, [[m1, [[src_dp, []]]]]]]]]]]]]]]], - [c2, [[dm, [[ub, [[new_dp1, [[m1, [[src_dp, []]]]]]]]]]]], - ], - ] - ] - self._validate_graph(traverse_dps(dp), exp_g2) - - graph = replace_dp(graph, m1, new_dp3) - exp_g3 = [ - [ - dp, - [ - [new_dp2, [[m2, [[c1, [[dm, [[ub, [[new_dp1, [[new_dp3, []]]]]]]]]]]]]], - [c2, [[dm, [[ub, [[new_dp1, [[new_dp3, []]]]]]]]]], - ], - ] - ] - self._validate_graph(traverse_dps(dp), exp_g3) - - def test_remove_dps(self) -> None: - # pyre-fixme[23]: Unable to unpack 3 values, 2 were expected. - graph, ( - src_dp, - m1, - ub, - dm, - c1, - c2, - m2, - dp, - ) = self._get_datapipes() - - graph = remove_dp(graph, m1) - exp_g1 = [[dp, [[m2, [[c1, [[dm, [[ub, [[src_dp, []]]]]]]]]], [c2, [[dm, [[ub, [[src_dp, []]]]]]]]]]] - self._validate_graph(traverse_dps(dp), exp_g1) - - graph = remove_dp(graph, m2) - exp_g2 = [[dp, [[c1, [[dm, [[ub, [[src_dp, []]]]]]]], [c2, [[dm, [[ub, [[src_dp, []]]]]]]]]]] - self._validate_graph(traverse_dps(dp), exp_g2) - - with self.assertRaisesRegex(RuntimeError, "Cannot remove the source DataPipe"): - remove_dp(graph, src_dp) - - with self.assertRaisesRegex(RuntimeError, "Cannot remove the receiving DataPipe"): - remove_dp(graph, dp) - - def test_reading_service(self) -> None: - _, (*_, dp) = self._get_datapipes() # pyre-ignore - - rs = TempReadingService() - dl = DataLoader2(dp, reading_service=rs) - - self.assertTrue(len(rs.adaptors) == 0) - - it = iter(dl) - for new_dp in rs.adaptors: - self.assertTrue(new_dp.started) - - res = list(it) - self.assertEqual(len(res), 20) - - for new_dp in rs.adaptors: - self.assertFalse(new_dp.started) - - self.assertEqual(res, list(dl)) - - -def insert_round_robin_sharding(graph, datapipe): - dispatch_dp = ShardingRoundRobinDispatcher(datapipe, SHARDING_PRIORITIES.MULTIPROCESSING) - return replace_dp(graph, datapipe, dispatch_dp), dispatch_dp - - -def replace_by_dummy(graph, datapipe): - return replace_dp(graph, datapipe, _DummyIterDataPipe()) - - -def make_non_replicable_dp(datapipe): - datapipe.is_replicable = types.MethodType(lambda self: False, datapipe) - return datapipe - - -class TestNonReplicableDataPipe(expecttest.TestCase): - def _make_dp(self): - r""" - Create a DataPipe that contains the most of cases including: - - single-branch pipeline - - multi-branch pipeline - - pipeline that has circurlar references - - single_br_dp ------------------------------------- - ch1 \ - / \ \ - multi_br_dp -->forker_dp--> -> fork_zip_dp -> end_dp -> - \ / / - <------- ch2 / - / \ / - cir_br_dp -> cir_map_dp -------------------------- - """ - # Single-branch - single_br_dp = IterableWrapper(list(range(10))) - - # Multi-branch - multi_br_dp = IterableWrapper(list(range(10))) - ch1, ch2 = multi_br_dp.fork(2) - forker_dp = ch1.main_datapipe - fork_zip_dp = ch1.zip(ch2) - - # Circular-branch - cir_br_dp = IterableWrapper(list(range(10))) - cir_map_dp = cir_br_dp.map(_x_mult_2) - # Force to circular reference - cir_br_dp.cir_dep = cir_map_dp - - end_dp = single_br_dp.zip(fork_zip_dp, cir_map_dp) - graph = traverse_dps(end_dp) - return single_br_dp, multi_br_dp, forker_dp, ch1, ch2, fork_zip_dp, cir_br_dp, cir_map_dp, end_dp, graph - - def test_single_round_robin_sharding_dp(self): - single_br_dp, *_, graph = self._make_dp() - graph, single_br_dp = insert_round_robin_sharding(graph, single_br_dp) - self.assertEqual(find_lca_round_robin_sharding_dp(graph), single_br_dp) - - # The same non-shardable DataPipe on both branches - _, multi_br_dp, *_, graph = self._make_dp() - graph, multi_br_dp = insert_round_robin_sharding(graph, multi_br_dp) - self.assertEqual(find_lca_round_robin_sharding_dp(graph), multi_br_dp) - - _, _, _, ch1, _, fork_zip_dp, *_, graph = self._make_dp() - graph, ch1 = insert_round_robin_sharding(graph, ch1) - self.assertEqual(find_lca_round_robin_sharding_dp(graph), fork_zip_dp) - - # Circular reference - *_, cir_br_dp, cir_map_dp, _, graph = self._make_dp() - graph, cir_br_dp = insert_round_robin_sharding(graph, cir_br_dp) - self.assertEqual(find_lca_round_robin_sharding_dp(graph), cir_map_dp) - - *_, cir_map_dp, _, graph = self._make_dp() - graph, cir_map_dp = insert_round_robin_sharding(graph, cir_map_dp) - self.assertEqual(find_lca_round_robin_sharding_dp(graph), cir_map_dp) - - def test_multi_round_robin_sharding_dps(self): - single_br_dp, multi_br_dp, *_, end_dp, graph = self._make_dp() - graph, single_br_dp = insert_round_robin_sharding(graph, single_br_dp) - graph, multi_br_dp = insert_round_robin_sharding(graph, multi_br_dp) - self.assertEqual(find_lca_round_robin_sharding_dp(graph), end_dp) - - single_br_dp, _, _, ch1, *_, end_dp, graph = self._make_dp() - graph, single_br_dp = insert_round_robin_sharding(graph, single_br_dp) - graph, ch1 = insert_round_robin_sharding(graph, ch1) - self.assertEqual(find_lca_round_robin_sharding_dp(graph), end_dp) - - _, multi_br_dp, _, ch1, _, fork_zip_dp, *_, graph = self._make_dp() - graph, multi_br_dp = insert_round_robin_sharding(graph, multi_br_dp) - graph, ch1 = insert_round_robin_sharding(graph, ch1) - self.assertEqual(find_lca_round_robin_sharding_dp(graph), fork_zip_dp) - - single_br_dp, *_, cir_br_dp, _, end_dp, graph = self._make_dp() - graph, single_br_dp = insert_round_robin_sharding(graph, single_br_dp) - graph, cir_br_dp = insert_round_robin_sharding(graph, cir_br_dp) - self.assertEqual(find_lca_round_robin_sharding_dp(graph), end_dp) - - def test_non_dispatching_branches(self): - r""" - There should be a single DataPipe as the lowest common ancestor of all - non-dispatching DataPipes that is replaced by ``DummyIterDataPipe``. - """ - single_br_dp, *_, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp() - graph = replace_by_dummy(graph, single_br_dp) - dps = find_non_dispatching_branches(graph) - self.assertEqual(len(dps), 2) - self.assertTrue(all(dp in (fork_zip_dp, cir_map_dp) for dp in dps)) - - single_br_dp, multi_br_dp, *_, cir_map_dp, _, graph = self._make_dp() - graph = replace_by_dummy(graph, multi_br_dp) - dps = find_non_dispatching_branches(graph) - self.assertEqual(len(dps), 2) - self.assertTrue(all(dp in (single_br_dp, cir_map_dp) for dp in dps)) - - # In theory, this case should never happen because LCA (fork_zip_dp) should be - # replaced by _DummpyIterDataPipe if any of child is non-replicable - single_br_dp, _, _, ch1, ch2, *_, cir_map_dp, _, graph = self._make_dp() - graph = replace_by_dummy(graph, ch1) - dps = find_non_dispatching_branches(graph) - self.assertEqual(len(dps), 3) - self.assertTrue(all(dp in (single_br_dp, ch2, cir_map_dp) for dp in dps)) - - single_br_dp, *_, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp() - graph = replace_by_dummy(graph, cir_map_dp) - dps = find_non_dispatching_branches(graph) - self.assertTrue(all(dp in (single_br_dp, fork_zip_dp) for dp in dps)) - - *_, end_dp, graph = self._make_dp() - graph = replace_by_dummy(graph, end_dp) - dps = find_non_dispatching_branches(graph) - self.assertEqual(len(dps), 0) - - single_br_dp, *_, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp() - graph = replace_by_dummy(graph, fork_zip_dp) - dps = find_non_dispatching_branches(graph) - self.assertEqual(len(dps), 2) - self.assertTrue(all(dp in (single_br_dp, cir_map_dp) for dp in dps)) - - def test_single_non_replicable_dp(self): - # All replicable - *_, end_dp, graph = self._make_dp() - dps = _find_replicable_branches(graph) - self.assertEqual(len(dps), 1) - self.assertEqual(dps[0], end_dp) - - # Test the production use case where the last DataPipe is fullsync - *_, end_dp, _ = self._make_dp() - dp = end_dp.fullsync() - graph = traverse_dps(dp) - dps = _find_replicable_branches(graph) - self.assertEqual(len(dps), 1) - self.assertEqual(dps[0], end_dp) - - single_br_dp, *_, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp() - make_non_replicable_dp(single_br_dp) - dps = _find_replicable_branches(graph) - self.assertEqual(len(dps), 2) - self.assertTrue(all(dp in (fork_zip_dp, cir_map_dp) for dp in dps)) - - single_br_dp, *_, ch1, ch2, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp() - make_non_replicable_dp(fork_zip_dp) - dps = _find_replicable_branches(graph) - self.assertEqual(len(dps), 4) - self.assertTrue(all(dp in (single_br_dp, ch1, ch2, cir_map_dp) for dp in dps)) - - single_br_dp, _, forker_dp, ch1, *_, cir_map_dp, _, graph = self._make_dp() - make_non_replicable_dp(ch1) - dps = _find_replicable_branches(graph) - self.assertEqual(len(dps), 3) - self.assertTrue(all(dp in (single_br_dp, forker_dp, cir_map_dp) for dp in dps)) - - single_br_dp, *_, fork_zip_dp, cir_br_dp, cir_map_dp, _, graph = self._make_dp() - make_non_replicable_dp(cir_map_dp) - dps = _find_replicable_branches(graph) - self.assertEqual(len(dps), 3) - self.assertTrue(all(dp in (single_br_dp, fork_zip_dp, cir_br_dp) for dp in dps)) - - single_br_dp, *_, fork_zip_dp, _, cir_map_dp, end_dp, graph = self._make_dp() - make_non_replicable_dp(end_dp) - dps = _find_replicable_branches(graph) - self.assertEqual(len(dps), 3) - self.assertTrue(all(dp in (single_br_dp, fork_zip_dp, cir_map_dp) for dp in dps)) - - def test_multi_non_replicable_dps(self): - single_br_dp, multi_br_dp, *_, cir_map_dp, _, graph = self._make_dp() - make_non_replicable_dp(single_br_dp) - make_non_replicable_dp(multi_br_dp) - dps = _find_replicable_branches(graph) - self.assertEqual(len(dps), 1) - self.assertEqual(dps[0], cir_map_dp) - - single_br_dp, _, forker_dp, ch1, *_, cir_map_dp, _, graph = self._make_dp() - make_non_replicable_dp(single_br_dp) - make_non_replicable_dp(ch1) - dps = _find_replicable_branches(graph) - self.assertEqual(len(dps), 2) - self.assertTrue(all(dp in (forker_dp, cir_map_dp) for dp in dps)) - - single_br_dp, *_, ch1, ch2, fork_zip_dp, _, cir_map_dp, _, graph = self._make_dp() - make_non_replicable_dp(single_br_dp) - make_non_replicable_dp(fork_zip_dp) - dps = _find_replicable_branches(graph) - self.assertEqual(len(dps), 3) - self.assertTrue(all(dp in (ch1, ch2, cir_map_dp) for dp in dps)) - - single_br_dp, *_, fork_zip_dp, cir_br_dp, cir_map_dp, _, graph = self._make_dp() - make_non_replicable_dp(single_br_dp) - make_non_replicable_dp(cir_map_dp) - dps = _find_replicable_branches(graph) - self.assertEqual(len(dps), 2) - self.assertTrue(all(dp in (fork_zip_dp, cir_br_dp) for dp in dps)) - - single_br_dp, multi_br_dp, forker_dp, ch1, *_, cir_map_dp, _, graph = self._make_dp() - make_non_replicable_dp(forker_dp) - make_non_replicable_dp(ch1) - dps = _find_replicable_branches(graph) - self.assertEqual(len(dps), 3) - self.assertTrue(all(dp in (single_br_dp, multi_br_dp, cir_map_dp) for dp in dps)) - - single_br_dp, multi_br_dp, forker_dp, *_, cir_br_dp, cir_map_dp, _, graph = self._make_dp() - make_non_replicable_dp(forker_dp) - make_non_replicable_dp(cir_map_dp) - dps = _find_replicable_branches(graph) - self.assertEqual(len(dps), 3) - self.assertTrue(all(dp in (single_br_dp, multi_br_dp, cir_br_dp) for dp in dps)) - - single_br_dp, *_, ch1, ch2, fork_zip_dp, cir_br_dp, cir_map_dp, _, graph = self._make_dp() - make_non_replicable_dp(fork_zip_dp) - make_non_replicable_dp(cir_map_dp) - dps = _find_replicable_branches(graph) - self.assertEqual(len(dps), 4) - self.assertTrue(all(dp in (single_br_dp, ch1, ch2, cir_br_dp) for dp in dps)) - - -class TestGraphVisualization(expecttest.TestCase): - @unittest.skipIf(not HAS_GRAPHVIZ, "Package `graphviz` is required to test graph visualization functionalities.") - def test_to_graph(self): - dp1 = IterableWrapper(range(10)) - dp2 = dp1.map(lambda x: x + 1) - dp3 = dp2.filter(lambda x: x > 5) - cdp1, cdp2 = dp3.fork(num_instances=2) - dp4 = cdp1.zip(cdp2) - cdp3, cdp4 = dp4.demux(num_instances=2, classifier_fn=lambda x: x % 2) - dp5 = cdp3.concat(cdp4) - - # Test to ensure that we can create these graphs with runtime errors - kwargs_list: List[Dict] = [ - {"dp": dp1}, - {"dp": dp2}, - {"dp": dp3}, - {"dp": cdp1, "debug": True}, - {"dp": dp4}, - {"dp": dp4, "debug": True}, - {"dp": cdp3, "debug": True}, - {"dp": dp5}, - {"dp": dp5, "debug": True}, - ] - for kwargs in kwargs_list: - g = to_graph(**kwargs) - self.assertTrue(isinstance(g, graphviz.Digraph)) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_huggingface_datasets.py b/test/test_huggingface_datasets.py deleted file mode 100644 index bc325ce74..000000000 --- a/test/test_huggingface_datasets.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest -from unittest.mock import patch - -import expecttest - -from torchdata.datapipes.iter import HuggingFaceHubReader - -try: - import datasets - - HAS_DATASETS = True - -except ImportError: - HAS_DATASETS = False -skipIfNoDatasets = unittest.skipIf(not HAS_DATASETS, "no datasets") - - -class TestHuggingFaceHubReader(expecttest.TestCase): - @skipIfNoDatasets - @patch("datasets.load_dataset") - def test_huggingface_hubreader(self, mock_load_dataset): - mock_load_dataset.return_value = datasets.Dataset.from_dict( - { - "id": ["7bd227d9-afc9-11e6-aba1-c4b301cdf627", "7bd22905-afc9-11e6-a5dc-c4b301cdf627"], - "package_name": ["com.mantz_it.rfanalyzer"] * 2, - } - ) - - datapipe = HuggingFaceHubReader("lhoestq/demo1", revision="branch", streaming=False, use_auth_token=True) - - iterator = iter(datapipe) - elem = next(iterator) - assert type(elem) is dict - assert elem["id"] == "7bd227d9-afc9-11e6-aba1-c4b301cdf627" - assert elem["package_name"] == "com.mantz_it.rfanalyzer" - mock_load_dataset.assert_called_with( - path="lhoestq/demo1", streaming=False, revision="branch", use_auth_token=True - ) - with self.assertRaises(StopIteration): - next(iterator) - next(iterator) - with self.assertRaises(TypeError): - len(datapipe) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py deleted file mode 100644 index d9c89fc23..000000000 --- a/test/test_iterdatapipe.py +++ /dev/null @@ -1,2054 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import asyncio -import io -import itertools -import pickle -import unittest -import warnings - -from collections import defaultdict -from functools import partial -from typing import Dict, NamedTuple - -import expecttest -import torch - -import torchdata - -from _utils._common_utils_for_test import IDP_NoLen, reset_after_n_next_calls -from torch.testing._internal.common_utils import suppress_warnings - -from torch.utils.data.datapipes.utils.snapshot import _simple_graph_snapshot_restoration -from torchdata.datapipes.iter import ( - BucketBatcher, - Cycler, - Header, - IndexAdder, - InMemoryCacheHolder, - IterableWrapper, - IterDataPipe, - IterKeyZipper, - LineReader, - MapKeyZipper, - MaxTokenBucketizer, - ParagraphAggregator, - Repeater, - Rows2Columnar, - SampleMultiplexer, - ShardExpander, - UnZipper, -) -from torchdata.datapipes.map import MapDataPipe, SequenceWrapper - -skipIfNoCUDA = unittest.skipIf(not torch.cuda.is_available(), "CUDA is not available") - - -def test_torchdata_pytorch_consistency() -> None: - def extract_datapipe_names(module): - return { - name - for name, dp_type in module.__dict__.items() - if not name.startswith("_") and isinstance(dp_type, type) and issubclass(dp_type, IterDataPipe) - } - - pytorch_datapipes = extract_datapipe_names(torch.utils.data.datapipes.iter) - torchdata_datapipes = extract_datapipe_names(torchdata.datapipes.iter) - - missing_datapipes = pytorch_datapipes - torchdata_datapipes - deprecated_datapipes = {"FileLoader"} - for dp in deprecated_datapipes: - if dp in missing_datapipes: - missing_datapipes.remove("FileLoader") - - if any(missing_datapipes): - msg = ( - "The following datapipes are exposed under `torch.utils.data.datapipes.iter`, " - "but not under `torchdata.datapipes.iter`:\n" - ) - raise AssertionError(msg + "\n".join(sorted(missing_datapipes))) - - -def _convert_to_tensor(data): - if isinstance(data, dict): - return {k: _convert_to_tensor(v) for k, v in data.items()} - elif isinstance(data, list): - return [_convert_to_tensor(v) for v in data] - return torch.tensor(data) - - -async def _async_mul_ten(x): - await asyncio.sleep(0.1) - return x * 10 - - -async def _async_x_mul_y(x, y): - await asyncio.sleep(0.1) - return x * y - - -class NamedTensors(NamedTuple): - x: torch.Tensor - y: torch.Tensor - - -class TestIterDataPipe(expecttest.TestCase): - def test_in_memory_cache_holder_iterdatapipe(self) -> None: - source_dp = IterableWrapper(range(10)) - cache_dp = source_dp.in_memory_cache(size=5) - - # Functional Test: Cache DP should just return the data without changing the values - res1 = list(cache_dp) - self.assertEqual(list(range(10)), res1) - - # Functional Test: Ensure the objects are the same ones from source DataPipe - res1 = list(cache_dp) - res2 = list(cache_dp) - self.assertTrue(id(source) == id(cache) for source, cache in zip(source_dp, res1)) - self.assertTrue(id(source) == id(cache) for source, cache in zip(source_dp, res2)) - - # TODO(122): Figure out a way to consistently test caching when size is in megabytes - - # Reset Test: reset the DataPipe after reading part of it - cache_dp = InMemoryCacheHolder(source_dp, size=5) - n_elements_before_reset = 5 - res_before_reset, res_after_reset = reset_after_n_next_calls(cache_dp, n_elements_before_reset) - self.assertEqual(list(range(5)), res_before_reset) - self.assertEqual(list(range(10)), res_after_reset) - - # __len__ Test: inherits length from source_dp - self.assertEqual(10, len(cache_dp)) - - # __len__ Test: source_dp has no len and cache is not yet loaded - source_dp_no_len = IDP_NoLen(range(10)) - cache_dp = InMemoryCacheHolder(source_dp_no_len, size=5) - with self.assertRaisesRegex(TypeError, "doesn't have valid length until the cache is loaded"): - len(cache_dp) - - # __len__ Test: source_dp has no len but we still can calculate after cache is loaded - list(cache_dp) - self.assertEqual(10, len(cache_dp)) - - def test_iter_key_zipper_iterdatapipe(self) -> None: - - source_dp = IterableWrapper(range(10)) - ref_dp = IterableWrapper(range(20)) - ref_dp2 = IterableWrapper(range(20)) - - # Functional Test: Output should be a zip list of tuple - zip_dp = source_dp.zip_with_iter( - ref_datapipe=ref_dp, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=False, buffer_size=100 - ) - self.assertEqual([(i, i) for i in range(10)], list(zip_dp)) - - # Functional Test: keep_key=True, and key should show up as the first element - zip_dp_w_key = source_dp.zip_with_iter( - ref_datapipe=ref_dp2, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=True, buffer_size=10 - ) - self.assertEqual([(i, (i, i)) for i in range(10)], list(zip_dp_w_key)) - - # Functional Test: using a different merge function - def merge_to_string(item1, item2): - return f"{item1},{item2}" - - zip_dp_w_str_merge = source_dp.zip_with_iter( - ref_datapipe=ref_dp, key_fn=lambda x: x, ref_key_fn=lambda x: x, buffer_size=10, merge_fn=merge_to_string - ) - self.assertEqual([f"{i},{i}" for i in range(10)], list(zip_dp_w_str_merge)) - - # Functional Test: using a different merge function and keep_key=True - zip_dp_w_key_str_merge = source_dp.zip_with_iter( - ref_datapipe=ref_dp, - key_fn=lambda x: x, - ref_key_fn=lambda x: x, - keep_key=True, - buffer_size=10, - merge_fn=merge_to_string, - ) - self.assertEqual([(i, f"{i},{i}") for i in range(10)], list(zip_dp_w_key_str_merge)) - - # Functional Test: testing nested zipping - zip_dp = source_dp.zip_with_iter( - ref_datapipe=ref_dp, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=False, buffer_size=100 - ) - - # Without a custom merge function, there will be nested tuples - zip_dp2 = zip_dp.zip_with_iter( - ref_datapipe=ref_dp2, key_fn=lambda x: x[0], ref_key_fn=lambda x: x, keep_key=False, buffer_size=100 - ) - self.assertEqual([((i, i), i) for i in range(10)], list(zip_dp2)) - - # With a custom merge function, nesting can be prevented - zip_dp2_w_merge = zip_dp.zip_with_iter( - ref_datapipe=ref_dp2, - key_fn=lambda x: x[0], - ref_key_fn=lambda x: x, - keep_key=False, - buffer_size=100, - merge_fn=lambda x, y: list(x) + [y], - ) - self.assertEqual([[i, i, i] for i in range(10)], list(zip_dp2_w_merge)) - - # Functional Test: element is in source but missing in reference - ref_dp_missing = IterableWrapper(range(1, 10)) - zip_dp = source_dp.zip_with_iter( - ref_datapipe=ref_dp_missing, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=False, buffer_size=100 - ) - with self.assertRaisesRegex(BufferError, r"No matching key can be found"): - list(zip_dp) - - # Functional Test: Buffer is not large enough, hence, element can't be found and raises error - ref_dp_end = IterableWrapper(list(range(1, 10)) + [0]) - zip_dp = source_dp.zip_with_iter( - ref_datapipe=ref_dp_end, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=False, buffer_size=5 - ) - it = iter(zip_dp) - with warnings.catch_warnings(record=True) as wa: - # In order to find '0' at the end, the buffer is filled, hence the warning - # and ref_dp is fully traversed - self.assertEqual( - ( - 0, - 0, - ), - next(it), - ) - self.assertEqual(len(wa), 1) - self.assertRegex(str(wa[0].message), r"Buffer reaches the upper limit") - with self.assertRaisesRegex(BufferError, r"No matching key can be found"): - # '1' cannot be find because the value was thrown out when buffer was filled - next(it) - - # Functional Test: Buffer is just big enough - zip_dp = source_dp.zip_with_iter( - ref_datapipe=ref_dp_end, key_fn=lambda x: x, ref_key_fn=lambda x: x, keep_key=False, buffer_size=10 - ) - self.assertEqual([(i, i) for i in range(10)], list(zip_dp)) - - # Reset Test: reset the DataPipe after reading part of it - zip_dp = IterKeyZipper( - source_datapipe=source_dp, - ref_datapipe=ref_dp, - key_fn=lambda x: x, - ref_key_fn=lambda x: x, - keep_key=False, - buffer_size=10, - ) - n_elements_before_reset = 5 - res_before_reset, res_after_reset = reset_after_n_next_calls(zip_dp, n_elements_before_reset) - self.assertEqual([(i, i) for i in range(5)], res_before_reset) - self.assertEqual([(i, i) for i in range(10)], res_after_reset) - - # __len__ Test: inherits length from source_dp - self.assertEqual(10, len(zip_dp)) - - def test_map_key_zipper_datapipe(self) -> None: - source_dp = IterableWrapper(range(10)) - map_dp = SequenceWrapper(["even", "odd"]) - - # Functional Test: ensure the hash join is working and return tuple by default - def odd_even(i: int) -> int: - return i % 2 - - result_dp = source_dp.zip_with_map(map_dp, odd_even) - - def odd_even_string(i: int) -> str: - return "odd" if i % 2 else "even" - - expected_res = [(i, odd_even_string(i)) for i in range(10)] - self.assertEqual(expected_res, list(result_dp)) - - # Functional Test: ensure that a custom merge function works - def custom_merge(a, b): - return f"{a} is a {b} number." - - result_dp = source_dp.zip_with_map(map_dp, odd_even, custom_merge) - expected_res2 = [f"{i} is a {odd_even_string(i)} number." for i in range(10)] - self.assertEqual(expected_res2, list(result_dp)) - - # Functional Test: raises error when key is invalid - def odd_even_bug(i: int) -> int: - return 2 if i == 0 else i % 2 - - result_dp = MapKeyZipper(source_dp, map_dp, odd_even_bug) - it = iter(result_dp) - with self.assertRaisesRegex(KeyError, "is not a valid key in the given MapDataPipe"): - next(it) - - # Functional test: ensure that keep_key option works - result_dp = source_dp.zip_with_map(map_dp, odd_even, keep_key=True) - expected_res_keep_key = [(key, (i, odd_even_string(i))) for i, key in zip(range(10), [0, 1] * 5)] - self.assertEqual(expected_res_keep_key, list(result_dp)) - - # Reset Test: - n_elements_before_reset = 4 - result_dp = source_dp.zip_with_map(map_dp, odd_even) - res_before_reset, res_after_reset = reset_after_n_next_calls(result_dp, n_elements_before_reset) - self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset) - self.assertEqual(expected_res, res_after_reset) - - # __len__ Test: returns the length of source DataPipe - result_dp = source_dp.zip_with_map(map_dp, odd_even) - self.assertEqual(len(source_dp), len(result_dp)) - - def test_prefetcher_iterdatapipe(self) -> None: - source_dp = IterableWrapper(range(5000)) - prefetched_dp = source_dp.prefetch(10) - # check if early termination resets child thread properly - for _, _ in zip(range(100), prefetched_dp): - pass - expected = list(source_dp) - actual = list(prefetched_dp) - self.assertEqual(expected, actual) - - # __len__ Test: returns the same length as source - self.assertEqual(len(source_dp), len(prefetched_dp)) - - def test_repeater_iterdatapipe(self) -> None: - import itertools - - source_dp = IterableWrapper(range(5)) - - # Functional Test: repeat for correct number of times - repeater_dp = source_dp.repeat(3) - self.assertEqual( - list(itertools.chain.from_iterable(itertools.repeat(x, 3) for x in range(5))), list(repeater_dp) - ) - - # Functional Test: `times` must be > 1 - with self.assertRaisesRegex(ValueError, "The number of repetition must be > 1"): - source_dp.repeat(1) - - # Reset Test: - repeater_dp = Repeater(source_dp, times=2) - n_elements_before_reset = 4 - res_before_reset, res_after_reset = reset_after_n_next_calls(repeater_dp, n_elements_before_reset) - self.assertEqual([0, 0, 1, 1], res_before_reset) - self.assertEqual(list(itertools.chain.from_iterable(itertools.repeat(x, 2) for x in range(5))), res_after_reset) - - # __len__ Test: returns correct length - self.assertEqual(10, len(repeater_dp)) - - def test_cycler_iterdatapipe(self) -> None: - source_dp = IterableWrapper(range(5)) - - # Functional Test: cycle for finite number of times and ends - cycler_dp = source_dp.cycle(3) - self.assertEqual(list(range(5)) * 3, list(cycler_dp)) - - # Functional Test: cycle for indefinitely - cycler_dp = source_dp.cycle() - it = iter(cycler_dp) - for expected_val in list(range(5)) * 10: - self.assertEqual(expected_val, next(it)) - - # Functional Test: zero is allowed but immediately triggers StopIteration - cycler_dp = source_dp.cycle(0) - self.assertEqual([], list(cycler_dp)) - - # Functional Test: negative value is not allowed - with self.assertRaisesRegex(ValueError, "Expected non-negative count"): - source_dp.cycle(-1) - - # Reset Test: - cycler_dp = Cycler(source_dp, count=2) - n_elements_before_reset = 4 - res_before_reset, res_after_reset = reset_after_n_next_calls(cycler_dp, n_elements_before_reset) - self.assertEqual(list(range(4)), res_before_reset) - self.assertEqual(list(range(5)) * 2, res_after_reset) - - # __len__ Test: returns length when count is not None - self.assertEqual(10, len(cycler_dp)) - - # __len__ Test: inherits length from source_dp - cycler_dp = Cycler(source_dp) - with self.assertRaisesRegex(TypeError, "instance cycles forever, and therefore doesn't have valid length"): - len(cycler_dp) - - def test_header_iterdatapipe(self) -> None: - # Functional Test: ensure the limit is enforced - source_dp = IterableWrapper(range(20)) - header_dp = source_dp.header(5) - self.assertEqual(list(range(5)), list(header_dp)) - - # Functional Test: ensure it works when the source has less elements than the limit - source_dp = IterableWrapper(range(5)) - header_dp = source_dp.header(100) - self.assertEqual(list(range(5)), list(header_dp)) - - # Functional Test: ensure the source is not modified if limit is set to None - source_dp = IterableWrapper(range(5)) - header_dp = source_dp.header(None) - self.assertEqual(list(range(5)), list(header_dp)) - - # Reset Test: - source_dp = IterableWrapper(range(20)) - header_dp = Header(source_dp, 5) - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(header_dp, n_elements_before_reset) - self.assertEqual(list(range(2)), res_before_reset) - self.assertEqual(list(range(5)), res_after_reset) - self.assertEqual(list(range(5)), list(header_dp)) - - # __len__ Test: returns the limit when it is less than the length of source - self.assertEqual(5, len(header_dp)) - - # __len__ Test: returns the length of source when it is less than the limit - header_dp = source_dp.header(30) - self.assertEqual(20, len(header_dp)) - - # __len__ Test: returns the length of source when limit is set to None - header_dp = source_dp.header(None) - self.assertEqual(20, len(header_dp)) - - # __len__ Test: returns limit if source doesn't have length - source_dp_NoLen = IDP_NoLen(list(range(20))) - header_dp = source_dp_NoLen.header(30) - with warnings.catch_warnings(record=True) as wa: - self.assertEqual(30, len(header_dp)) - self.assertEqual(len(wa), 1) - self.assertRegex( - str(wa[0].message), r"length of this HeaderIterDataPipe is inferred to be equal to its limit" - ) - - # __len__ Test: raises TypeError if source doesn't have length and limit is set to None - header_dp = source_dp_NoLen.header(None) - with self.assertRaisesRegex(TypeError, "The length of this HeaderIterDataPipe cannot be determined."): - len(header_dp) - - # __len__ Test: returns limit if source doesn't have length, even when it has been iterated through once - header_dp = source_dp_NoLen.header(30) - for _ in header_dp: - pass - self.assertEqual(30, len(header_dp)) - - def test_enumerator_iterdatapipe(self) -> None: - letters = "abcde" - source_dp = IterableWrapper(letters) - enum_dp = source_dp.enumerate() - - # Functional Test: ensure that the correct index value is added to each element (tuple) - self.assertEqual([(0, "a"), (1, "b"), (2, "c"), (3, "d"), (4, "e")], list(enum_dp)) - - # Functional Test: start index from non-zero - enum_dp = source_dp.enumerate(starting_index=10) - self.assertEqual([(10, "a"), (11, "b"), (12, "c"), (13, "d"), (14, "e")], list(enum_dp)) - - # Reset Test: - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(enum_dp, n_elements_before_reset) - self.assertEqual([(10, "a"), (11, "b")], res_before_reset) - self.assertEqual([(10, "a"), (11, "b"), (12, "c"), (13, "d"), (14, "e")], res_after_reset) - - # __len__ Test: returns length of source DataPipe - self.assertEqual(5, len(enum_dp)) - - def test_index_adder_iterdatapipe(self) -> None: - letters = "abcdefg" - source_dp = IterableWrapper([{i: i} for i in letters]) - index_adder_dp = source_dp.add_index() - it = iter(index_adder_dp) - - def dict_content_test_helper(iterator): - for i, curr_dict in enumerate(iterator): - self.assertEqual(i, curr_dict["index"]) - self.assertTrue(letters[i] in curr_dict) - - # Functional Test: ensure that the correct index value is added to each element (dict) - dict_content_test_helper(it) - - # Functional Test: raises error when the elements of source_dp is not of type Dict - source_dp = IterableWrapper(range(10)) - index_adder_dp = source_dp.add_index() - it = iter(index_adder_dp) - with self.assertRaisesRegex(NotImplementedError, "We only support adding index to row or batch in dict type"): - next(it) - - # Reset Test - source_dp = IterableWrapper([{i: i} for i in "abcdefg"]) - index_adder_dp = IndexAdder(source_dp) - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(index_adder_dp, n_elements_before_reset) - dict_content_test_helper(iter(res_before_reset)) - dict_content_test_helper(iter(res_after_reset)) - - # __len__ Test: returns length of source DataPipe - self.assertEqual(7, len(index_adder_dp)) - - def test_line_reader_iterdatapipe(self) -> None: - text1 = "Line1\nLine2" - text2 = "Line2,1\r\nLine2,2\r\nLine2,3" - - # Functional Test: read lines correctly - source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))]) - line_reader_dp = source_dp.readlines() - expected_result = [("file1", line) for line in text1.splitlines()] + [ - ("file2", line) for line in text2.splitlines() - ] - self.assertEqual(expected_result, list(line_reader_dp)) - - # Functional Test: strip new lines for bytes - source_dp = IterableWrapper( - [("file1", io.BytesIO(text1.encode("utf-8"))), ("file2", io.BytesIO(text2.encode("utf-8")))] - ) - line_reader_dp = source_dp.readlines() - expected_result_bytes = [("file1", line.encode("utf-8")) for line in text1.splitlines()] + [ - ("file2", line.encode("utf-8")) for line in text2.splitlines() - ] - self.assertEqual(expected_result_bytes, list(line_reader_dp)) - - # Functional Test: do not strip new lines - source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))]) - line_reader_dp = source_dp.readlines(strip_newline=False) - expected_result = [ - ("file1", "Line1\n"), - ("file1", "Line2"), - ("file2", "Line2,1\r\n"), - ("file2", "Line2,2\r\n"), - ("file2", "Line2,3"), - ] - self.assertEqual(expected_result, list(line_reader_dp)) - - # Reset Test: - source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))]) - line_reader_dp = LineReader(source_dp, strip_newline=False) - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(line_reader_dp, n_elements_before_reset) - self.assertEqual(expected_result[:n_elements_before_reset], res_before_reset) - self.assertEqual(expected_result, res_after_reset) - - # __len__ Test: length isn't implemented since it cannot be known ahead of time - with self.assertRaisesRegex(TypeError, "has no len"): - len(line_reader_dp) - - def test_paragraph_aggregator_iterdatapipe(self) -> None: - # Functional Test: aggregate lines correctly - source_dp = IterableWrapper( - [("file1", "Line1"), ("file1", "Line2"), ("file2", "Line2,1"), ("file2", "Line2,2"), ("file2", "Line2,3")] - ) - para_agg_dp = source_dp.lines_to_paragraphs() - self.assertEqual([("file1", "Line1\nLine2"), ("file2", "Line2,1\nLine2,2\nLine2,3")], list(para_agg_dp)) - - # Functional Test: aggregate lines correctly with different joiner - para_agg_dp = source_dp.lines_to_paragraphs(joiner=lambda ls: " ".join(ls)) - self.assertEqual([("file1", "Line1 Line2"), ("file2", "Line2,1 Line2,2 Line2,3")], list(para_agg_dp)) - - # Reset Test: each yield is for a single file - para_agg_dp = ParagraphAggregator(source_dp) - n_elements_before_reset = 1 - res_before_reset, res_after_reset = reset_after_n_next_calls(para_agg_dp, n_elements_before_reset) - self.assertEqual([("file1", "Line1\nLine2")], res_before_reset) - self.assertEqual([("file1", "Line1\nLine2"), ("file2", "Line2,1\nLine2,2\nLine2,3")], res_after_reset) - - # __len__ Test: length isn't implemented since it cannot be known ahead of time - with self.assertRaisesRegex(TypeError, "has no len"): - len(para_agg_dp) - - def test_rows_to_columnar_iterdatapipe(self) -> None: - # Functional Test: working with DataPipe with dict - column_names_dict = {"a", "b", "c"} - source_dp = IterableWrapper( - [ - [{l: i for i, l in enumerate("abc")}, {l: i * 10 for i, l in enumerate("abc")}], - [{l: i + 100 for i, l in enumerate("abc")}, {l: (i + 100) * 10 for i, l in enumerate("abc")}], - ] - ) - result_dp = source_dp.rows2columnar(column_names_dict) - batch1 = defaultdict(list, {"a": [0, 0], "b": [1, 10], "c": [2, 20]}) - batch2 = defaultdict(list, {"a": [100, 1000], "b": [101, 1010], "c": [102, 1020]}) - expected_output = [batch1, batch2] - self.assertEqual(expected_output, list(result_dp)) - - # Functional Test: working with DataPipe with list - column_names_list = ["a", "b", "c"] - source_dp = IterableWrapper( - [ - [[i for i, _ in enumerate("abc")], [i * 10 for i, _ in enumerate("abc")]], - [[i + 100 for i, _ in enumerate("abc")], [(i + 100) * 10 for i, _ in enumerate("abc")]], - ] - ) - result_dp = source_dp.rows2columnar(column_names_list) - self.assertEqual(expected_output, list(result_dp)) - - # Reset Test: - result_dp = Rows2Columnar(source_dp, column_names_list) - n_elements_before_reset = 1 - res_before_reset, res_after_reset = reset_after_n_next_calls(result_dp, n_elements_before_reset) - self.assertEqual([expected_output[0]], res_before_reset) - self.assertEqual(expected_output, res_after_reset) - - # __len__ Test: returns length of source DataPipe - self.assertEqual(2, len(result_dp)) - - def test_sample_multiplexer_iterdatapipe(self) -> None: - # Functional Test: yields all values from the sources - source_dp1 = IterableWrapper([0] * 10) - source_dp2 = IterableWrapper([1] * 10) - d: Dict[IterDataPipe, float] = {source_dp1: 99999999, source_dp2: 0.0000001} - sample_mul_dp = SampleMultiplexer(pipes_to_weights_dict=d, seed=0) - result = list(sample_mul_dp) - self.assertEqual([0] * 10 + [1] * 10, result) - - # Functional Test: raises error for empty dict - with self.assertRaisesRegex(ValueError, "Empty dictionary"): - SampleMultiplexer(pipes_to_weights_dict={}, seed=0) # type: ignore[arg-type] - - # Functional Test: raises error for negative or zero weight - d = {source_dp1: 99999999, source_dp2: 0} - with self.assertRaisesRegex(ValueError, "Expecting a positive and non-zero weight"): - SampleMultiplexer(pipes_to_weights_dict=d, seed=0) - - # Reset Test - d = {source_dp1: 99999999, source_dp2: 0.0000001} - sample_mul_dp = SampleMultiplexer(pipes_to_weights_dict=d, seed=0) - n_elements_before_reset = 5 - res_before_reset, res_after_reset = reset_after_n_next_calls(sample_mul_dp, n_elements_before_reset) - self.assertEqual([0] * n_elements_before_reset, res_before_reset) - self.assertEqual([0] * 10 + [1] * 10, res_after_reset) - - # __len__ Test: returns the sum of the lengths of the sources - self.assertEqual(20, len(sample_mul_dp)) - - def test_in_batch_shuffler_iterdatapipe(self): - input_dp = IterableWrapper(list(range(23))).batch(3) - expected = list(input_dp) - - # Functional Test: No seed - shuffler_dp = input_dp.in_batch_shuffle() - for exp, res in zip(expected, shuffler_dp): - self.assertEqual(sorted(res), exp) - - # Functional Test: With global seed - torch.manual_seed(123) - res = list(shuffler_dp) - torch.manual_seed(123) - self.assertEqual(list(shuffler_dp), res) - - # Functional Test: Set seed - shuffler_dp = input_dp.in_batch_shuffle().set_seed(123) - res = list(shuffler_dp) - shuffler_dp.set_seed(123) - self.assertEqual(list(shuffler_dp), res) - - # Functional Test: deactivate shuffling via set_shuffle - unshuffled_dp = shuffler_dp.set_shuffle(False) - self.assertEqual(list(unshuffled_dp), expected) - - # Reset Test: - shuffler_dp = input_dp.in_batch_shuffle() - n_elements_before_reset = 5 - res_before_reset, res_after_reset = reset_after_n_next_calls(shuffler_dp, n_elements_before_reset) - self.assertEqual(5, len(res_before_reset)) - for exp, res in zip(expected, res_before_reset): - self.assertEqual(sorted(res), exp) - for exp, res in zip(expected, res_after_reset): - self.assertEqual(sorted(res), exp) - - # __len__ Test: returns the length of the input DataPipe - shuffler_dp = input_dp.in_batch_shuffle() - self.assertEqual(8, len(shuffler_dp)) - - # Serialization Test - from torch.utils.data.datapipes._hook_iterator import _SnapshotState - - shuffler_dp = input_dp.in_batch_shuffle() - it = iter(shuffler_dp) - for _ in range(2): - next(it) - shuffler_dp_copy = pickle.loads(pickle.dumps(shuffler_dp)) - _simple_graph_snapshot_restoration(shuffler_dp_copy.datapipe, shuffler_dp.datapipe._number_of_samples_yielded) - - exp = list(it) - shuffler_dp_copy._snapshot_state = _SnapshotState.Restored - self.assertEqual(exp, list(shuffler_dp_copy)) - - def test_bucket_batcher_iterdatapipe(self) -> None: - source_dp = IterableWrapper(range(10)) - - # Functional Test: drop last reduces length - batch_dp = source_dp.bucketbatch( - batch_size=3, drop_last=True, batch_num=100, bucket_num=1, use_in_batch_shuffle=True - ) - self.assertEqual(9, len(list(batch_dp.unbatch()))) - - # Functional Test: drop last is False preserves length - batch_dp = source_dp.bucketbatch( - batch_size=3, drop_last=False, batch_num=100, bucket_num=1, use_in_batch_shuffle=False - ) - self.assertEqual(10, len(list(batch_dp.unbatch()))) - - def _return_self(x): - return x - - # Functional Test: using sort_key, with in_batch_shuffle - batch_dp = source_dp.bucketbatch( - batch_size=3, drop_last=True, batch_num=100, bucket_num=1, use_in_batch_shuffle=True, sort_key=_return_self - ) - # bucket_num = 1 means there will be no shuffling if a sort key is given - self.assertEqual([[0, 1, 2], [3, 4, 5], [6, 7, 8]], list(batch_dp)) - self.assertEqual(9, len(list(batch_dp.unbatch()))) - - # Functional Test: using sort_key, without use_in_batch_shuffle - batch_dp = source_dp.bucketbatch( - batch_size=3, drop_last=True, batch_num=100, bucket_num=2, use_in_batch_shuffle=False, sort_key=_return_self - ) - self.assertEqual(9, len(list(batch_dp.unbatch()))) - - # Reset Test: - batch_dp = BucketBatcher( - source_dp, - batch_size=3, - drop_last=True, - batch_num=100, - bucket_num=2, - use_in_batch_shuffle=False, - sort_key=_return_self, - ) - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(batch_dp, n_elements_before_reset) - self.assertEqual(n_elements_before_reset, len(res_before_reset)) - self.assertEqual(6, len([item for batch in res_before_reset for item in batch])) - self.assertEqual(3, len(res_after_reset)) - self.assertEqual(9, len([item for batch in res_after_reset for item in batch])) - - # __len__ Test: returns the number of batches - with self.assertRaises(TypeError): - len(batch_dp) - - def test_max_token_bucketizer_iterdatapipe(self) -> None: - source_data = ["1" * d for d in range(1, 6)] + ["2" * d for d in range(1, 6)] - source_dp = IterableWrapper(source_data) - - # Functional Test: Invalid arguments - with self.assertRaisesRegex(ValueError, "``min_len`` should be larger than 0"): - source_dp.max_token_bucketize(max_token_count=2, min_len=-1) - - with self.assertRaisesRegex(ValueError, "``min_len`` should be larger than 0"): - source_dp.max_token_bucketize(max_token_count=2, min_len=3, max_len=2) - - with self.assertRaises(ValueError, msg="``max_token_count`` must be equal to or greater than ``max_len``."): - source_dp.max_token_bucketize(max_token_count=2, max_len=3) - - def _validate_batch_size(res, exp_batch_len, len_fn=lambda d: len(d)): - self.assertEqual(len(res), len(exp_batch_len)) - - for batch, exp_token_lens in zip(res, exp_batch_len): - self.assertEqual(len(batch), len(exp_token_lens)) - for token, exp_token_len in zip(batch, exp_token_lens): - self.assertEqual(len_fn(token), exp_token_len) - - # Functional Test: Filter out min_len - batch_dp = source_dp.max_token_bucketize(max_token_count=5, min_len=2, buffer_size=10) - exp_batch_len = [(2, 2), (3,), (3,), (4,), (4,), (5,), (5,)] - _validate_batch_size(list(batch_dp), exp_batch_len) - - # Functional Test: Filter out max_len - batch_dp = source_dp.max_token_bucketize(max_token_count=5, max_len=4, buffer_size=10) - exp_batch_len = [(1, 1, 2), (2, 3), (3,), (4,), (4,)] - _validate_batch_size(list(batch_dp), exp_batch_len) - - def _custom_len_fn(token): - return len(token) + 1 - - # Functional Test: Custom length function - batch_dp = source_dp.max_token_bucketize(max_token_count=7, len_fn=_custom_len_fn, buffer_size=10) - exp_batch_len = [(1, 1, 2), (2, 3), (3,), (4,), (4,), (5,), (5,)] - _validate_batch_size(list(batch_dp), exp_batch_len) - - # Functional Test: Small buffer - batch_dp = source_dp.max_token_bucketize(max_token_count=10, buffer_size=4) - exp_batch_len = [(1, 2, 1, 2, 3), (3, 4), (4, 5), (5,)] - _validate_batch_size(list(batch_dp), exp_batch_len) - - # Reset Test: - batch_dp = MaxTokenBucketizer(source_dp, max_token_count=5, buffer_size=10) - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(batch_dp, n_elements_before_reset) - exp_batch_len_before_reset = [(1, 1, 2), (2, 3)] - exp_batch_len_after_reset = [(1, 1, 2), (2, 3), (3,), (4,), (4,), (5,), (5,)] - _validate_batch_size(res_before_reset, exp_batch_len_before_reset) - _validate_batch_size(res_after_reset, exp_batch_len_after_reset) - - # Functional test: Padded tokens exceeding max_token_count - source_data = ["111", "1111", "11111"] # 3, 4, 5 - source_dp = IterableWrapper(source_data) - batch_dp = source_dp.max_token_bucketize(max_token_count=7) - exp_batch_len = [(3, 4), (5,)] - _validate_batch_size(list(batch_dp), exp_batch_len) - - # Functional test: Padded tokens not exceeding max_token_count - source_data = ["111", "111", "111", "1111"] # 3, 3, 3, 4 - source_dp = IterableWrapper(source_data) - batch_dp = source_dp.max_token_bucketize(max_token_count=7, include_padding=True) - exp_batch_len = [(3, 3), (3,), (4,)] - _validate_batch_size(list(batch_dp), exp_batch_len) - - # Functional test: sample length exceeding max_token_count - source_data = ["111"] - source_dp = IterableWrapper(source_data) - batch_dp = source_dp.max_token_bucketize(max_token_count=2) - exp_batch = [] - self.assertEqual(list(batch_dp), exp_batch) - - # Functional test: incomparable data for heapq - def _custom_len_fn(data): - return data["len"] - - source_data = [{"len": 1}, {"len": 2}, {"len": 1}, {"len": 3}, {"len": 1}] - source_dp = IterableWrapper(source_data) - batch_dp = source_dp.max_token_bucketize(max_token_count=3, len_fn=_custom_len_fn) - exp_batch_len = [(1, 1, 1), (2,), (3,)] - _validate_batch_size(list(batch_dp), exp_batch_len, len_fn=_custom_len_fn) - - # __len__ Test: returns the number of batches - with self.assertRaises(TypeError): - len(batch_dp) - - def test_map_batches_iterdatapipe(self): - source_dp = IterableWrapper(list(range(20))) - - def fn(batch): - return [d + 1 for d in batch] - - batch_mapped_dp = source_dp.map_batches(fn, batch_size=9) - expected_list = list(range(1, 21)) - self.assertEqual(expected_list, list(batch_mapped_dp)) - - # Reset Test: reset the DataPipe after reading part of it - n_elements_before_reset = 5 - res_before_reset, res_after_reset = reset_after_n_next_calls(batch_mapped_dp, n_elements_before_reset) - - self.assertEqual(expected_list[:n_elements_before_reset], res_before_reset) - self.assertEqual(expected_list, res_after_reset) - - # Functional Test: Different sizes between input and output - def fn_less(batch): - return [batch[idx] // 2 for idx in range(0, len(batch), 2)] - - less_batch_mapped_dp = source_dp.map_batches(fn_less, batch_size=8) - self.assertEqual(list(range(10)), list(less_batch_mapped_dp)) - - # Functional Test: Specify input_col - source_dp = IterableWrapper([(d - 1, d, d + 1) for d in range(20)]) - - batch_mapped_input_1_dp = source_dp.map_batches(fn, batch_size=9, input_col=0) - self.assertEqual(list(range(20)), list(batch_mapped_input_1_dp)) - - def fn_2_cols(batch): - return [(d1, d2 - 1) for d1, d2 in batch] - - batch_mapped_input_2_dp = source_dp.map_batches(fn_2_cols, batch_size=9, input_col=[1, 2]) - self.assertEqual([(d, d) for d in range(20)], list(batch_mapped_input_2_dp)) - - # __len__ Test: length should be determined by ``fn`` which we can't know - with self.assertRaisesRegex(TypeError, "length relies on the output of its function."): - len(batch_mapped_dp) - - def test_flatmap_iterdatapipe(self): - source_dp = IterableWrapper(list(range(20))) - - def fn(e): - return [e, e * 10] - - flatmapped_dp = source_dp.flatmap(fn) - expected_list = list(itertools.chain(*[(e, e * 10) for e in source_dp])) - - self.assertEqual(expected_list, list(flatmapped_dp)) - - # Funtional Test: Specify input_col - tuple_source_dp = IterableWrapper([(d - 1, d, d + 1) for d in range(20)]) - - # Single input_col - input_col_1_dp = tuple_source_dp.flatmap(fn, input_col=1) - self.assertEqual(expected_list, list(input_col_1_dp)) - - # Multiple input_col - def mul_fn(a, b): - return [a - b, b - a] - - input_col_2_dp = tuple_source_dp.flatmap(mul_fn, input_col=(0, 2)) - self.assertEqual(list(itertools.chain(*[(-2, 2) for _ in range(20)])), list(input_col_2_dp)) - - # flatmap with no fn specified - default_dp = tuple_source_dp.flatmap() - self.assertEqual(list(itertools.chain(*[(n - 1, n, n + 1) for n in range(20)])), list(default_dp)) - - # flatmap with no fn specified, multiple input_col - default_dp = tuple_source_dp.flatmap(input_col=(0, 2)) - self.assertEqual(list(itertools.chain(*[(n - 1, n + 1) for n in range(20)])), list(default_dp)) - - # flatmap with no fn specified, some special input - tuple_source_dp = IterableWrapper([[1, 2, [3, 4]], [5, 6, [7, 8]]]) - default_dp = tuple_source_dp.flatmap(input_col=(0, 2)) - self.assertEqual([1, [3, 4], 5, [7, 8]], list(default_dp)) - - # Reset Test: reset the DataPipe after reading part of it - n_elements_before_reset = 5 - res_before_reset, res_after_reset = reset_after_n_next_calls(flatmapped_dp, n_elements_before_reset) - - self.assertEqual(expected_list[:n_elements_before_reset], res_before_reset) - self.assertEqual(expected_list, res_after_reset) - - # __len__ Test: length should be len(source_dp)*len(fn->out_shape) which we can't know - with self.assertRaisesRegex(TypeError, "length relies on the output of its function."): - len(flatmapped_dp) - - def test_shuffled_flatmap_iterdatapipe(self): - source_dp = IterableWrapper(list(range(20))) - - def fn(e): - return [e, e * 10] - - # Tests with buffer_size=1 - # In this case, the expected behavior is similar to flatmap - - shuffled_flatmapped_dp = source_dp.shuffled_flatmap(fn, buffer_size=1) - expected_list = list(itertools.chain(*[(e, e * 10) for e in source_dp])) - - self.assertEqual(expected_list, list(shuffled_flatmapped_dp)) - - # Funtional Test: Specify input_col - tuple_source_dp = IterableWrapper([(d - 1, d, d + 1) for d in range(20)]) - - # Single input_col - input_col_1_dp = tuple_source_dp.shuffled_flatmap(fn, input_col=1, buffer_size=1) - self.assertEqual(expected_list, list(input_col_1_dp)) - - # With generator as fn - def gen_fn(e): - yield e - yield e * 10 - - shuffled_flatmapped_dp = source_dp.shuffled_flatmap(gen_fn, buffer_size=1) - expected_list = list(itertools.chain(*[(e, e * 10) for e in source_dp])) - - self.assertEqual(expected_list, list(shuffled_flatmapped_dp)) - - # Multiple input_col - def mul_fn(a, b): - return [a - b, b - a] - - input_col_2_dp = tuple_source_dp.shuffled_flatmap(mul_fn, input_col=(0, 2), buffer_size=1) - self.assertEqual(list(itertools.chain(*[(-2, 2) for _ in range(20)])), list(input_col_2_dp)) - - # shuffled_flatmap with no fn specified - default_dp = tuple_source_dp.shuffled_flatmap(buffer_size=1) - self.assertEqual(list(itertools.chain(*[(n - 1, n, n + 1) for n in range(20)])), list(default_dp)) - - # shuffled_flatmap with no fn specified, multiple input_col - default_dp = tuple_source_dp.shuffled_flatmap(input_col=(0, 2), buffer_size=1) - self.assertEqual(list(itertools.chain(*[(n - 1, n + 1) for n in range(20)])), list(default_dp)) - - # shuffled_flatmap with no fn specified, some special input - tuple_source_dp = IterableWrapper([[1, 2, [3, 4]], [5, 6, [7, 8]]]) - default_dp = tuple_source_dp.shuffled_flatmap(input_col=(0, 2), buffer_size=1) - self.assertEqual([1, [3, 4], 5, [7, 8]], list(default_dp)) - - # Reset Test: reset the DataPipe after reading part of it - n_elements_before_reset = 5 - res_before_reset, res_after_reset = reset_after_n_next_calls(shuffled_flatmapped_dp, n_elements_before_reset) - - self.assertEqual(expected_list[:n_elements_before_reset], res_before_reset) - self.assertEqual(expected_list, res_after_reset) - - # __len__ Test: length should be len(source_dp)*len(fn->out_shape) which we can't know - with self.assertRaisesRegex(TypeError, "length relies on the output of its function."): - len(shuffled_flatmapped_dp) - - # __len__ when no fn specified: - dp = IterableWrapper([[1, 2], [], [3], [4, 5, 6, [7, 8]]]) - dp = dp.shuffled_flatmap() - self.assertEqual(len(dp), 7) - - # Tests with .set_shuffle(False) - # In this case, the expected behavior is similar to flatmap - - shuffled_flatmapped_dp = source_dp.shuffled_flatmap(fn).set_shuffle(False) - expected_list = list(itertools.chain(*[(e, e * 10) for e in source_dp])) - - self.assertEqual(expected_list, list(shuffled_flatmapped_dp)) - - # Funtional Test: Specify input_col - tuple_source_dp = IterableWrapper([(d - 1, d, d + 1) for d in range(20)]) - - # Single input_col - input_col_1_dp = tuple_source_dp.shuffled_flatmap(fn, input_col=1, buffer_size=1) - self.assertEqual(expected_list, list(input_col_1_dp)) - - # Multiple input_col - input_col_2_dp = tuple_source_dp.shuffled_flatmap(mul_fn, input_col=(0, 2)).set_shuffle(False) - self.assertEqual(list(itertools.chain(*[(-2, 2) for _ in range(20)])), list(input_col_2_dp)) - - # shuffled_flatmap with no fn specified - default_dp = tuple_source_dp.shuffled_flatmap().set_shuffle(False) - self.assertEqual(list(itertools.chain(*[(n - 1, n, n + 1) for n in range(20)])), list(default_dp)) - - # shuffled_flatmap with no fn specified, multiple input_col - default_dp = tuple_source_dp.shuffled_flatmap(input_col=(0, 2)).set_shuffle(False) - self.assertEqual(list(itertools.chain(*[(n - 1, n + 1) for n in range(20)])), list(default_dp)) - - # shuffled_flatmap with no fn specified, some special input - tuple_source_dp = IterableWrapper([[1, 2, [3, 4]], [5, 6, [7, 8]]]) - default_dp = tuple_source_dp.shuffled_flatmap(input_col=(0, 2)).set_shuffle(False) - self.assertEqual([1, [3, 4], 5, [7, 8]], list(default_dp)) - - # Reset Test: reset the DataPipe after reading part of it - n_elements_before_reset = 5 - res_before_reset, res_after_reset = reset_after_n_next_calls(shuffled_flatmapped_dp, n_elements_before_reset) - - self.assertEqual(expected_list[:n_elements_before_reset], res_before_reset) - self.assertEqual(expected_list, res_after_reset) - - # Other tests - - # Test no empty buffers: - with self.assertRaises(AssertionError): - _ = source_dp.shuffled_flatmap(buffer_size=0) - - # Functional Test: No seed - consecutive_tuple_source_dp = IterableWrapper([(d, d + 1, d + 2) for d in range(0, 21, 3)]) - shuffled_flatmapped_dp = consecutive_tuple_source_dp.shuffled_flatmap() - self.assertEqual(set(range(21)), set(shuffled_flatmapped_dp)) - - # Functional Test: With global seed - torch.manual_seed(123) - shuffled_flatmapped_dp = tuple_source_dp.shuffled_flatmap() - res = list(shuffled_flatmapped_dp) - torch.manual_seed(123) - self.assertEqual(list(shuffled_flatmapped_dp), res) - - # Functional Test: Set seed - shuffled_flatmapped_dp = tuple_source_dp.shuffled_flatmap().set_seed(123) - res = list(shuffled_flatmapped_dp) - shuffled_flatmapped_dp.set_seed(123) - self.assertEqual(list(shuffled_flatmapped_dp), res) - - # Reset Test: - shuffled_flatmapped_dp = tuple_source_dp.shuffled_flatmap() - n_elements_before_reset = 5 - res_before_reset, res_after_reset = reset_after_n_next_calls(shuffled_flatmapped_dp, n_elements_before_reset) - self.assertEqual(5, len(res_before_reset)) - - def test_round_robin_demux_iterdatapipe(self): - source_dp = IterableWrapper(list(range(23))) - with self.assertRaisesRegex(ValueError, "Expected `num_instaces`"): - _ = source_dp.round_robin_demux(0) - - # Funtional Test - dp1, dp2, dp3 = source_dp.round_robin_demux(3) - self.assertEqual(list(range(0, 23, 3)), list(dp1)) - self.assertEqual(list(range(1, 23, 3)), list(dp2)) - self.assertEqual(list(range(2, 23, 3)), list(dp3)) - - # __len__ Test - self.assertEqual(len(dp1), 8) - self.assertEqual(len(dp2), 8) - self.assertEqual(len(dp3), 7) - - def test_unzipper_iterdatapipe(self): - source_dp = IterableWrapper([(i, i + 10, i + 20) for i in range(10)]) - - # Functional Test: unzips each sequence, no `sequence_length` specified - dp1, dp2, dp3 = UnZipper(source_dp, sequence_length=3) - self.assertEqual(list(range(10)), list(dp1)) - self.assertEqual(list(range(10, 20)), list(dp2)) - self.assertEqual(list(range(20, 30)), list(dp3)) - - # Functional Test: unzips each sequence, with `sequence_length` specified - dp1, dp2, dp3 = source_dp.unzip(sequence_length=3) - self.assertEqual(list(range(10)), list(dp1)) - self.assertEqual(list(range(10, 20)), list(dp2)) - self.assertEqual(list(range(20, 30)), list(dp3)) - - # Functional Test: skipping over specified values - dp2, dp3 = source_dp.unzip(sequence_length=3, columns_to_skip=[0]) - self.assertEqual(list(range(10, 20)), list(dp2)) - self.assertEqual(list(range(20, 30)), list(dp3)) - - (dp2,) = source_dp.unzip(sequence_length=3, columns_to_skip=[0, 2], buffer_size=0) - self.assertEqual(list(range(10, 20)), list(dp2)) - - source_dp = IterableWrapper([(i, i + 10, i + 20, i + 30) for i in range(10)]) - dp2, dp3 = source_dp.unzip(sequence_length=4, columns_to_skip=[0, 3]) - self.assertEqual(list(range(10, 20)), list(dp2)) - self.assertEqual(list(range(20, 30)), list(dp3)) - - # Functional Test: one child DataPipe yields all value first, but buffer_size = 5 being too small, raises error - source_dp = IterableWrapper([(i, i + 10) for i in range(10)]) - dp1, dp2 = source_dp.unzip(sequence_length=2, buffer_size=4) - it1 = iter(dp1) - for _ in range(4): - next(it1) - with self.assertRaises(BufferError): - next(it1) - with self.assertRaises(BufferError): - list(dp2) - - dp1, dp2 = source_dp.unzip(sequence_length=2, buffer_size=4) - with self.assertRaises(BufferError): - list(dp2) - - # Reset Test: DataPipe resets when a new iterator is created, even if this datapipe hasn't been read - dp1, dp2 = source_dp.unzip(sequence_length=2) - _ = iter(dp1) - output2 = [] - with self.assertRaisesRegex(RuntimeError, r"iterator has been invalidated"): - for i, n2 in enumerate(dp2): - output2.append(n2) - if i == 4: - _ = iter(dp1) # This will reset all child DataPipes - self.assertEqual(list(range(10, 15)), output2) - - # Reset Test: DataPipe reset when some of it have been read - dp1, dp2 = source_dp.unzip(sequence_length=2) - output1, output2 = [], [] - for i, (n1, n2) in enumerate(zip(dp1, dp2)): - output1.append(n1) - output2.append(n2) - if i == 4: - with warnings.catch_warnings(record=True) as wa: - _ = iter(dp1) # Reset both all child DataPipe - self.assertEqual(len(wa), 1) - self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted") - break - for n1, n2 in zip(dp1, dp2): - output1.append(n1) - output2.append(n2) - self.assertEqual(list(range(5)) + list(range(10)), output1) - self.assertEqual(list(range(10, 15)) + list(range(10, 20)), output2) - - # Reset Test: DataPipe reset, even when some other child DataPipes are not read - source_dp = IterableWrapper([(i, i + 10, i + 20) for i in range(10)]) - dp1, dp2, dp3 = source_dp.unzip(sequence_length=3) - output1, output2 = list(dp1), list(dp2) - self.assertEqual(list(range(10)), output1) - self.assertEqual(list(range(10, 20)), output2) - with warnings.catch_warnings(record=True) as wa: - self.assertEqual(list(range(10)), list(dp1)) # Resets even though dp3 has not been read - self.assertEqual(len(wa), 1) - self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted") - output3 = [] - for i, n3 in enumerate(dp3): - output3.append(n3) - if i == 4: - with warnings.catch_warnings(record=True) as wa: - output1 = list(dp1) # Resets even though dp3 is only partially read - self.assertEqual(len(wa), 1) - self.assertRegex(str(wa[0].message), r"Some child DataPipes are not exhausted") - self.assertEqual(list(range(20, 25)), output3) - self.assertEqual(list(range(10)), output1) - break - self.assertEqual(list(range(20, 30)), list(dp3)) # dp3 has to read from the start again - - # __len__ Test: Each DataPipe inherits the source datapipe's length - dp1, dp2, dp3 = source_dp.unzip(sequence_length=3) - self.assertEqual(len(source_dp), len(dp1)) - self.assertEqual(len(source_dp), len(dp2)) - self.assertEqual(len(source_dp), len(dp3)) - - def test_itertomap_mapdatapipe(self): - # Functional Test with None key_value_fn - values = list(range(10)) - keys = ["k" + str(i) for i in range(10)] - source_dp = IterableWrapper(list(zip(keys, values))) - - map_dp = source_dp.to_map_datapipe() - self.assertTrue(isinstance(map_dp, MapDataPipe)) - - # Lazy loading - self.assertTrue(map_dp._map is None) - - # __len__ Test: Each DataPipe inherits the source datapipe's length - self.assertEqual(len(map_dp), 10) - - # Functional Test - self.assertEqual(list(range(10)), [map_dp["k" + str(idx)] for idx in range(10)]) - self.assertFalse(map_dp._map is None) - - source_dp = IterableWrapper(range(10)) - - # TypeError test for invalid data type - map_dp = source_dp.to_map_datapipe() - with self.assertRaisesRegex(TypeError, "Cannot convert dictionary update element"): - _ = list(map_dp) - - # ValueError test for wrong length - map_dp = source_dp.to_map_datapipe(lambda d: (d,)) - with self.assertRaisesRegex(ValueError, "dictionary update sequence element has length"): - _ = list(map_dp) - - # Functional Test with key_value_fn - map_dp = source_dp.to_map_datapipe(lambda d: ("k" + str(d), d + 1)) - self.assertEqual(list(range(1, 11)), [map_dp["k" + str(idx)] for idx in range(10)]) - self.assertFalse(map_dp._map is None) - - # No __len__ from prior DataPipe - no_len_dp = source_dp.filter(lambda x: x % 2 == 0) - map_dp = no_len_dp.to_map_datapipe(lambda x: (x, x + 2)) - with warnings.catch_warnings(record=True) as wa: - length = len(map_dp) - self.assertEqual(length, 5) - self.assertEqual(len(wa), 1) - self.assertRegex(str(wa[0].message), r"Data from prior DataPipe") - - # Duplicate Key Test - dup_map_dp = source_dp.to_map_datapipe(lambda x: (x % 1, x)) - with warnings.catch_warnings(record=True) as wa: - dup_map_dp._load_map() - self.assertEqual(len(wa), 1) - self.assertRegex(str(wa[0].message), r"Found duplicate key") - - def test_mux_longest_iterdatapipe(self): - - # Functional Test: Elements are yielded one at a time from each DataPipe, until they are all exhausted - input_dp1 = IterableWrapper(range(4)) - input_dp2 = IterableWrapper(range(4, 8)) - input_dp3 = IterableWrapper(range(8, 12)) - output_dp = input_dp1.mux_longest(input_dp2, input_dp3) - expected_output = [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11] - self.assertEqual(len(expected_output), len(output_dp)) - self.assertEqual(expected_output, list(output_dp)) - - # Functional Test: Uneven input Data Pipes - input_dp1 = IterableWrapper([1, 2, 3, 4]) - input_dp2 = IterableWrapper([10]) - input_dp3 = IterableWrapper([100, 200, 300]) - output_dp = input_dp1.mux_longest(input_dp2, input_dp3) - expected_output = [1, 10, 100, 2, 200, 3, 300, 4] - self.assertEqual(len(expected_output), len(output_dp)) - self.assertEqual(expected_output, list(output_dp)) - - # Functional Test: Empty Data Pipe - input_dp1 = IterableWrapper([0, 1, 2, 3]) - input_dp2 = IterableWrapper([]) - output_dp = input_dp1.mux_longest(input_dp2) - self.assertEqual(len(input_dp1), len(output_dp)) - self.assertEqual(list(input_dp1), list(output_dp)) - - # __len__ Test: raises TypeError when __len__ is called and an input doesn't have __len__ - input_dp1 = IterableWrapper(range(10)) - input_dp_no_len = IDP_NoLen(range(10)) - output_dp = input_dp1.mux_longest(input_dp_no_len) - with self.assertRaises(TypeError): - len(output_dp) - - def test_shard_expand(self): - - # Functional Test: ensure expansion generates the right outputs - def testexpand(s): - stage1 = IterableWrapper([s]) - stage2 = ShardExpander(stage1) - return list(iter(stage2)) - - def myexpand(lo, hi, fmt): - return [fmt.format(i) for i in range(lo, hi)] - - self.assertEqual(testexpand("ds-{000000..000009}.tar"), myexpand(0, 10, "ds-{:06d}.tar")) - self.assertEqual(testexpand("{0..9}"), myexpand(0, 10, "{}")) - self.assertEqual(testexpand("{0..999}"), myexpand(0, 1000, "{}")) - self.assertEqual(testexpand("{123..999}"), myexpand(123, 1000, "{}")) - self.assertEqual(testexpand("{000..999}"), myexpand(0, 1000, "{:03d}")) - with self.assertRaisesRegex(ValueError, r"must not start with 0"): - testexpand("{01..999}") - with self.assertRaisesRegex(ValueError, r"must be shorter"): - testexpand("{0000..999}") - with self.assertRaisesRegex(ValueError, r"bad range"): - testexpand("{999..123}") - self.assertEqual(testexpand("{0..1}{0..1}"), "00 01 10 11".split()) - - def test_combining_infinite_iterdatapipe(self): - r""" - Test combining DataPipe can properly exit at the end of iteration - with an infinite DataPipe as the input. - """ - - def _get_dp(length=10): - source_dp = IterableWrapper(list(range(length))) - inf_dp = IterableWrapper(list(range(length))).cycle() - return source_dp, inf_dp - - # zip - noinf_dp, inf_dp = _get_dp(10) - dp = inf_dp.zip(noinf_dp) - res = list(dp) - self.assertEqual(res, [(i, i) for i in range(10)]) - - # mux - noinf_dp, inf_dp = _get_dp(10) - dp = inf_dp.mux(noinf_dp) - res = list(dp) - self.assertEqual(res, [i for i in range(10) for _ in range(2)]) - - # zip_with_iter - noinf_dp, inf_dp = _get_dp(10) - dp = noinf_dp.zip_with_iter(inf_dp, key_fn=lambda x: x) - res = list(dp) - self.assertEqual(res, [(i, i) for i in range(10)]) - - def test_zip_longest_iterdatapipe(self): - - # Functional Test: raises TypeError when an input is not of type `IterDataPipe` - with self.assertRaises(TypeError): - input_dp1 = IterableWrapper(range(10)) - input_no_dp = list(range(10)) - output_dp = input_dp1.zip_longest(input_no_dp) # type: ignore[arg-type] - - # Functional Test: raises TypeError when an input does not have valid length - input_dp1 = IterableWrapper(range(10)) - input_dp_no_len = IDP_NoLen(range(5)) - output_dp = input_dp1.zip_longest(input_dp_no_len) - with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): - len(output_dp) - - # Functional Test: zips the results properly even when lengths are different - # (zips to the longest, filling missing values with default value None.) - input_dp1 = IterableWrapper(range(10)) - input_dp2 = IterableWrapper(range(5)) - output_dp = input_dp1.zip_longest(input_dp2) - exp = [(i, i) for i in range(5)] + [(i, None) for i in range(5, 10)] - self.assertEqual(list(output_dp), exp) - - # Functional Test: zips the results properly even when lengths are different - # (zips to the longest, filling missing values with user input) - input_dp1 = IterableWrapper(range(10)) - input_dp2 = IterableWrapper(range(5)) - output_dp = input_dp1.zip_longest(input_dp2, fill_value=-1) - exp = [(i, i) for i in range(5)] + [(i, -1) for i in range(5, 10)] - self.assertEqual(list(output_dp), exp) - - # __len__ Test: length matches the length of the shortest input - self.assertEqual(len(output_dp), 10) - - def test_drop_iterdatapipe(self): - # tuple tests - input_dp = IterableWrapper([(0, 1, 2), (3, 4, 5), (6, 7, 8)]) - - # Functional Test: single index drop for tuple elements - drop_dp = input_dp.drop(1) - self.assertEqual([(0, 2), (3, 5), (6, 8)], list(drop_dp)) - - # Functional Test: multiple indices drop for tuple elements - drop_dp = input_dp.drop([0, 2]) - self.assertEqual([(1,), (4,), (7,)], list(drop_dp)) - - # dict tests - input_dp = IterableWrapper([{"a": 1, "b": 2, "c": 3}, {"a": 3, "b": 4, "c": 5}, {"a": 5, "b": 6, "c": 7}]) - - # Functional Test: single key drop for dict elements - drop_dp = input_dp.drop("a") - self.assertEqual([{"b": 2, "c": 3}, {"b": 4, "c": 5}, {"b": 6, "c": 7}], list(drop_dp)) - - # Functional Test: multiple key drop for dict elements - drop_dp = input_dp.drop(["a", "b"]) - self.assertEqual([{"c": 3}, {"c": 5}, {"c": 7}], list(drop_dp)) - - # list tests - input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - - # Functional Test: single key drop for list elements - drop_dp = input_dp.drop(2) - self.assertEqual([[0, 1], [3, 4], [6, 7]], list(drop_dp)) - - # Functional Test: multiple key drop for list elements - drop_dp = input_dp.drop([0, 1]) - self.assertEqual([[2], [5], [8]], list(drop_dp)) - - # Reset Test: - n_elements_before_reset = 2 - input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - drop_dp = input_dp.drop([0, 1]) - expected_res = [[2], [5], [8]] - res_before_reset, res_after_reset = reset_after_n_next_calls(drop_dp, n_elements_before_reset) - self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset) - self.assertEqual(expected_res, res_after_reset) - - # __len__ Test: - input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - drop_dp = input_dp.drop([0, 1]) - self.assertEqual(3, len(drop_dp)) - - def test_slice_iterdatapipe(self): - # tuple tests - input_dp = IterableWrapper([(0, 1, 2), (3, 4, 5), (6, 7, 8)]) - - # Functional Test: slice with no stop and no step for tuple - slice_dp = input_dp.slice(1) - self.assertEqual([(1, 2), (4, 5), (7, 8)], list(slice_dp)) - - # Functional Test: slice with no step for tuple - slice_dp = input_dp.slice(0, 2) - self.assertEqual([(0, 1), (3, 4), (6, 7)], list(slice_dp)) - - # Functional Test: slice with step for tuple - slice_dp = input_dp.slice(0, 2, 2) - self.assertEqual([(0,), (3,), (6,)], list(slice_dp)) - - # Functional Test: slice with list of indices for tuple - slice_dp = input_dp.slice([0, 1]) - self.assertEqual([(0, 1), (3, 4), (6, 7)], list(slice_dp)) - - # list tests - input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - - # Functional Test: slice with no stop and no step for list - slice_dp = input_dp.slice(1) - self.assertEqual([[1, 2], [4, 5], [7, 8]], list(slice_dp)) - - # Functional Test: slice with no step for list - slice_dp = input_dp.slice(0, 2) - self.assertEqual([[0, 1], [3, 4], [6, 7]], list(slice_dp)) - - # Functional Test: slice with list of indices for list - slice_dp = input_dp.slice(0, 2) - self.assertEqual([[0, 1], [3, 4], [6, 7]], list(slice_dp)) - - # dict tests - input_dp = IterableWrapper([{"a": 1, "b": 2, "c": 3}, {"a": 3, "b": 4, "c": 5}, {"a": 5, "b": 6, "c": 7}]) - - # Functional Test: slice with key for dict - slice_dp = input_dp.slice("a") - self.assertEqual([{"a": 1}, {"a": 3}, {"a": 5}], list(slice_dp)) - - # Functional Test: slice with list of keys for dict - slice_dp = input_dp.slice(["a", "b"]) - self.assertEqual([{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}], list(slice_dp)) - - # __len__ Test: - input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - slice_dp = input_dp.slice(0, 2) - self.assertEqual(3, len(slice_dp)) - - # Reset Test: - n_elements_before_reset = 2 - input_dp = IterableWrapper([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - slice_dp = input_dp.slice([2]) - expected_res = [[2], [5], [8]] - res_before_reset, res_after_reset = reset_after_n_next_calls(slice_dp, n_elements_before_reset) - self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset) - self.assertEqual(expected_res, res_after_reset) - - def test_flatten_iterdatapipe(self): - # tuple tests - - # Functional Test: flatten for an index - input_dp = IterableWrapper([(0, 1, (2, 3)), (4, 5, (6, 7)), (8, 9, (10, 11))]) - flatten_dp = input_dp.flatten(2) - self.assertEqual([(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11)], list(flatten_dp)) - - # Functional Test: flatten for list of indices - input_dp = IterableWrapper([((0, 10), 1, (2, 3)), ((4, 14), 5, (6, 7)), ((8, 18), 9, (10, 11))]) - flatten_dp = input_dp.flatten([0, 2]) - self.assertEqual([(0, 10, 1, 2, 3), (4, 14, 5, 6, 7), (8, 18, 9, 10, 11)], list(flatten_dp)) - - # Functional Test: flatten all iters in the datapipe one level (no argument) - input_dp = IterableWrapper([(0, (1, 2)), (3, (4, 5)), (6, (7, 8))]) - flatten_dp = input_dp.flatten() - self.assertEqual([(0, 1, 2), (3, 4, 5), (6, 7, 8)], list(flatten_dp)) - - # list tests - - # Functional Test: flatten for an index - input_dp = IterableWrapper([[0, 1, [2, 3]], [4, 5, [6, 7]], [8, 9, [10, 11]]]) - flatten_dp = input_dp.flatten(2) - self.assertEqual([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], list(flatten_dp)) - - # Functional Test: flatten for list of indices - input_dp = IterableWrapper([[[0, 10], 1, [2, 3]], [[4, 14], 5, [6, 7]], [[8, 18], 9, [10, 11]]]) - flatten_dp = input_dp.flatten([0, 2]) - self.assertEqual([[0, 10, 1, 2, 3], [4, 14, 5, 6, 7], [8, 18, 9, 10, 11]], list(flatten_dp)) - - # Functional Test: flatten all iters in the datapipe one level (no argument) - input_dp = IterableWrapper([[0, [1, 2]], [3, [4, 5]], [6, [7, 8]]]) - flatten_dp = input_dp.flatten() - self.assertEqual([[0, 1, 2], [3, 4, 5], [6, 7, 8]], list(flatten_dp)) - - # Functional Test: string test, flatten all iters in the datapipe one level (no argument) - input_dp = IterableWrapper([["zero", ["one", "2"]], ["3", ["4", "5"]], ["6", ["7", "8"]]]) - flatten_dp = input_dp.flatten() - self.assertEqual([["zero", "one", "2"], ["3", "4", "5"], ["6", "7", "8"]], list(flatten_dp)) - - # dict tests - - # Functional Test: flatten for an index - input_dp = IterableWrapper([{"a": 1, "b": 2, "c": {"d": 3, "e": 4}}, {"a": 5, "b": 6, "c": {"d": 7, "e": 8}}]) - flatten_dp = input_dp.flatten("c") - self.assertEqual([{"a": 1, "b": 2, "d": 3, "e": 4}, {"a": 5, "b": 6, "d": 7, "e": 8}], list(flatten_dp)) - - # Functional Test: flatten for an index already flat - input_dp = IterableWrapper([{"a": 1, "b": 2, "c": {"d": 9, "e": 10}}, {"a": 5, "b": 6, "c": {"d": 7, "e": 8}}]) - flatten_dp = input_dp.flatten("a") - self.assertEqual( - [{"a": 1, "b": 2, "c": {"d": 9, "e": 10}}, {"a": 5, "b": 6, "c": {"d": 7, "e": 8}}], list(flatten_dp) - ) - - # Functional Test: flatten for list of indices - input_dp = IterableWrapper( - [ - {"a": {"f": 10, "g": 11}, "b": 2, "c": {"d": 3, "e": 4}}, - {"a": {"f": 10, "g": 11}, "b": 6, "c": {"d": 7, "e": 8}}, - ] - ) - flatten_dp = input_dp.flatten(["a", "c"]) - self.assertEqual( - [{"f": 10, "g": 11, "b": 2, "d": 3, "e": 4}, {"f": 10, "g": 11, "b": 6, "d": 7, "e": 8}], list(flatten_dp) - ) - - # Functional Test: flatten all iters in the datapipe one level (no argument) - input_dp = IterableWrapper([{"a": 1, "b": 2, "c": {"d": 3, "e": 4}}, {"a": 5, "b": 6, "c": {"d": 7, "e": 8}}]) - flatten_dp = input_dp.flatten() - self.assertEqual([{"a": 1, "b": 2, "d": 3, "e": 4}, {"a": 5, "b": 6, "d": 7, "e": 8}], list(flatten_dp)) - - # Functional Test: flatten all iters one level, multiple iters - input_dp = IterableWrapper( - [ - {"a": {"f": 10, "g": 11}, "b": 2, "c": {"d": 3, "e": 4}}, - {"a": {"f": 10, "g": 11}, "b": 6, "c": {"d": 7, "e": 8}}, - ] - ) - flatten_dp = input_dp.flatten() - self.assertEqual( - [{"f": 10, "g": 11, "b": 2, "d": 3, "e": 4}, {"f": 10, "g": 11, "b": 6, "d": 7, "e": 8}], list(flatten_dp) - ) - - # __len__ Test: - input_dp = IterableWrapper([(0, 1, (2, 3)), (4, 5, (6, 7)), (8, 9, (10, 11))]) - flatten_dp = input_dp.flatten(2) - self.assertEqual(3, len(flatten_dp)) - - # Reset Test: - n_elements_before_reset = 2 - input_dp = IterableWrapper([(0, 1, (2, 3)), (4, 5, (6, 7)), (8, 9, (10, 11))]) - flatten_dp = input_dp.flatten(2) - expected_res = [(0, 1, 2, 3), (4, 5, 6, 7), (8, 9, 10, 11)] - res_before_reset, res_after_reset = reset_after_n_next_calls(flatten_dp, n_elements_before_reset) - self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset) - self.assertEqual(expected_res, res_after_reset) - - def test_length_setter_iterdatapipe(self): - input_dp = IterableWrapper(range(10)) - - # Functional Test: Setting length doesn't change the content of the DataPipe - dp: IterDataPipe = input_dp.set_length(3) - self.assertEqual(list(range(10)), list(dp)) - - with self.assertRaises(AssertionError): - input_dp.set_length(-1) - - # __len__ Test: Length is as specified and propagates through - dp = input_dp.set_length(3).map(lambda x: x + 1) - self.assertEqual(3, len(dp)) - - # Reset Test: - n_elements_before_reset = 2 - dp = input_dp.set_length(3) - expected_res = list(range(10)) - res_before_reset, res_after_reset = reset_after_n_next_calls(dp, n_elements_before_reset) - self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset) - self.assertEqual(expected_res, res_after_reset) - - def test_random_splitter_iterdatapipe(self): - - n_epoch = 2 - - # Functional Test: Split results are the same across epochs - dp = IterableWrapper(range(10)) - train, valid = dp.random_split(total_length=10, weights={"train": 0.5, "valid": 0.5}, seed=0) - results = [] - for _ in range(n_epoch): - res = list(train) - self.assertEqual(5, len(res)) - results.append(res) - self.assertEqual(results[0], results[1]) - valid_res = list(valid) - self.assertEqual(5, len(valid_res)) - self.assertEqual(list(range(10)), sorted(results[0] + valid_res)) - - # Functional Test: lengths can be known in advance because it splits evenly into integers. - self.assertEqual(5, len(train)) - self.assertEqual(5, len(valid)) - - # Functional Test: DataPipe can split into 3 DataPipes, and infer `total_length` when not given - dp = IterableWrapper(range(10)) - train, valid, test = dp.random_split(weights={"train": 0.6, "valid": 0.2, "test": 0.2}, seed=0) - results = [] - for _ in range(n_epoch): - res = list(train) - self.assertEqual(6, len(res)) - results.append(res) - self.assertEqual(results[0], results[1]) - valid_res = list(valid) - self.assertEqual(2, len(valid_res)) - test_res = list(test) - self.assertEqual(2, len(test_res)) - self.assertEqual(list(range(10)), sorted(results[0] + valid_res + test_res)) - - # Functional Test: lengths can be known in advance because it splits evenly into integers. - self.assertEqual(6, len(train)) - self.assertEqual(2, len(valid)) - self.assertEqual(2, len(test)) - - # Functional Test: Split can work even when weights do not split evenly into integers. - dp = IterableWrapper(range(13)) - train, valid, test = dp.random_split(weights={"train": 0.6, "valid": 0.2, "test": 0.2}, seed=0) - res = list(train) + list(valid) + list(test) - self.assertEqual(list(range(13)), sorted(res)) - - # Functional Test: lengths can be known in advance because it splits evenly into integers. - with self.assertRaisesRegex(TypeError, "Lengths of the split cannot be known in advance"): - len(train) - - # Functional Test: Error when `total_length` cannot be inferred - nolen_dp = IDP_NoLen(range(10)) - with self.assertRaisesRegex(TypeError, "needs `total_length`"): - _, __ = nolen_dp.random_split(weights={"train": 0.5, "valid": 0.5}, seed=0) # type: ignore[call-arg] - - # Functional Test: `target` must match a key in the `weights` dict - dp = IterableWrapper(range(10)) - with self.assertRaisesRegex(KeyError, "does not match any key"): - _ = dp.random_split( - total_length=10, weights={"train": 0.5, "valid": 0.2, "test": 0.2}, seed=0, target="NOTINDICT" - ) - - # Functional Test: `target` is specified, and match the results from before - dp = IterableWrapper(range(10)) - train = dp.random_split( - total_length=10, weights={"train": 0.6, "valid": 0.2, "test": 0.2}, seed=0, target="train" - ) - results2 = [] - for _ in range(n_epoch): - res = list(train) - self.assertEqual(6, len(res)) - results2.append(res) - self.assertEqual(results2[0], results2[1]) - self.assertEqual(results, results2) - - # Functional Test: `override_seed` works and change split result - train.override_seed(1) - seed_1_res = list(train) - self.assertNotEqual(results2[0], seed_1_res) - - # Functional Test: `override_seed` doesn't impact the current iteration, only the next one - temp_res = [] - for i, x in enumerate(train): - temp_res.append(x) - if i == 3: - train.override_seed(0) - self.assertEqual(seed_1_res, temp_res) # The current iteration should equal seed 1 result - self.assertEqual(results2[0], list(train)) # The next iteration should equal seed 0 result - - # Functional Test: Raise exception if both children are used at the same time - dp = IterableWrapper(range(10)) - train, valid = dp.random_split(total_length=10, weights={"train": 0.5, "valid": 0.5}, seed=0) - it_train = iter(train) - next(it_train) - it_valid = iter(valid) # This resets the DataPipe and invalidates the other iterator - next(it_valid) - with self.assertRaisesRegex(RuntimeError, "iterator has been invalidated"): - next(it_train) - next(it_valid) # No error, can keep going - - @skipIfNoCUDA - def test_pin_memory(self): - # Tensor - dp = IterableWrapper([(i, i + 1) for i in range(10)]).map(_convert_to_tensor).pin_memory() - self.assertTrue(all(d.is_pinned() for d in dp)) - - # List of Tensors - dp = IterableWrapper([[(i - 1, i), (i, i + 1)] for i in range(10)]).map(_convert_to_tensor).pin_memory() - self.assertTrue(all(d0.is_pinned() and d1.is_pinned() for d0, d1 in dp)) - - # Dict of Tensors - dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).pin_memory() - self.assertTrue(all(v.is_pinned() for d in dp for v in d.values())) - - # NamedTuple - dp = IterableWrapper([NamedTensors(torch.tensor(i), torch.tensor(i + 1)) for i in range(10)]).pin_memory() - self.assertTrue(all(v.is_pinned() for d in dp for v in d)) - - # Dict of List of Tensors - dp = ( - IterableWrapper([{str(i): [(i - 1, i), (i, i + 1)]} for i in range(10)]) - .map(_convert_to_tensor) - .pin_memory() - ) - self.assertTrue(all(v.is_pinned() for d in dp for batch in d.values() for v in batch)) - - # List of Dict of Tensors - dp = IterableWrapper([{str(i): (i, i + 1)} for i in range(10)]).map(_convert_to_tensor).batch(2).pin_memory() - self.assertTrue(all(v.is_pinned() for batch in dp for d in batch for v in d.values())) - - # List of List of Tensors - dp = ( - IterableWrapper([[(i - 1, i), (i, i + 1)] for i in range(10)]).map(_convert_to_tensor).batch(2).pin_memory() - ) - self.assertTrue(all(d0.is_pinned() and d1.is_pinned() for batch in dp for d0, d1 in batch)) - - # Single str - dp = IterableWrapper(["hello", "world"]).batch(1).collate().pin_memory() - self.assertEqual(list(dp), [["hello"], ["world"]]) - - def test_async_map_batches(self): - batch_size = 16 - - def _helper(input_data, exp_res, async_fn, input_col=None, output_col=None, max_concurrency=32, flatten=True): - dp = IterableWrapper(input_data) - dp = dp.async_map_batches(async_fn, batch_size, input_col, output_col, max_concurrency, flatten) - self.assertEqual( - exp_res, - list(dp), - msg=f"Async map test with {async_fn=}, {input_col=}, {output_col=}, {max_concurrency=}", - ) - if flatten: - self.assertEqual(len(input_data), len(dp)) - - _helper(range(50), [i * 10 for i in range(50)], _async_mul_ten) - - # Smaller max_concurrency - _helper(range(50), [i * 10 for i in range(50)], _async_mul_ten, max_concurrency=6) - - # Tuple with input_col - _helper([(i, i) for i in range(50)], [(i * 10, i) for i in range(50)], _async_mul_ten, input_col=0) - _helper([(i, i) for i in range(50)], [(i, i * 10) for i in range(50)], _async_mul_ten, input_col=1) - # Tuple with input_col and output_col - _helper( - [(i, i) for i in range(50)], [(i, i * 10) for i in range(50)], _async_mul_ten, input_col=0, output_col=1 - ) - _helper( - [(i, i) for i in range(50)], [(i, i, i * 10) for i in range(50)], _async_mul_ten, input_col=0, output_col=-1 - ) - - # Dict with input_col - _helper( - [{"a": i, "b": i} for i in range(50)], - [{"a": i, "b": i * 10} for i in range(50)], - _async_mul_ten, - input_col="b", - ) - # Dict with input_col and output_col - _helper( - [{"a": i, "b": i} for i in range(50)], - [{"a": i * 10, "b": i} for i in range(50)], - _async_mul_ten, - input_col="b", - output_col="a", - ) - _helper( - [{"a": i, "b": i} for i in range(50)], - [{"a": i, "b": i, "c": i * 10} for i in range(50)], - _async_mul_ten, - input_col="b", - output_col="c", - ) - - # Multiple input_col - _helper( - [(i - 1, i, i + 1) for i in range(50)], - [((i - 1) * (i + 1), i) for i in range(50)], - _async_x_mul_y, - input_col=(0, 2), - ) - _helper( - [(i - 1, i, i + 1) for i in range(50)], - [(i, (i - 1) * (i + 1)) for i in range(50)], - _async_x_mul_y, - input_col=(2, 0), - ) - # Multiple input_col with output_col - _helper( - [(i - 1, i, i + 1) for i in range(50)], - [(i - 1, (i - 1) * (i + 1), i + 1) for i in range(50)], - _async_x_mul_y, - input_col=(0, 2), - output_col=1, - ) - # Skip over `flatten` operation - _helper( - range(32), - [[i * 10 for i in range(16)], [i * 10 for i in range(16, 32)]], - _async_mul_ten, - flatten=False, - ) - - # Test multiple asyncio eventloops - dp1 = IterableWrapper(range(50)) - dp1 = dp1.async_map_batches(_async_mul_ten, 16) - dp2 = IterableWrapper(range(50)) - dp2 = dp2.async_map_batches(_async_mul_ten, 16) - for v1, v2, exp in zip(dp1, dp2, [i * 10 for i in range(50)]): - self.assertEqual(v1, exp) - self.assertEqual(v2, exp) - - def test_threadpool_map(self): - target_length = 30 - input_dp = IterableWrapper(range(target_length)) - input_dp_parallel = IterableWrapper(range(target_length)) - - def fn(item, dtype=torch.float, *, sum=False): - data = torch.tensor(item, dtype=dtype) - return data if not sum else data.sum() - - # Functional Test: apply to each element correctly - map_dp = input_dp.threadpool_map(fn) - self.assertEqual(target_length, len(map_dp)) - for x, y in zip(map_dp, range(target_length)): - self.assertEqual(x, torch.tensor(y, dtype=torch.float)) - - # Functional Test: works with partial function - map_dp = input_dp.threadpool_map(partial(fn, dtype=torch.int, sum=True)) - for x, y in zip(map_dp, range(target_length)): - self.assertEqual(x, torch.tensor(y, dtype=torch.int).sum()) - - # __len__ Test: inherits length from source DataPipe - self.assertEqual(target_length, len(map_dp)) - - input_dp_nl = IDP_NoLen(range(target_length)) - map_dp_nl = input_dp_nl.threadpool_map(lambda x: x) - for x, y in zip(map_dp_nl, range(target_length)): - self.assertEqual(x, torch.tensor(y, dtype=torch.float)) - - # __len__ Test: inherits length from source DataPipe - raises error when invalid - with self.assertRaisesRegex(TypeError, r"instance doesn't have valid length$"): - len(map_dp_nl) - - # Test: two independent ThreadPoolExecutors running at the same time - map_dp_parallel = input_dp_parallel.threadpool_map(fn) - for x, y, z in zip(map_dp, map_dp_parallel, range(target_length)): - self.assertEqual(x, torch.tensor(z, dtype=torch.float)) - self.assertEqual(y, torch.tensor(z, dtype=torch.float)) - - # Reset Test: DataPipe resets properly - n_elements_before_reset = 5 - res_before_reset, res_after_reset = reset_after_n_next_calls(map_dp, n_elements_before_reset) - self.assertEqual(list(range(n_elements_before_reset)), res_before_reset) - self.assertEqual(list(range(target_length)), res_after_reset) - - @suppress_warnings # Suppress warning for lambda fn - def test_threadpool_map_tuple_list_with_col_iterdatapipe(self): - def fn_11(d): - return -d - - def fn_1n(d): - return -d, d - - def fn_n1(d0, d1): - return d0 + d1 - - def fn_nn(d0, d1): - return -d0, -d1, d0 + d1 - - def fn_n1_def(d0, d1=1): - return d0 + d1 - - def fn_n1_kwargs(d0, d1, **kwargs): - return d0 + d1 - - def fn_n1_pos(d0, d1, *args): - return d0 + d1 - - def fn_n1_sep_pos(d0, *args, d1): - return d0 + d1 - - def fn_cmplx(d0, d1=1, *args, d2, **kwargs): - return d0 + d1 - - p_fn_n1 = partial(fn_n1, d1=1) - p_fn_cmplx = partial(fn_cmplx, d2=2) - - def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): - for constr in (list, tuple): - datapipe = IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))]) - if ref_fn is None: - with self.assertRaises(error): - res_dp = datapipe.threadpool_map(fn, input_col, output_col) - list(res_dp) - else: - res_dp = datapipe.threadpool_map(fn, input_col, output_col) - ref_dp = datapipe.map(ref_fn) - if constr is list: - ref_dp = ref_dp.map(list) - self.assertEqual(list(res_dp), list(ref_dp), "First test failed") - # Reset - self.assertEqual(list(res_dp), list(ref_dp), "Test after reset failed") - - _helper(lambda data: data, fn_n1_def, 0, 1) - _helper(lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2) - _helper(lambda data: data, p_fn_n1, 0, 1) - _helper(lambda data: data, p_fn_cmplx, 0, 1) - _helper(lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2) - _helper(lambda data: (data[0] + data[1],), fn_n1_pos, [0, 1, 2]) - - # Replacing with one input column and default output column - _helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1) - _helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1) - # The index of input column is out of range - _helper(None, fn_1n, 3, error=IndexError) - # Unmatched input columns with fn arguments - _helper(None, fn_n1, 1, error=ValueError) - _helper(None, fn_n1, [0, 1, 2], error=ValueError) - _helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError) - _helper(None, lambda d0, d1: d0 + d1, [0, 1, 2], error=ValueError) - _helper(None, fn_cmplx, 0, 1, ValueError) - _helper(None, fn_n1_pos, 1, error=ValueError) - _helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError) - _helper(None, p_fn_n1, [0, 1], error=ValueError) - _helper(None, fn_1n, [1, 2], error=ValueError) - # _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError) - _helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError) - # Fn has keyword-only arguments - _helper(None, fn_n1_kwargs, 1, error=ValueError) - _helper(None, fn_cmplx, [0, 1], 2, ValueError) - - # Replacing with multiple input columns and default output column (the left-most input column) - _helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0]) - _helper(lambda data: (data[0], (-data[2], -data[1], data[2] + data[1])), fn_nn, [2, 1]) - - # output_col can only be specified when input_col is not None - _helper(None, fn_n1, None, 1, error=ValueError) - # output_col can only be single-element list or tuple - _helper(None, fn_n1, None, [0, 1], error=ValueError) - # Single-element list as output_col - _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, [0]) - # Replacing with one input column and single specified output column - _helper(lambda data: (-data[1], data[1], data[2]), fn_11, 1, 0) - _helper(lambda data: (data[0], data[1], (-data[1], data[1])), fn_1n, 1, 2) - # The index of output column is out of range - _helper(None, fn_1n, 1, 3, error=IndexError) - _helper(lambda data: (data[0], data[0] + data[2], data[2]), fn_n1, [0, 2], 1) - _helper(lambda data: ((-data[1], -data[2], data[1] + data[2]), data[1], data[2]), fn_nn, [1, 2], 0) - - # Appending the output at the end - _helper(lambda data: (*data, -data[1]), fn_11, 1, -1) - _helper(lambda data: (*data, (-data[1], data[1])), fn_1n, 1, -1) - _helper(lambda data: (*data, data[0] + data[2]), fn_n1, [0, 2], -1) - _helper(lambda data: (*data, (-data[1], -data[2], data[1] + data[2])), fn_nn, [1, 2], -1) - - # Handling built-in functions (e.g. `dict`, `iter`, `int`, `str`) whose signatures cannot be inspected - _helper(lambda data: (str(data[0]), data[1], data[2]), str, 0) - _helper(lambda data: (data[0], data[1], int(data[2])), int, 2) - - @suppress_warnings # Suppress warning for lambda fn - def test_threadpool_map_dict_with_col_iterdatapipe(self): - def fn_11(d): - return -d - - def fn_1n(d): - return -d, d - - def fn_n1(d0, d1): - return d0 + d1 - - def fn_nn(d0, d1): - return -d0, -d1, d0 + d1 - - def fn_n1_def(d0, d1=1): - return d0 + d1 - - p_fn_n1 = partial(fn_n1, d1=1) - - def fn_n1_pos(d0, d1, *args): - return d0 + d1 - - def fn_n1_kwargs(d0, d1, **kwargs): - return d0 + d1 - - def fn_kwonly(*, d0, d1): - return d0 + d1 - - def fn_has_nondefault_kwonly(d0, *, d1): - return d0 + d1 - - def fn_cmplx(d0, d1=1, *args, d2, **kwargs): - return d0 + d1 - - p_fn_cmplx = partial(fn_cmplx, d2=2) - - # Prevent modification in-place to support resetting - def _dict_update(data, newdata, remove_idx=None): - _data = dict(data) - _data.update(newdata) - if remove_idx: - for idx in remove_idx: - del _data[idx] - return _data - - def _helper(ref_fn, fn, input_col=None, output_col=None, error=None): - datapipe = IterableWrapper([{"x": 0, "y": 1, "z": 2}, {"x": 3, "y": 4, "z": 5}, {"x": 6, "y": 7, "z": 8}]) - if ref_fn is None: - with self.assertRaises(error): - res_dp = datapipe.threadpool_map(fn, input_col, output_col) - list(res_dp) - else: - res_dp = datapipe.threadpool_map(fn, input_col, output_col) - ref_dp = datapipe.map(ref_fn) - self.assertEqual(list(res_dp), list(ref_dp), "First test failed") - # Reset - self.assertEqual(list(res_dp), list(ref_dp), "Test after reset failed") - - _helper(lambda data: data, fn_n1_def, "x", "y") - _helper(lambda data: data, p_fn_n1, "x", "y") - _helper(lambda data: data, p_fn_cmplx, "x", "y") - _helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), p_fn_cmplx, ["x", "y", "z"], "z") - - _helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), fn_n1_def, ["x", "y"], "z") - - _helper(None, fn_n1_pos, "x", error=ValueError) - _helper(None, fn_n1_kwargs, "x", error=ValueError) - # non-default kw-only args - _helper(None, fn_kwonly, ["x", "y"], error=ValueError) - _helper(None, fn_has_nondefault_kwonly, ["x", "y"], error=ValueError) - _helper(None, fn_cmplx, ["x", "y"], error=ValueError) - - # Replacing with one input column and default output column - _helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y") - _helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y") - # The key of input column is not in dict - _helper(None, fn_1n, "a", error=KeyError) - # Unmatched input columns with fn arguments - _helper(None, fn_n1, "y", error=ValueError) - _helper(None, fn_1n, ["x", "y"], error=ValueError) - _helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError) - _helper(None, p_fn_n1, ["x", "y"], error=ValueError) - _helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError) - # Replacing with multiple input columns and default output column (the left-most input column) - _helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"]) - _helper( - lambda data: _dict_update(data, {"z": (-data["z"], -data["y"], data["y"] + data["z"])}, ["y"]), - fn_nn, - ["z", "y"], - ) - - # output_col can only be specified when input_col is not None - _helper(None, fn_n1, None, "x", error=ValueError) - # output_col can only be single-element list or tuple - _helper(None, fn_n1, None, ["x", "y"], error=ValueError) - # Single-element list as output_col - _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", ["x"]) - # Replacing with one input column and single specified output column - _helper(lambda data: _dict_update(data, {"x": -data["y"]}), fn_11, "y", "x") - _helper(lambda data: _dict_update(data, {"z": (-data["y"], data["y"])}), fn_1n, "y", "z") - _helper(lambda data: _dict_update(data, {"y": data["x"] + data["z"]}), fn_n1, ["x", "z"], "y") - _helper( - lambda data: _dict_update(data, {"x": (-data["y"], -data["z"], data["y"] + data["z"])}), - fn_nn, - ["y", "z"], - "x", - ) - - # Adding new key to dict for the output - _helper(lambda data: _dict_update(data, {"a": -data["y"]}), fn_11, "y", "a") - _helper(lambda data: _dict_update(data, {"a": (-data["y"], data["y"])}), fn_1n, "y", "a") - _helper(lambda data: _dict_update(data, {"a": data["x"] + data["z"]}), fn_n1, ["x", "z"], "a") - _helper( - lambda data: _dict_update(data, {"a": (-data["y"], -data["z"], data["y"] + data["z"])}), - fn_nn, - ["y", "z"], - "a", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_linter.py b/test/test_linter.py deleted file mode 100644 index 48f5e0a09..000000000 --- a/test/test_linter.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import unittest - -from torchdata.dataloader2.linter import _check_shuffle_before_sharding - -from torchdata.datapipes.iter import IterableWrapper, ShardingFilter, Shuffler - - -def dummy_fn(x): - return x - - -class LinterTest(unittest.TestCase): - def test_sharding_shuffle(self): - source_dp = IterableWrapper(list(range(20))) - - # Single path - dp = source_dp.map(dummy_fn).shuffle() - self.assertTrue(_check_shuffle_before_sharding(dp)) - dp = source_dp.map(dummy_fn) - self.assertTrue(_check_shuffle_before_sharding(dp)) - - dp = source_dp.map(dummy_fn).shuffle().sharding_filter() - self.assertTrue(_check_shuffle_before_sharding(dp)) - - dp = source_dp.map(dummy_fn).sharding_filter() - self.assertFalse(_check_shuffle_before_sharding(dp)) - - dp = source_dp.map(dummy_fn).sharding_filter().shuffle() - self.assertFalse(_check_shuffle_before_sharding(dp)) - - # Multi pathes - def _multi_path_dp_1(shuffle): - s_dp = source_dp.shuffle() if shuffle else source_dp - dp1, dp2 = s_dp.unzip(2) - dp1 = dp1.sharding_filter() - dp2 = dp2.map(dummy_fn).sharding_filter() - dp = dp1.zip(dp2) - return dp - - self.assertTrue(_check_shuffle_before_sharding(_multi_path_dp_1(True))) - self.assertFalse(_check_shuffle_before_sharding(_multi_path_dp_1(False))) - - def _multi_path_dp_2(shuffle): - s_dp = source_dp.shuffle() if shuffle else source_dp - dp1, dp2 = s_dp.unzip(2) - dp1 = dp1.map(dummy_fn) - dp = dp1.zip(dp2).sharding_filter() - return dp - - self.assertTrue(_check_shuffle_before_sharding(_multi_path_dp_2(True))) - self.assertFalse(_check_shuffle_before_sharding(_multi_path_dp_2(False))) - - def _multi_path_dp_3(shuffle): - dp1, dp2 = source_dp.unzip(2) - dp1 = dp1.shuffle() if shuffle else dp1 - dp1 = dp1.map(dummy_fn).sharding_filter() - dp2 = dp2.shuffle() if shuffle else dp1 - dp2 = dp2.sharding_filter() - dp = dp1.zip(dp2).map(dummy_fn) - return dp - - self.assertTrue(_check_shuffle_before_sharding(_multi_path_dp_3(True))) - self.assertFalse(_check_shuffle_before_sharding(_multi_path_dp_3(False))) - - # Partial paths - dp1, dp2 = source_dp.unzip(2) - dp1 = dp1.shuffle().map(dummy_fn) - dp = dp1.zip(dp2).sharding_filter() - - self.assertFalse(_check_shuffle_before_sharding(dp)) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_local_io.py b/test/test_local_io.py deleted file mode 100644 index 313b5e361..000000000 --- a/test/test_local_io.py +++ /dev/null @@ -1,923 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import bz2 -import functools -import hashlib -import io -import itertools -import lzma -import os -import subprocess -import tarfile -import tempfile -import time -import unittest -import warnings -import zipfile -from functools import partial - -from json.decoder import JSONDecodeError - -import expecttest - -from _utils._common_utils_for_test import create_temp_dir, create_temp_files, get_name, reset_after_n_next_calls - -from torch.utils.data import DataLoader - -from torchdata.dataloader2.adapter import CacheTimeout -from torchdata.datapipes.iter import ( - Bz2FileLoader, - CSVDictParser, - CSVParser, - Decompressor, - FileLister, - FileOpener, - HashChecker, - IoPathFileLister, - IoPathFileOpener, - IoPathSaver, - IterableWrapper, - IterDataPipe, - JsonParser, - RarArchiveLoader, - Saver, - StreamReader, - TarArchiveLoader, - WebDataset, - XzFileLoader, - ZipArchiveLoader, -) - -try: - import iopath - import torch - - HAS_IOPATH = True -except ImportError: - HAS_IOPATH = False -skipIfNoIoPath = unittest.skipIf(not HAS_IOPATH, "no iopath") - -try: - import rarfile - - HAS_RAR_TOOLS = True - try: - rarfile.tool_setup() - subprocess.run(("rar", "-?"), check=True) - except (rarfile.RarCannotExec, subprocess.CalledProcessError): - HAS_RAR_TOOLS = False -except (ModuleNotFoundError, FileNotFoundError): - HAS_RAR_TOOLS = False -skipIfNoRarTools = unittest.skipIf(not HAS_RAR_TOOLS, "no rar tools") - -try: - import portalocker - - HAS_PORTALOCKER = True -except ImportError: - HAS_PORTALOCKER = False -skipIfNoPortalocker = unittest.skipIf(not HAS_PORTALOCKER, "No portalocker installed") - - -def filepath_fn(temp_dir_name, name: str) -> str: - return os.path.join(temp_dir_name, os.path.basename(name)) - - -def _unbatch(x): - return x[0] - - -def _noop(x): - return x - - -class TestDataPipeLocalIO(expecttest.TestCase): - def setUp(self): - self.temp_dir = create_temp_dir() - self.temp_files = create_temp_files(self.temp_dir) - self.temp_sub_dir = create_temp_dir(self.temp_dir.name) - self.temp_sub_files = create_temp_files(self.temp_sub_dir, 4, False) - - self.temp_dir_2 = create_temp_dir() - self.temp_files_2 = create_temp_files(self.temp_dir_2) - self.temp_sub_dir_2 = create_temp_dir(self.temp_dir_2.name) - self.temp_sub_files_2 = create_temp_files(self.temp_sub_dir_2, 4, False) - - def tearDown(self): - try: - self.temp_sub_dir.cleanup() - self.temp_dir.cleanup() - self.temp_sub_dir_2.cleanup() - self.temp_dir_2.cleanup() - except Exception as e: - warnings.warn(f"TestDataPipeLocalIO was not able to cleanup temp dir due to {e}") - - def _custom_files_set_up(self, files): - for fname, content in files.items(): - temp_file_path = os.path.join(self.temp_dir.name, fname) - with open(temp_file_path, "w") as f: - f.write(content) - - def _compressed_files_comparison_helper(self, expected_files, result, check_length: bool = True): - if check_length: - self.assertEqual(len(expected_files), len(result)) - for res, expected_file in itertools.zip_longest(result, expected_files): - self.assertTrue(res is not None and expected_file is not None) - self.assertEqual(os.path.basename(res[0]), os.path.basename(expected_file)) - with open(expected_file, "rb") as f: - self.assertEqual(res[1].read(), f.read()) - res[1].close() - - def _unordered_compressed_files_comparison_helper(self, expected_files, result, check_length: bool = True): - expected_names_to_files = {os.path.basename(f): f for f in expected_files} - if check_length: - self.assertEqual(len(expected_files), len(result)) - for res in result: - fname = os.path.basename(res[0]) - self.assertTrue(fname is not None) - self.assertTrue(fname in expected_names_to_files) - with open(expected_names_to_files[fname], "rb") as f: - self.assertEqual(res[1].read(), f.read()) - res[1].close() - - def test_csv_parser_iterdatapipe(self): - def make_path(fname): - return f"{self.temp_dir.name}/{fname}" - - csv_files = {"1.csv": "key,item\na,1\nb,2", "empty.csv": "", "empty2.csv": "\n"} - self._custom_files_set_up(csv_files) - datapipe1 = IterableWrapper([make_path(fname) for fname in ["1.csv", "empty.csv", "empty2.csv"]]) - datapipe2 = FileOpener(datapipe1, mode="b") - datapipe3 = datapipe2.map(get_name) - - # Functional Test: yield one row at time from each file, skipping over empty content - csv_parser_dp = datapipe3.parse_csv() - expected_res = [["key", "item"], ["a", "1"], ["b", "2"], []] - self.assertEqual(expected_res, list(csv_parser_dp)) - - # Functional Test: yield one row at time from each file, skipping over empty content and header - csv_parser_dp = datapipe3.parse_csv(skip_lines=1) - expected_res = [["a", "1"], ["b", "2"]] - self.assertEqual(expected_res, list(csv_parser_dp)) - - # Functional Test: yield one row at time from each file with file name, skipping over empty content - csv_parser_dp = datapipe3.parse_csv(return_path=True) - expected_res = [("1.csv", ["key", "item"]), ("1.csv", ["a", "1"]), ("1.csv", ["b", "2"]), ("empty2.csv", [])] - self.assertEqual(expected_res, list(csv_parser_dp)) - - # Functional Test: yield one row at time from each file as tuple instead of list, skipping over empty content - csv_parser_dp = datapipe3.parse_csv(as_tuple=True) - expected_res = [("key", "item"), ("a", "1"), ("b", "2"), ()] - self.assertEqual(expected_res, list(csv_parser_dp)) - - # Reset Test: - csv_parser_dp = CSVParser(datapipe3, return_path=True) - n_elements_before_reset = 2 - expected_res = [("1.csv", ["key", "item"]), ("1.csv", ["a", "1"]), ("1.csv", ["b", "2"]), ("empty2.csv", [])] - res_before_reset, res_after_reset = reset_after_n_next_calls(csv_parser_dp, n_elements_before_reset) - self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset) - self.assertEqual(expected_res, res_after_reset) - - # __len__ Test: length isn't implemented since it cannot be known ahead of time - with self.assertRaisesRegex(TypeError, "has no len"): - len(csv_parser_dp) - - def test_csv_dict_parser_iterdatapipe(self): - def get_name(path_and_stream): - return os.path.basename(path_and_stream[0]), path_and_stream[1] - - csv_files = {"1.csv": "key,item\na,1\nb,2", "empty.csv": "", "empty2.csv": "\n"} - self._custom_files_set_up(csv_files) - datapipe1 = FileLister(self.temp_dir.name, "*.csv") - datapipe2 = FileOpener(datapipe1, mode="b") - datapipe3 = datapipe2.map(get_name) - - # Functional Test: yield one row at a time as dict, with the first row being the header (key) - csv_dict_parser_dp = datapipe3.parse_csv_as_dict() - expected_res1 = [{"key": "a", "item": "1"}, {"key": "b", "item": "2"}] - self.assertEqual(expected_res1, list(csv_dict_parser_dp)) - - # Functional Test: yield one row at a time as dict, skip over first row, with the second row being the header - csv_dict_parser_dp = datapipe3.parse_csv_as_dict(skip_lines=1) - expected_res2 = [{"a": "b", "1": "2"}] - self.assertEqual(expected_res2, list(csv_dict_parser_dp)) - - # Functional Test: yield one row at a time as dict with file name, and the first row being the header (key) - csv_dict_parser_dp = datapipe3.parse_csv_as_dict(return_path=True) - expected_res3 = [("1.csv", {"key": "a", "item": "1"}), ("1.csv", {"key": "b", "item": "2"})] - self.assertEqual(expected_res3, list(csv_dict_parser_dp)) - - # Reset Test - csv_dict_parser_dp = CSVDictParser(datapipe3) - expected_res4 = [{"key": "a", "item": "1"}, {"key": "b", "item": "2"}] - n_elements_before_reset = 1 - res_before_reset, res_after_reset = reset_after_n_next_calls(csv_dict_parser_dp, n_elements_before_reset) - self.assertEqual(expected_res4[:n_elements_before_reset], res_before_reset) - self.assertEqual(expected_res4, res_after_reset) - - # __len__ Test: length isn't implemented since it cannot be known ahead of time - with self.assertRaisesRegex(TypeError, "has no len"): - len(csv_dict_parser_dp) - - def test_hash_checker_iterdatapipe(self): - hash_dict = {} - - def fill_hash_dict(): - for path in self.temp_files: - with open(path) as f: - hash_func = hashlib.sha256() - content = f.read().encode("utf-8") - hash_func.update(content) - hash_dict[path] = hash_func.hexdigest() - - fill_hash_dict() - - datapipe1 = FileLister(self.temp_dir.name, "*") - datapipe2 = FileOpener(datapipe1, mode="b") - hash_check_dp = HashChecker(datapipe2, hash_dict) - - expected_res = list(datapipe2) - - # Functional Test: Ensure the DataPipe values are unchanged if the hashes are the same - for (expected_path, expected_stream), (actual_path, actual_stream) in zip(expected_res, hash_check_dp): - self.assertEqual(expected_path, actual_path) - self.assertEqual(expected_stream.read(), actual_stream.read()) - - # Functional Test: Ensure the rewind option works, and the stream is empty when there is no rewind - hash_check_dp_no_reset = HashChecker(datapipe2, hash_dict, rewind=False) - for (expected_path, _), (actual_path, actual_stream) in zip(expected_res, hash_check_dp_no_reset): - self.assertEqual(expected_path, actual_path) - self.assertEqual(b"", actual_stream.read()) - - # Functional Test: Error when file/path is not in hash_dict - hash_check_dp = HashChecker(datapipe2, {}) - it = iter(hash_check_dp) - with self.assertRaisesRegex(RuntimeError, "Unspecified hash for file"): - next(it) - - # Functional Test: Error when the hash is different - hash_dict[self.temp_files[0]] = "WRONG HASH" - hash_check_dp = HashChecker(datapipe2, hash_dict) - with self.assertRaisesRegex(RuntimeError, "does not match"): - list(hash_check_dp) - - # Reset Test: - fill_hash_dict() # Reset the dict with correct values because we changed it in the last test case - hash_check_dp = datapipe2.check_hash(hash_dict) - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(hash_check_dp, n_elements_before_reset) - for (expected_path, expected_stream), (actual_path, actual_stream) in zip(datapipe2, res_before_reset): - self.assertEqual(expected_path, actual_path) - self.assertEqual(expected_stream.read(), actual_stream.read()) - for (expected_path, expected_stream), (actual_path, actual_stream) in zip(datapipe2, res_after_reset): - self.assertEqual(expected_path, actual_path) - self.assertEqual(expected_stream.read(), actual_stream.read()) - - # __len__ Test: returns the length of source DataPipe - with self.assertRaisesRegex(TypeError, "FileOpenerIterDataPipe instance doesn't have valid length"): - len(hash_check_dp) - - def test_json_parser_iterdatapipe(self): - def is_empty_json(path_and_stream): - return path_and_stream[0] == "empty.json" - - def is_nonempty_json(path_and_stream): - return path_and_stream[0] != "empty.json" - - json_files = { - "1.json": '["foo", {"bar":["baz", null, 1.0, 2]}]', - "empty.json": "", - "2.json": '{"__complex__": true, "real": 1, "imag": 2}', - } - self._custom_files_set_up(json_files) - datapipe1 = IterableWrapper([f"{self.temp_dir.name}/{fname}" for fname in ["empty.json", "1.json", "2.json"]]) - datapipe2 = FileOpener(datapipe1, mode="b") - datapipe3 = datapipe2.map(get_name) - datapipe_empty = datapipe3.filter(is_empty_json) - datapipe_nonempty = datapipe3.filter(is_nonempty_json) - - empty_json_dp = datapipe_empty.parse_json_files() - it = iter(empty_json_dp) - # Functional Test: dp fails when empty JSON file is given - with self.assertRaisesRegex(JSONDecodeError, "Expecting value"): - next(it) - - # Functional Test: dp yields one json file at a time - json_dp = datapipe_nonempty.parse_json_files() - expected_res = [ - ("1.json", ["foo", {"bar": ["baz", None, 1.0, 2]}]), - ("2.json", {"__complex__": True, "real": 1, "imag": 2}), - ] - self.assertEqual(expected_res, list(json_dp)) - - # Reset Test: - json_dp = JsonParser(datapipe_nonempty) - n_elements_before_reset = 1 - res_before_reset, res_after_reset = reset_after_n_next_calls(json_dp, n_elements_before_reset) - self.assertEqual(expected_res[:n_elements_before_reset], res_before_reset) - self.assertEqual(expected_res, res_after_reset) - - # __len__ Test: length isn't implemented since it cannot be known ahead of time - with self.assertRaisesRegex(TypeError, "len"): - len(json_dp) - - # kwargs Test: - json_dp = JsonParser(datapipe_nonempty, parse_int=str) - expected_res = [ - ("1.json", ["foo", {"bar": ["baz", None, 1.0, "2"]}]), - ("2.json", {"__complex__": True, "real": "1", "imag": "2"}), - ] - self.assertEqual(expected_res, list(json_dp)) - - def test_saver_iterdatapipe(self): - # Functional Test: Saving some data - name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"} - source_dp = IterableWrapper(sorted(name_to_data.items())) - saver_dp = source_dp.save_to_disk(filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb") - res_file_paths = list(saver_dp) - expected_paths = [filepath_fn(self.temp_dir.name, name) for name in name_to_data.keys()] - self.assertEqual(expected_paths, res_file_paths) - for name in name_to_data.keys(): - p = filepath_fn(self.temp_dir.name, name) - with open(p) as f: - self.assertEqual(name_to_data[name], f.read().encode()) - - # Reset Test: - saver_dp = Saver(source_dp, filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb") - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(saver_dp, n_elements_before_reset) - self.assertEqual( - [filepath_fn(self.temp_dir.name, "1.txt"), filepath_fn(self.temp_dir.name, "2.txt")], res_before_reset - ) - self.assertEqual(expected_paths, res_after_reset) - for name in name_to_data.keys(): - p = filepath_fn(self.temp_dir.name, name) - with open(p) as f: - self.assertEqual(name_to_data[name], f.read().encode()) - - # __len__ Test: returns the length of source DataPipe - self.assertEqual(3, len(saver_dp)) - - def _write_test_tar_files(self): - path = os.path.join(self.temp_dir.name, "test_tar.tar") - with tarfile.open(path, "w:tar") as tar: - tar.add(self.temp_files[0]) - tar.add(self.temp_files[1]) - tar.add(self.temp_files[2]) - - def _write_test_tar_gz_files(self): - path = os.path.join(self.temp_dir.name, "test_gz.tar.gz") - with tarfile.open(path, "w:gz") as tar: - tar.add(self.temp_files[0]) - tar.add(self.temp_files[1]) - tar.add(self.temp_files[2]) - - def test_tar_archive_reader_iterdatapipe(self): - self._write_test_tar_files() - datapipe1 = FileLister(self.temp_dir.name, "*.tar") - datapipe2 = FileOpener(datapipe1, mode="b") - tar_loader_dp = TarArchiveLoader(datapipe2) - - self._write_test_tar_gz_files() - datapipe_gz_1 = FileLister(self.temp_dir.name, "*.tar.gz") - datapipe_gz_2 = FileOpener(datapipe_gz_1, mode="b") - gz_reader_dp = TarArchiveLoader(datapipe_gz_2) - - # Functional Test: Read extracted files before reaching the end of the tarfile - self._compressed_files_comparison_helper(self.temp_files, tar_loader_dp, check_length=False) - self._compressed_files_comparison_helper(self.temp_files, gz_reader_dp, check_length=False) - - # Load from decompressed file stream - decomp_dp = datapipe_gz_2.decompress() - decomp_reader_dp = TarArchiveLoader(decomp_dp) - self._compressed_files_comparison_helper(self.temp_files, decomp_reader_dp, check_length=False) - - # Functional Test: Read extracted files after reaching the end of the tarfile - data_refs = list(tar_loader_dp) - self._compressed_files_comparison_helper(self.temp_files, data_refs) - data_refs_gz = list(gz_reader_dp) - self._compressed_files_comparison_helper(self.temp_files, data_refs_gz) - - # Reset Test: reset the DataPipe after reading part of it - tar_loader_dp = datapipe2.load_from_tar() - n_elements_before_reset = 1 - res_before_reset, res_after_reset = reset_after_n_next_calls(tar_loader_dp, n_elements_before_reset) - # Check result accumulated before reset - self._compressed_files_comparison_helper(self.temp_files[:n_elements_before_reset], res_before_reset) - # Check result accumulated after reset - self._compressed_files_comparison_helper(self.temp_files, res_after_reset) - - # __len__ Test: doesn't have valid length - with self.assertRaisesRegex(TypeError, "instance doesn't have valid length"): - len(tar_loader_dp) - - def _write_test_zip_files(self): - path = os.path.join(self.temp_dir.name, "test_zip.zip") - with zipfile.ZipFile(path, "w") as myzip: - myzip.write(self.temp_files[0], arcname=os.path.basename(self.temp_files[0])) - myzip.write(self.temp_files[1], arcname=os.path.basename(self.temp_files[1])) - myzip.write(self.temp_files[2], arcname=os.path.basename(self.temp_files[2])) - - def test_zip_archive_reader_iterdatapipe(self): - self._write_test_zip_files() - datapipe1 = FileLister(self.temp_dir.name, "*.zip") - datapipe2 = FileOpener(datapipe1, mode="b") - zip_loader_dp = ZipArchiveLoader(datapipe2) - - # Functional Test: read extracted files before reaching the end of the zipfile - self._compressed_files_comparison_helper(self.temp_files, zip_loader_dp, check_length=False) - - # Functional Test: read extracted files after reaching the end of the zipile - data_refs = list(zip_loader_dp) - self._compressed_files_comparison_helper(self.temp_files, data_refs) - - # Reset Test: reset the DataPipe after reading part of it - zip_loader_dp = datapipe2.load_from_zip() - n_elements_before_reset = 1 - res_before_reset, res_after_reset = reset_after_n_next_calls(zip_loader_dp, n_elements_before_reset) - # Check the results accumulated before reset - self._compressed_files_comparison_helper(self.temp_files[:n_elements_before_reset], res_before_reset) - # Check the results accumulated after reset - self._compressed_files_comparison_helper(self.temp_files, res_after_reset) - - # __len__ Test: doesn't have valid length - with self.assertRaisesRegex(TypeError, "instance doesn't have valid length"): - len(zip_loader_dp) - - def _write_test_xz_files(self): - for path in self.temp_files: - fname = os.path.basename(path) - temp_xzfile_pathname = os.path.join(self.temp_dir.name, f"{fname}.xz") - with open(path) as f: - with lzma.open(temp_xzfile_pathname, "w") as xz: - xz.write(f.read().encode("utf-8")) - - def test_xz_archive_reader_iterdatapipe(self): - # Worth noting that the .tar and .zip tests write multiple files into the same compressed file - # Whereas we create multiple .xz files in the same directories below. - self._write_test_xz_files() - datapipe1 = FileLister(self.temp_dir.name, "*.xz") - datapipe2 = FileOpener(datapipe1, mode="b") - xz_loader_dp = XzFileLoader(datapipe2) - - # Functional Test: Read extracted files before reaching the end of the xzfile - self._unordered_compressed_files_comparison_helper(self.temp_files, xz_loader_dp, check_length=False) - - # Functional Test: Read extracted files after reaching the end of the xzfile - data_refs = list(xz_loader_dp) - self._unordered_compressed_files_comparison_helper(self.temp_files, data_refs) - - # Reset Test: reset the DataPipe after reading part of it - xz_loader_dp = datapipe2.load_from_xz() - n_elements_before_reset = 1 - res_before_reset, res_after_reset = reset_after_n_next_calls(xz_loader_dp, n_elements_before_reset) - # Check result accumulated before reset - self.assertEqual(n_elements_before_reset, len(res_before_reset)) - self._unordered_compressed_files_comparison_helper(self.temp_files, res_before_reset, check_length=False) - # Check result accumulated after reset - self._unordered_compressed_files_comparison_helper(self.temp_files, res_after_reset) - - # Reset Test: Ensure the order is consistent between iterations - for r1, r2 in zip(list(xz_loader_dp), list(xz_loader_dp)): - self.assertEqual(r1[0], r2[0]) - - # __len__ Test: doesn't have valid length - with self.assertRaisesRegex(TypeError, "instance doesn't have valid length"): - len(xz_loader_dp) - - def _write_test_bz2_files(self): - for path in self.temp_files: - fname = os.path.basename(path) - temp_bz2file_pathname = os.path.join(self.temp_dir.name, f"{fname}.bz2") - with open(path) as f: - with bz2.open(temp_bz2file_pathname, "w") as f_bz2: - f_bz2.write(f.read().encode("utf-8")) - - def test_bz2_archive_reader_iterdatapipe(self): - self._write_test_bz2_files() - filelist_dp = FileLister(self.temp_dir.name, "*.bz2") - fileopen_dp = FileOpener(filelist_dp, mode="b") - bz2_loader_dp = Bz2FileLoader(fileopen_dp) - - # Functional Test: Read extracted files before reaching the end of the bz2file - self._unordered_compressed_files_comparison_helper(self.temp_files, bz2_loader_dp, check_length=False) - - # Functional Test: Read extracted files after reaching the end of the bz2file - data_refs = list(bz2_loader_dp) - self._unordered_compressed_files_comparison_helper(self.temp_files, data_refs) - - # Reset Test: reset the DataPipe after reading part of it - bz2_loader_dp = fileopen_dp.load_from_bz2() - n_elements_before_reset = 1 - res_before_reset, res_after_reset = reset_after_n_next_calls(bz2_loader_dp, n_elements_before_reset) - # Check result accumulated before reset - self.assertEqual(n_elements_before_reset, len(res_before_reset)) - self._unordered_compressed_files_comparison_helper(self.temp_files, res_before_reset, check_length=False) - # Check result accumulated after reset - self._unordered_compressed_files_comparison_helper(self.temp_files, res_after_reset) - - # Reset Test: Ensure the order is consistent between iterations - - for r1, r2 in zip(list(bz2_loader_dp), list(bz2_loader_dp)): - self.assertEqual(r1[0], r2[0]) - - # __len__ Test: doesn't have valid length - with self.assertRaisesRegex(TypeError, "instance doesn't have valid length"): - len(bz2_loader_dp) - - def _decompressor_tar_test_helper(self, expected_files, tar_decompress_dp): - for _file, child_obj in tar_decompress_dp: - for expected_file, tarinfo in zip(expected_files, child_obj): - if not tarinfo.isfile(): - continue - extracted_fobj = child_obj.extractfile(tarinfo) - with open(expected_file, "rb") as f: - self.assertEqual(f.read(), extracted_fobj.read()) - - def _decompressor_xz_test_helper(self, xz_decompress_dp): - for xz_file_name, xz_stream in xz_decompress_dp: - expected_file = xz_file_name[:-3] - with open(expected_file, "rb") as f: - self.assertEqual(f.read(), xz_stream.read()) - - def _decompressor_bz2_test_helper(self, bz2_decompress_dp): - for bz2_file_name, bz2_stream in bz2_decompress_dp: - expected_file = bz2_file_name.rsplit(".", 1)[0] - with open(expected_file, "rb") as f: - self.assertEqual(f.read(), bz2_stream.read()) - - def _write_single_gz_file(self): - import gzip - - with gzip.open(f"{self.temp_dir.name}/temp.gz", "wb") as k: - with open(self.temp_files[0], "rb") as f: - k.write(f.read()) - - def test_decompressor_iterdatapipe(self): - self._write_test_tar_files() - self._write_test_tar_gz_files() - self._write_single_gz_file() - self._write_test_zip_files() - self._write_test_xz_files() - self._write_test_bz2_files() - - # Functional Test: work with .tar files - tar_file_dp = FileLister(self.temp_dir.name, "*.tar") - tar_load_dp = FileOpener(tar_file_dp, mode="b") - tar_decompress_dp = Decompressor(tar_load_dp, file_type="tar") - self._decompressor_tar_test_helper(self.temp_files, tar_decompress_dp) - - # Functional test: work with .tar.gz files - tar_gz_file_dp = FileLister(self.temp_dir.name, "*.tar.gz") - tar_gz_load_dp = FileOpener(tar_gz_file_dp, mode="b") - tar_gz_decompress_dp = Decompressor(tar_gz_load_dp, file_type="tar") - self._decompressor_tar_test_helper(self.temp_files, tar_gz_decompress_dp) - - # Functional Test: work with .gz files - gz_file_dp = IterableWrapper([f"{self.temp_dir.name}/temp.gz"]) - gz_load_dp = FileOpener(gz_file_dp, mode="b") - gz_decompress_dp = Decompressor(gz_load_dp, file_type="gzip") - for _, gz_stream in gz_decompress_dp: - with open(self.temp_files[0], "rb") as f: - self.assertEqual(f.read(), gz_stream.read()) - - # Functional Test: work with .zip files - zip_file_dp = FileLister(self.temp_dir.name, "*.zip") - zip_load_dp = FileOpener(zip_file_dp, mode="b") - zip_decompress_dp = zip_load_dp.decompress(file_type="zip") - for _, zip_stream in zip_decompress_dp: - for fname in self.temp_files: - with open(fname, "rb") as f: - self.assertEqual(f.read(), zip_stream.read(name=os.path.basename(fname))) - - # Functional Test: work with .xz files - xz_file_dp = FileLister(self.temp_dir.name, "*.xz") - xz_load_dp = FileOpener(xz_file_dp, mode="b") - xz_decompress_dp = Decompressor(xz_load_dp, file_type="lzma") - self._decompressor_xz_test_helper(xz_decompress_dp) - - # Functional Test: work with .bz2 files - bz2_file_dp = FileLister(self.temp_dir.name, "*.bz2") - bz2_load_dp = FileOpener(bz2_file_dp, mode="b") - bz2_decompress_dp = Decompressor(bz2_load_dp, file_type="bz2") - self._decompressor_bz2_test_helper(bz2_decompress_dp) - - # Functional Test: work without file type as input for .tar files - tar_decompress_dp = Decompressor(tar_load_dp, file_type=None) - self._decompressor_tar_test_helper(self.temp_files, tar_decompress_dp) - - # Functional Test: work without file type as input for .xz files - xz_decompress_dp = Decompressor(xz_load_dp) - self._decompressor_xz_test_helper(xz_decompress_dp) - - # Functional Test: work without file type as input for .tar.gz files - tar_gz_decompress_dp = Decompressor(tar_gz_load_dp, file_type=None) - self._decompressor_tar_test_helper(self.temp_files, tar_gz_decompress_dp) - - # Functional Test: work without file type as input for .bz2 files - bz2_decompress_dp = Decompressor(bz2_load_dp, file_type=None) - self._decompressor_bz2_test_helper(bz2_decompress_dp) - - # Functional Test: Compression Type is works for both upper and lower case strings - tar_decompress_dp = Decompressor(tar_load_dp, file_type="TAr") - self._decompressor_tar_test_helper(self.temp_files, tar_decompress_dp) - - # Functional Test: Compression Type throws error for invalid file type - with self.assertRaisesRegex(ValueError, "not a valid CompressionType"): - Decompressor(tar_load_dp, file_type="ABC") - - # Reset Test: Ensure the order is consistent between iterations - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(xz_decompress_dp, n_elements_before_reset) - self._decompressor_xz_test_helper(res_before_reset) - self._decompressor_xz_test_helper(res_after_reset) - - # __len__ Test: doesn't have valid length - with self.assertRaisesRegex(TypeError, "has no len"): - len(tar_decompress_dp) - - def _write_text_files(self): - name_to_data = {"1.text": b"DATA", "2.text": b"DATA", "3.text": b"DATA"} - source_dp = IterableWrapper(sorted(name_to_data.items())) - saver_dp = source_dp.save_to_disk(filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb") - list(saver_dp) - - @staticmethod - def _slow_fn(tmpdirname, x): - with open(os.path.join(tmpdirname, str(os.getpid())), "w") as pid_fh: - pid_fh.write("anything") - time.sleep(10) - return (x, "str") - - @skipIfNoPortalocker - def test_disk_cache_locks(self): - with tempfile.TemporaryDirectory() as tmpdirname: - file_name = os.path.join(tmpdirname, "test.bin") - dp = IterableWrapper([file_name]) - dp = dp.on_disk_cache(filepath_fn=_noop) - dp = dp.map(functools.partial(self._slow_fn, tmpdirname)) - dp = dp.end_caching(mode="t", filepath_fn=_noop, timeout=120) - dp = FileOpener(dp) - dp = StreamReader(dp) - dl = DataLoader(dp, num_workers=10, multiprocessing_context="spawn", batch_size=1, collate_fn=_unbatch) - result = list(dl) - all_files = [] - for (_, _, filenames) in os.walk(tmpdirname): - all_files += filenames - # We expect only two files, one with pid and 'downloaded' one - self.assertEqual(2, len(all_files)) - self.assertEqual("str", result[0][1]) - - # cleanup cached files - for f in os.listdir(tmpdirname): - os.remove(os.path.join(tmpdirname, f)) - - dp = CacheTimeout(2)(dp) # Calling adapter manually to work with classic DataLoader - dl = DataLoader(dp, num_workers=10, multiprocessing_context="spawn", batch_size=1, collate_fn=_unbatch) - with self.assertRaisesRegex(Exception, "OnDiskCache Exception"): - result = list(dl) - - # TODO(120): this test currently only covers reading from local - # filesystem. It needs to be modified once test data can be stored on - # gdrive/onedrive - @skipIfNoIoPath - def test_io_path_file_lister_iterdatapipe(self): - datapipe = IoPathFileLister(root=self.temp_sub_dir.name) - - # check all file paths within sub_folder are listed - for path in datapipe: - self.assertTrue(path in self.temp_sub_files) - - datapipe = IterableWrapper([self.temp_sub_dir.name]) - datapipe = datapipe.list_files_by_iopath() - for path in datapipe: - self.assertTrue(path in self.temp_sub_files) - - @skipIfNoIoPath - def test_io_path_file_lister_iterdatapipe_with_list(self): - datapipe = IoPathFileLister(root=[self.temp_sub_dir.name, self.temp_sub_dir_2.name]) - - file_lister = list(datapipe) - file_lister.sort() - all_temp_files = list(self.temp_sub_files + self.temp_sub_files_2) - all_temp_files.sort() - - # check all file paths within sub_folder are listed - self.assertEqual(file_lister, all_temp_files) - - datapipe = IterableWrapper([self.temp_sub_dir.name, self.temp_sub_dir_2.name]) - datapipe = datapipe.list_files_by_iopath() - results = list(datapipe) - results.sort() - self.assertEqual(results, all_temp_files) - - @skipIfNoIoPath - def test_io_path_file_loader_iterdatapipe(self): - datapipe1 = IoPathFileLister(root=self.temp_sub_dir.name) - datapipe2 = IoPathFileOpener(datapipe1) - - # check contents of file match - for _, f in datapipe2: - self.assertEqual(f.read(), "0123456789abcdef") - - # Reset Test: Ensure the resulting streams are still readable after the DataPipe is reset/exhausted - self._write_text_files() - lister_dp = FileLister(self.temp_dir.name, "*.text") - iopath_file_opener_dp = lister_dp.open_files_by_iopath(mode="rb") - - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(iopath_file_opener_dp, n_elements_before_reset) - self.assertEqual(2, len(res_before_reset)) - self.assertEqual(3, len(res_after_reset)) - for _name, stream in res_before_reset: - self.assertEqual(b"DATA", stream.read()) - for _name, stream in res_after_reset: - self.assertEqual(b"DATA", stream.read()) - - @skipIfNoIoPath - def test_io_path_saver_iterdatapipe(self): - # Functional Test: Saving some data - name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"} - source_dp = IterableWrapper(sorted(name_to_data.items())) - saver_dp = source_dp.save_by_iopath(filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb") - res_file_paths = list(saver_dp) - expected_paths = [filepath_fn(self.temp_dir.name, name) for name in name_to_data.keys()] - self.assertEqual(expected_paths, res_file_paths) - for name in name_to_data.keys(): - p = filepath_fn(self.temp_dir.name, name) - with open(p) as f: - self.assertEqual(name_to_data[name], f.read().encode()) - - # Reset Test: - saver_dp = IoPathSaver(source_dp, filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb") - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(saver_dp, n_elements_before_reset) - self.assertEqual( - [filepath_fn(self.temp_dir.name, "1.txt"), filepath_fn(self.temp_dir.name, "2.txt")], res_before_reset - ) - self.assertEqual(expected_paths, res_after_reset) - for name in name_to_data.keys(): - p = filepath_fn(self.temp_dir.name, name) - with open(p) as f: - self.assertEqual(name_to_data[name], f.read().encode()) - - # __len__ Test: returns the length of source DataPipe - self.assertEqual(3, len(saver_dp)) - - @skipIfNoIoPath - def test_io_path_saver_file_lock(self): - # Same filename with different name - name_to_data = {"1.txt": b"DATA1", "1.txt": b"DATA2", "2.txt": b"DATA3", "2.txt": b"DATA4"} # noqa: F601 - - # Add sharding_filter to shard data into 2 - source_dp = IterableWrapper(list(name_to_data.items())).sharding_filter() - - # Use appending as the mode - saver_dp = source_dp.save_by_iopath(filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="ab") - - import torch.utils.data.graph_settings - - from torch.utils.data import DataLoader - - num_workers = 2 - line_lengths = [] - dl = DataLoader(saver_dp, num_workers=num_workers, multiprocessing_context="spawn") - for filename in dl: - with open(filename[0]) as f: - lines = f.readlines() - x = len(lines) - line_lengths.append(x) - self.assertEqual(x, 1) - - self.assertEqual(num_workers, len(line_lengths)) - - def _write_test_rar_files(self): - # `rarfile` can only read but not write .rar archives so we use to system utilities - rar_archive_name = os.path.join(self.temp_dir.name, "test_rar") - subprocess.run(("rar", "a", rar_archive_name + ".rar", *self.temp_files), check=True) - - # Nested RAR - subprocess.run(("rar", "a", rar_archive_name + "1.rar", self.temp_files[0]), check=True) - subprocess.run(("rar", "a", rar_archive_name + "2.rar", *self.temp_files[1:]), check=True) - subprocess.run( - ("rar", "a", rar_archive_name + "_nested.rar", rar_archive_name + "1.rar", rar_archive_name + "2.rar"), - check=True, - ) - - # Nested RAR in TAR - with tarfile.open(rar_archive_name + "_nested.tar", "w:tar") as tar: - tar.add(rar_archive_name + "1.rar") - tar.add(rar_archive_name + "2.rar") - - @skipIfNoRarTools - def test_rar_archive_loader(self): - self._write_test_rar_files() - - datapipe1 = IterableWrapper([os.path.join(self.temp_dir.name, "test_rar.rar")]) - datapipe2 = FileOpener(datapipe1, mode="b") - rar_loader_dp = RarArchiveLoader(datapipe2) - - # Functional Test: read extracted files before reaching the end of the rarfile - self._unordered_compressed_files_comparison_helper(self.temp_files, rar_loader_dp, check_length=False) - - # Functional Test: read extracted files after reaching the end of the rarfile - data_refs = list(rar_loader_dp) - self._unordered_compressed_files_comparison_helper(self.temp_files, data_refs) - - # Reset Test: reset the DataPipe after reading part of it - rar_loader_dp = datapipe2.load_from_rar() - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(rar_loader_dp, n_elements_before_reset) - # Check the results accumulated before reset - self._unordered_compressed_files_comparison_helper(self.temp_files[:n_elements_before_reset], res_before_reset) - # Check the results accumulated after reset - self._unordered_compressed_files_comparison_helper(self.temp_files, res_after_reset) - - # __len__ Test: doesn't have valid length - with self.assertRaisesRegex(TypeError, "instance doesn't have valid length"): - len(rar_loader_dp) - - # Nested RAR - datapipe1 = IterableWrapper([os.path.join(self.temp_dir.name, "test_rar_nested.rar")]) - datapipe2 = FileOpener(datapipe1, mode="b") - rar_loader_dp_1 = RarArchiveLoader(datapipe2) - rar_loader_dp_2 = RarArchiveLoader(rar_loader_dp_1) - - with self.assertRaisesRegex(ValueError, "Nested RAR archive is not supported"): - list(rar_loader_dp_2) - - # Nested RAR in TAR - datapipe1 = IterableWrapper([os.path.join(self.temp_dir.name, "test_rar_nested.tar")]) - datapipe2 = FileOpener(datapipe1, mode="b") - tar_loader_dp = TarArchiveLoader(datapipe2) - rar_loader_dp = RarArchiveLoader(tar_loader_dp) - - # Functional Test: read extracted files before reaching the end of the rarfile - self._unordered_compressed_files_comparison_helper(self.temp_files, rar_loader_dp, check_length=False) - - # Functional Test: read extracted files after reaching the end of the rarfile - data_refs = list(rar_loader_dp) - self._unordered_compressed_files_comparison_helper(self.temp_files, data_refs) - - def _add_data_to_wds_tar(self, archive, name, value): - if isinstance(value, str): - value = value.encode() - info = tarfile.TarInfo(name) - info.size = len(value) - archive.addfile(info, io.BytesIO(value)) - - def _create_wds_tar(self, dest, nsamples): - with tarfile.open(dest, mode="w") as archive: - for i in range(nsamples): - self._add_data_to_wds_tar(archive, f"data/{i}.txt", f"text{i}") - self._add_data_to_wds_tar(archive, f"data/{i}.bin", f"bin{i}") - - def test_webdataset(self) -> None: - # Functional Test: groups samples correctly - source_dp = IterableWrapper( - # simulated tar file content - [ - ("/path/to/file1.jpg", b"1"), - ("/path/to/_something_", b"nothing"), - ("/path/to/file1.cls", b"2"), - ("/path/to/file2.jpg", b"3"), - ("/path/to/file2.cls", b"4"), - ] - ) - web_dataset = WebDataset(source_dp) - self.assertEqual( - # expected grouped output - [ - {".jpg": b"1", ".cls": b"2", "__key__": "/path/to/file1"}, - {".jpg": b"3", ".cls": b"4", "__key__": "/path/to/file2"}, - ], - list(web_dataset), - ) - - def test_webdataset2(self) -> None: - # Setup - nsamples = 10 - self._create_wds_tar(os.path.join(self.temp_dir.name, "wds.tar"), nsamples) - - def decode(item): - key, value = item - if key.endswith(".txt"): - return key, value.read().decode("utf-8") - if key.endswith(".bin"): - return key, value.read().decode("utf-8") - - datapipe1 = FileLister(self.temp_dir.name, "wds*.tar") - datapipe2 = FileOpener(datapipe1, mode="b") - dataset = datapipe2.load_from_tar().map(decode).webdataset() - items = list(dataset) - assert len(items) == nsamples - assert items[0][".txt"] == "text0" - assert items[9][".bin"] == "bin9" - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_mapdatapipe.py b/test/test_mapdatapipe.py deleted file mode 100644 index b6e6c0b9b..000000000 --- a/test/test_mapdatapipe.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import expecttest -from torchdata.datapipes.iter import MapToIterConverter -from torchdata.datapipes.map import InMemoryCacheHolder, MapDataPipe, SequenceWrapper, UnZipper - - -class TestMapDataPipe(expecttest.TestCase): - def test_unzipper_mapdatapipe(self) -> None: - source_dp = SequenceWrapper([(i, i + 10, i + 20) for i in range(10)]) - - # Functional Test: unzips each sequence, with `sequence_length` specified - dp1: MapDataPipe - dp2: MapDataPipe - dp3: MapDataPipe - dp1, dp2, dp3 = UnZipper(source_dp, sequence_length=3) # type: ignore[misc] - self.assertEqual(list(range(10)), list(dp1)) - self.assertEqual(list(range(10, 20)), list(dp2)) - self.assertEqual(list(range(20, 30)), list(dp3)) - - # Functional Test: skipping over specified values - dp2, dp3 = source_dp.unzip(sequence_length=3, columns_to_skip=[0]) - self.assertEqual(list(range(10, 20)), list(dp2)) - self.assertEqual(list(range(20, 30)), list(dp3)) - - (dp2,) = source_dp.unzip(sequence_length=3, columns_to_skip=[0, 2]) - self.assertEqual(list(range(10, 20)), list(dp2)) - - source_dp = SequenceWrapper([(i, i + 10, i + 20, i + 30) for i in range(10)]) - dp2, dp3 = source_dp.unzip(sequence_length=4, columns_to_skip=[0, 3]) - self.assertEqual(list(range(10, 20)), list(dp2)) - self.assertEqual(list(range(20, 30)), list(dp3)) - - # __len__ Test: the lengths of child DataPipes are correct - self.assertEqual((10, 10), (len(dp2), len(dp3))) - - def test_map_to_iter_converter_datapipe(self) -> None: - # Functional Test: ensure the conversion without indices input is correct - source_dp = SequenceWrapper(range(10)) - iter_dp = source_dp.to_iter_datapipe() - self.assertEqual(list(range(10)), list(iter_dp)) - - # Functional Test: ensure conversion with custom indices is correct - source_dp2 = SequenceWrapper({"a": 0, "b": 1, "c": 2}) - iter_dp2 = MapToIterConverter(source_dp2, indices=["a", "b", "c"]) - self.assertEqual([0, 1, 2], list(iter_dp2)) - - # __len__ Test: the lengths of the output is correct - self.assertEqual(10, len(iter_dp)) - self.assertEqual(3, len(iter_dp2)) - - def test_in_memory_cache_holder_mapdatapipe(self) -> None: - source_dp = SequenceWrapper(range(10)) - cache_dp = source_dp.in_memory_cache() - - # Functional Test: Cache DP should just return the data without changing the values - self.assertEqual(list(range(10)), list(cache_dp)) - - # Functional Test: Ensure the objects are the same ones from source DataPipe - cache_dp = InMemoryCacheHolder(source_dp) # type: ignore[arg-type] - res1 = list(cache_dp) - res2 = list(cache_dp) - self.assertTrue(id(source) == id(cache) for source, cache in zip(source_dp, res1)) - self.assertTrue(id(source) == id(cache) for source, cache in zip(source_dp, res2)) - - # __len__ Test: inherits length from source_dp - self.assertEqual(10, len(cache_dp)) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_period.py b/test/test_period.py deleted file mode 100644 index ef0a2dae7..000000000 --- a/test/test_period.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import io -import os -import unittest - -import expecttest - -from torchdata.datapipes.iter import GDriveReader, IterableWrapper, OnlineReader - - -# This TestCase is created due to the limited quota to access google drive -class TestDataPipePeriod(expecttest.TestCase): - def test_gdrive_iterdatapipe(self): - - amazon_review_url = "https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM" - expected_file_name = "amazon_review_polarity_csv.tar.gz" - expected_MD5_hash = "fe39f8b653cada45afd5792e0f0e8f9b" - query_params = {"auth": ("fake_username", "fake_password"), "allow_redirects": True} - timeout = 120 - gdrive_reader_dp = GDriveReader(IterableWrapper([amazon_review_url]), timeout=timeout, **query_params) - - # Functional Test: test if the GDrive Reader can download and read properly - reader_dp = gdrive_reader_dp.readlines() - it = iter(reader_dp) - path, line = next(it) - self.assertEqual(expected_file_name, os.path.basename(path)) - self.assertTrue(line != b"") - - # Reset Test: gdrive_reader_dp has been read, but we reset when calling check_hash() - check_cache_dp = gdrive_reader_dp.check_hash({expected_file_name: expected_MD5_hash}, "md5", rewind=False) - it = iter(check_cache_dp) - path, stream = next(it) - self.assertEqual(expected_file_name, os.path.basename(path)) - self.assertTrue(io.BufferedReader, type(stream)) - - # __len__ Test: returns the length of source DataPipe - source_dp = IterableWrapper([amazon_review_url]) - gdrive_dp = GDriveReader(source_dp) - self.assertEqual(1, len(gdrive_dp)) - - # Error Test: test if the GDrive Reader raises an error when the url is invalid - error_url = "https://drive.google.com/uc?export=download&id=filedoesnotexist" - http_error_dp = GDriveReader(IterableWrapper([error_url]), timeout=timeout) - with self.assertRaisesRegex( - Exception, r"404.+https://drive.google.com/uc\?export=download&id=filedoesnotexist" - ): - next(iter(http_error_dp.readlines())) - - # Feature skip-error Test: test if the GDrive Reader skips urls causing problems - gdrive_skip_error_dp = GDriveReader( - IterableWrapper([error_url, amazon_review_url]), timeout=timeout, skip_on_error=True - ) - reader_dp = gdrive_skip_error_dp.readlines() - with self.assertWarnsRegex( - Warning, r"404.+https://drive.google.com/uc\?export=download&id=filedoesnotexist.+skipping" - ): - it = iter(reader_dp) - path, line = next(it) - self.assertEqual(expected_file_name, os.path.basename(path)) - self.assertTrue(line != b"") - - def test_online_iterdatapipe(self): - - license_file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE" - amazon_review_url = "https://drive.google.com/uc?export=download&id=0Bz8a_Dbh9QhbaW12WVVZS2drcnM" - expected_license_file_name = "LICENSE" - expected_amazon_file_name = "amazon_review_polarity_csv.tar.gz" - expected_license_MD5_hash = "bb9675028dd39d2dd2bf71002b93e66c" - expected_amazon_MD5_hash = "fe39f8b653cada45afd5792e0f0e8f9b" - query_params = {"auth": ("fake_username", "fake_password"), "allow_redirects": True} - timeout = 120 - - file_hash_dict = { - license_file_url: expected_license_MD5_hash, - expected_amazon_file_name: expected_amazon_MD5_hash, - } - - # Functional Test: can read from GDrive links - online_reader_dp = OnlineReader(IterableWrapper([amazon_review_url]), timeout=timeout, **query_params) - reader_dp = online_reader_dp.readlines() - it = iter(reader_dp) - path, line = next(it) - self.assertEqual(expected_amazon_file_name, os.path.basename(path)) - self.assertTrue(line != b"") - - # Functional Test: can read from other links - online_reader_dp = OnlineReader(IterableWrapper([license_file_url])) - reader_dp = online_reader_dp.readlines() - it = iter(reader_dp) - path, line = next(it) - self.assertEqual(expected_license_file_name, os.path.basename(path)) - self.assertTrue(line != b"") - - # Reset Test: reset online_reader_dp by calling check_hash - check_cache_dp = online_reader_dp.check_hash(file_hash_dict, "md5", rewind=False) - it = iter(check_cache_dp) - path, stream = next(it) - self.assertEqual(expected_license_file_name, os.path.basename(path)) - self.assertTrue(io.BufferedReader, type(stream)) - - # Functional Test: works with multiple URLs of different sources - online_reader_dp = OnlineReader(IterableWrapper([license_file_url, amazon_review_url])) - check_cache_dp = online_reader_dp.check_hash(file_hash_dict, "md5", rewind=False) - it = iter(check_cache_dp) - for expected_file_name, (path, stream) in zip([expected_license_file_name, expected_amazon_file_name], it): - self.assertEqual(expected_file_name, os.path.basename(path)) - self.assertTrue(io.BufferedReader, type(stream)) - - # __len__ Test: returns the length of source DataPipe - self.assertEqual(2, len(online_reader_dp)) - - # Error Test: test if the Online Reader raises an error when the url is invalid - error_url_http = "https://github.com/pytorch/data/this/url/dont/exist" - online_error_dp = OnlineReader(IterableWrapper([error_url_http]), timeout=timeout) - with self.assertRaisesRegex(Exception, f"404.+{error_url_http}"): - next(iter(online_error_dp.readlines())) - - error_url_gdrive = "https://drive.google.com/uc?export=download&id=filedoesnotexist" - online_error_dp = OnlineReader(IterableWrapper([error_url_gdrive]), timeout=timeout) - with self.assertRaisesRegex( - Exception, r"404.+https://drive.google.com/uc\?export=download&id=filedoesnotexist" - ): - next(iter(online_error_dp.readlines())) - - # Feature skip-error Test: test if the Online Reader skips urls causing problems - online_skip_error_dp = OnlineReader( - IterableWrapper([error_url_http, error_url_gdrive, license_file_url]), timeout=timeout, skip_on_error=True - ) - reader_dp = online_skip_error_dp.readlines() - with self.assertWarnsRegex(Warning, f"404.+{error_url_http}.+skipping"): - it = iter(reader_dp) - path, line = next(it) - self.assertEqual(expected_license_file_name, os.path.basename(path)) - self.assertTrue(b"BSD" in line) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_remote_io.py b/test/test_remote_io.py deleted file mode 100644 index 86a818974..000000000 --- a/test/test_remote_io.py +++ /dev/null @@ -1,392 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import io -import json -import os -import subprocess -import unittest -import warnings -from unittest.mock import patch - -import expecttest - -from _utils._common_utils_for_test import check_hash_fn, create_temp_dir, IS_M1, IS_WINDOWS -from torch.utils.data import DataLoader -from torchdata.datapipes.iter import ( - FileOpener, - FSSpecFileLister, - FSSpecFileOpener, - HttpReader, - IterableWrapper, - OnDiskCacheHolder, - S3FileLister, - S3FileLoader, -) -from torchdata.datapipes.iter.load.online import _get_proxies - -try: - import fsspec - - HAS_FSSPEC = True -except ImportError: - HAS_FSSPEC = False - -try: - import s3fs - - HAS_FSSPEC_S3 = True -except ImportError: - HAS_FSSPEC_S3 = False -skipIfNoFSSpecS3 = unittest.skipIf(not (HAS_FSSPEC and HAS_FSSPEC_S3), "no FSSpec with S3fs") - -try: - import adlfs - - HAS_FSSPEC_AZ = True -except ImportError: - HAS_FSSPEC_AZ = False -skipIfNoFSSpecAZ = unittest.skipIf(not (HAS_FSSPEC and HAS_FSSPEC_AZ), "no FSSpec with adlfs") - -try: - from torchdata._torchdata import S3Handler - - HAS_AWS = True -except ImportError: - HAS_AWS = False -skipIfAWS = unittest.skipIf(HAS_AWS, "AWSSDK Enabled") -skipIfNoAWS = unittest.skipIf(not HAS_AWS, "No AWSSDK Enabled") - -try: - import portalocker - - HAS_PORTALOCKER = True -except ImportError: - HAS_PORTALOCKER = False -skipIfNoPortalocker = unittest.skipIf(not HAS_PORTALOCKER, "No portalocker installed") - - -class TestDataPipeRemoteIO(expecttest.TestCase): - def setUp(self): - self.temp_dir = create_temp_dir() - - def tearDown(self): - try: - self.temp_dir.cleanup() - except Exception as e: - warnings.warn(f"TestDataPipeRemoteIO was not able to cleanup temp dir due to {e}") - - def test_http_reader_iterdatapipe(self): - - file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE" - expected_file_name = "LICENSE" - expected_MD5_hash = "bb9675028dd39d2dd2bf71002b93e66c" - query_params = {"auth": ("fake_username", "fake_password"), "allow_redirects": True} - timeout = 120 - http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, **query_params) - - # Functional Test: test if the Http Reader can download and read properly - reader_dp = http_reader_dp.readlines() - it = iter(reader_dp) - path, line = next(it) - self.assertEqual(expected_file_name, os.path.basename(path)) - self.assertTrue(b"BSD" in line) - - # Reset Test: http_reader_dp has been read, but we reset when calling check_hash() - check_cache_dp = http_reader_dp.check_hash({file_url: expected_MD5_hash}, "md5", rewind=False) - it = iter(check_cache_dp) - path, stream = next(it) - self.assertEqual(expected_file_name, os.path.basename(path)) - self.assertTrue(io.BufferedReader, type(stream)) - - # __len__ Test: returns the length of source DataPipe - self.assertEqual(1, len(http_reader_dp)) - - # Error Test: test if the Http Reader raises an error when the url is invalid - error_url = "https://github.com/pytorch/data/this/url/dont/exist" - http_error_dp = HttpReader(IterableWrapper([error_url]), timeout=timeout) - with self.assertRaisesRegex(Exception, f"404.+{error_url}"): - next(iter(http_error_dp.readlines())) - - # Feature skip-error Test: test if the Http Reader skips urls causing problems - http_skip_error_dp = HttpReader(IterableWrapper([error_url, file_url]), timeout=timeout, skip_on_error=True) - reader_dp = http_skip_error_dp.readlines() - with self.assertWarnsRegex(Warning, f"404.+{error_url}.+skipping"): - it = iter(reader_dp) - path, line = next(it) - self.assertEqual(expected_file_name, os.path.basename(path)) - self.assertTrue(b"BSD" in line) - - # test if GET-request is done with correct arguments - with patch("requests.Session.get") as mock_get: - http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, **query_params) - _ = next(iter(http_reader_dp)) - mock_get.assert_called_with( - file_url, - timeout=timeout, - proxies=_get_proxies(), - stream=True, - auth=query_params["auth"], - allow_redirects=query_params["allow_redirects"], - ) - - @skipIfNoPortalocker - def test_on_disk_cache_holder_iterdatapipe(self): - tar_file_url = "https://raw.githubusercontent.com/pytorch/data/main/test/_fakedata/csv.tar.gz" - expected_file_name = os.path.join(self.temp_dir.name, "csv.tar.gz") - expected_MD5_hash = "42cd45e588dbcf64c65751fbf0228af9" - tar_hash_dict = {expected_file_name: expected_MD5_hash} - - tar_file_dp = IterableWrapper([tar_file_url]) - - with self.assertRaisesRegex(RuntimeError, "Expected `OnDiskCacheHolder` existing"): - _ = tar_file_dp.end_caching() - - def _filepath_fn(url): - filename = os.path.basename(url) - return os.path.join(self.temp_dir.name, filename) - - tar_cache_dp = tar_file_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict=tar_hash_dict, - hash_type="md5", - ) - - # DataPipe Constructor - tar_cache_dp = HttpReader(tar_cache_dp) - - # Start iteration without `end_caching` - with self.assertRaisesRegex(RuntimeError, "Please call"): - _ = list(tar_cache_dp) - - # Both filepath_fn and same_filepath_fn are set - with self.assertRaisesRegex(ValueError, "`filepath_fn` is mutually"): - _ = tar_cache_dp.end_caching(mode="wb", filepath_fn=_filepath_fn, same_filepath_fn=True) - - tar_cache_dp = tar_cache_dp.end_caching(mode="wb", same_filepath_fn=True) - - # File doesn't exist on disk - self.assertFalse(os.path.exists(expected_file_name)) - - path = list(tar_cache_dp)[0] - - # File is cached to disk - self.assertTrue(os.path.exists(expected_file_name)) - self.assertEqual(expected_file_name, path) - self.assertTrue(check_hash_fn(expected_file_name, expected_MD5_hash)) - - # Modify the downloaded file to trigger downloading again - with open(expected_file_name, "w") as f: - f.write("0123456789abcdef") - - self.assertFalse(check_hash_fn(expected_file_name, expected_MD5_hash)) - path = list(tar_cache_dp)[0] - self.assertTrue(check_hash_fn(expected_file_name, expected_MD5_hash)) - - # Call `end_caching` again - with self.assertRaisesRegex(RuntimeError, "`end_caching` can only be invoked once"): - _ = tar_cache_dp.end_caching() - - # Cache decompressed archive but only check root directory - root_dir = "temp" - file_cache_dp = OnDiskCacheHolder( - tar_cache_dp, filepath_fn=lambda tar_path: os.path.join(os.path.dirname(tar_path), root_dir) - ) - remember_cache_dp_object = file_cache_dp - file_cache_dp = FileOpener(file_cache_dp, mode="rb").load_from_tar() - - file_cache_dp = file_cache_dp.end_caching( - mode="wb", - filepath_fn=lambda file_path: os.path.join(self.temp_dir.name, root_dir, os.path.basename(file_path)), - ) - - cached_it = iter(file_cache_dp) - for i in range(3): - expected_csv_path = os.path.join(self.temp_dir.name, root_dir, f"{i}.csv") - - # File doesn't exist on disk - # Check disabled due to some elements of prefetching inside of on_disck_cache - # self.assertFalse(os.path.exists(expected_csv_path)) - - csv_path = next(cached_it) - - # File is cached to disk - self.assertTrue(os.path.exists(expected_csv_path)) - self.assertEqual(expected_csv_path, csv_path) - - # This is the situation when previous process had no canche to release promise file on the file lists, - # as we are in same pid, we need to force iterators to finish by deleting or exhausing them - del cached_it - - if not IS_WINDOWS: - dl = DataLoader(file_cache_dp, num_workers=3, multiprocessing_context="fork", batch_size=1) - expected = [[os.path.join(self.temp_dir.name, root_dir, f"{i}.csv")] for i in range(3)] * 3 - res = list(dl) - self.assertEqual(sorted(expected), sorted(res)) - - remember_cache_dp_object._download_everything = True - workers = 100 - dl = DataLoader(file_cache_dp, num_workers=workers, multiprocessing_context="fork", batch_size=1) - expected = [[os.path.join(self.temp_dir.name, root_dir, f"{i}.csv")] for i in range(3)] * workers - res = list(dl) - self.assertEqual(sorted(expected), sorted(res)) - - def __get_s3_cnt(self, s3_pths: list, recursive=True): - """Return the count of the total objects collected from a list s3 paths""" - tot_objs = set() - for p in s3_pths: - pth_parts = p.split("s3://")[1].split("/", 1) - if len(pth_parts) == 1: - bkt_name, prefix = pth_parts[0], "" - else: - bkt_name, prefix = pth_parts - - aws_cmd = f"aws --output json s3api list-objects --bucket {bkt_name} --no-sign-request" - if prefix.strip(): - aws_cmd += f" --prefix {prefix}" - if not recursive: - aws_cmd += " --delimiter /" - - res = subprocess.run(aws_cmd, shell=True, check=True, capture_output=True) - json_res = json.loads(res.stdout) - if "Contents" in json_res: - objs = [v["Key"] for v in json_res["Contents"]] - else: - objs = [v["Prefix"] for v in json_res["CommonPrefixes"]] - tot_objs |= set(objs) - - return len(tot_objs) - - @skipIfNoFSSpecS3 - def test_fsspec_io_iterdatapipe(self): - input_list = [ - ["s3://ai2-public-datasets"], # bucket without '/' - ["s3://ai2-public-datasets/charades/"], # bucket with '/' - [ - "s3://ai2-public-datasets/charades/Charades_v1.zip", - "s3://ai2-public-datasets/charades/Charades_v1_flow.tar", - "s3://ai2-public-datasets/charades/Charades_v1_rgb.tar", - "s3://ai2-public-datasets/charades/Charades_v1_480.zip", - ], # multiple files - ] - for urls in input_list: - fsspec_lister_dp = FSSpecFileLister(IterableWrapper(urls), anon=True) - self.assertEqual( - sum(1 for _ in fsspec_lister_dp), self.__get_s3_cnt(urls, recursive=False), f"{urls} failed" - ) - - url = "s3://ai2-public-datasets/charades/" - fsspec_loader_dp = FSSpecFileOpener(FSSpecFileLister(IterableWrapper([url]), anon=True), anon=True) - res = list(fsspec_loader_dp) - self.assertEqual(len(res), 18, f"{input} failed") - - @unittest.skipIf(True, "Needs authentications. See: https://github.com/pytorch/data/issues/904") - @skipIfNoFSSpecAZ - def test_fsspec_azure_blob(self): - url = "public/curated/covid-19/ecdc_cases/latest/ecdc_cases.csv" - account_name = "pandemicdatalake" - azure_prefixes = ["abfs", "az"] - fsspec_loader_dp = {} - - for prefix in azure_prefixes: - fsspec_lister_dp = FSSpecFileLister(f"{prefix}://{url}", account_name=account_name) - fsspec_loader_dp[prefix] = FSSpecFileOpener(fsspec_lister_dp, account_name=account_name).parse_csv() - - res_abfs = list(fsspec_loader_dp["abfs"])[0] - res_az = list(fsspec_loader_dp["az"])[0] - self.assertEqual(res_abfs, res_az, f"{input} failed") - - @skipIfAWS - @unittest.skip("S3 IterDataPipes are deprecated") - def test_disabled_s3_io_iterdatapipe(self): - file_urls = ["s3://ai2-public-datasets"] - with self.assertRaisesRegex(ModuleNotFoundError, "TorchData must be built with"): - _ = S3FileLister(IterableWrapper(file_urls)) - with self.assertRaisesRegex(ModuleNotFoundError, "TorchData must be built with"): - _ = S3FileLoader(IterableWrapper(file_urls)) - - @skipIfNoAWS - @unittest.skip("S3 IterDataPipes are deprecated") - @unittest.skipIf(IS_M1, "PyTorch M1 CI Machine doesn't allow accessing") - def test_s3_io_iterdatapipe(self): - # S3FileLister: different inputs - input_list = [ - ["s3://ai2-public-datasets"], # bucket without '/' - ["s3://ai2-public-datasets/"], # bucket with '/' - ["s3://ai2-public-datasets/charades"], # folder without '/' - ["s3://ai2-public-datasets/charades/"], # folder without '/' - ["s3://ai2-public-datasets/charad"], # prefix - [ - "s3://ai2-public-datasets/charades/Charades_v1", - "s3://ai2-public-datasets/charades/Charades_vu17", - ], # prefixes - ["s3://ai2-public-datasets/charades/Charades_v1.zip"], # single file - [ - "s3://ai2-public-datasets/charades/Charades_v1.zip", - "s3://ai2-public-datasets/charades/Charades_v1_flow.tar", - "s3://ai2-public-datasets/charades/Charades_v1_rgb.tar", - "s3://ai2-public-datasets/charades/Charades_v1_480.zip", - ], # multiple files - [ - "s3://ai2-public-datasets/charades/Charades_v1.zip", - "s3://ai2-public-datasets/charades/Charades_v1_flow.tar", - "s3://ai2-public-datasets/charades/Charades_v1_rgb.tar", - "s3://ai2-public-datasets/charades/Charades_v1_480.zip", - "s3://ai2-public-datasets/charades/Charades_vu17", - ], # files + prefixes - ] - for input in input_list: - s3_lister_dp = S3FileLister(IterableWrapper(input), region="us-west-2") - self.assertEqual(sum(1 for _ in s3_lister_dp), self.__get_s3_cnt(input), f"{input} failed") - - # S3FileLister: prefixes + different region - file_urls = [ - "s3://aft-vbi-pds/bin-images/111", - "s3://aft-vbi-pds/bin-images/222", - ] - s3_lister_dp = S3FileLister(IterableWrapper(file_urls), request_timeout_ms=10000, region="us-east-1") - self.assertEqual(sum(1 for _ in s3_lister_dp), 2212, f"{input} failed") - - # S3FileLister: incorrect inputs - input_list = [ - [""], - ["ai2-public-datasets"], - ["s3://"], - ["s3:///bin-images"], - ] - for input in input_list: - with self.assertRaises(ValueError, msg=f"{input} should raise ValueError."): - s3_lister_dp = S3FileLister(IterableWrapper(input), region="us-east-1") - for _ in s3_lister_dp: - pass - - input = [["s3://aft-vbi-pds/bin-images/100730.jpg"], 1] - s3_loader_dp = S3FileLoader(input[0], region="us-east-1") - self.assertEqual(sum(1 for _ in s3_loader_dp), input[1], f"{input[0]} failed") - - # S3FileLoader: incorrect inputs - input_list = [ - [""], - ["ai2-public-datasets"], - ["s3://"], - ["s3:///bin-images"], - ["s3://ai2-public-datasets/bin-image"], - ] - for input in input_list: - with self.assertRaises(ValueError, msg=f"{input} should raise ValueError."): - s3_loader_dp = S3FileLoader(input, region="us-east-1") - for _ in s3_loader_dp: - pass - - # integration test - input = [["s3://charades-tar-shards/"], 10] - s3_lister_dp = S3FileLister(IterableWrapper(input[0]), region="us-west-2") - s3_loader_dp = S3FileLoader(s3_lister_dp, region="us-west-2") - self.assertEqual(sum(1 for _ in s3_loader_dp), input[1], f"{input[0]} failed") - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_s3io.py b/test/test_s3io.py deleted file mode 100644 index 92ef577af..000000000 --- a/test/test_s3io.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest -from unittest.mock import MagicMock, patch - -import expecttest -from torch.testing._internal.common_utils import IS_SANDCASTLE -from torchdata.datapipes.iter import IterableWrapper, S3FileLister - -skipIfSandcastle = unittest.skipIf(IS_SANDCASTLE, "Skip for internal testing") - - -@skipIfSandcastle -@unittest.skip("S3 IterDataPipes are deprecated") -@patch("torchdata._torchdata") -class TestS3FileListerIterDataPipe(expecttest.TestCase): - def test_list_files(self, mock_torchdata): - s3handler_mock = MagicMock() - mock_torchdata.S3Handler.return_value = s3handler_mock - s3handler_mock.list_files = MagicMock( - side_effect=[["s3://bucket-name/folder/a.txt", "s3://bucket-name/folder/b.csv"], []] - ) - s3_prefixes = IterableWrapper(["s3://bucket-name/folder/"]) - dp_s3_urls = S3FileLister(s3_prefixes) - assert list(dp_s3_urls) == ["s3://bucket-name/folder/a.txt", "s3://bucket-name/folder/b.csv"] - - def test_list_files_with_filter_mask(self, mock_torchdata): - s3handler_mock = MagicMock() - mock_torchdata.S3Handler.return_value = s3handler_mock - s3handler_mock.list_files = MagicMock( - side_effect=[["s3://bucket-name/folder/a.txt", "s3://bucket-name/folder/b.csv"], []] - ) - s3_prefixes = IterableWrapper(["s3://bucket-name/folder/"]) - dp_s3_urls = S3FileLister(s3_prefixes, masks="*.csv") - assert list(dp_s3_urls) == ["s3://bucket-name/folder/b.csv"] diff --git a/test/test_seed_generator.py b/test/test_seed_generator.py deleted file mode 100644 index da0fa4d0d..000000000 --- a/test/test_seed_generator.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -from torchdata.dataloader2.random import SeedGenerator -from torchdata.dataloader2.random._philox import PhiloxEngine - - -class TestPhilox(unittest.TestCase): - def test_philox_engine_generate(self): - prng = PhiloxEngine() - with self.assertRaisesRegex(AssertionError, "Please provide seed"): - prng.generate() - - prng.seed(123) - s0 = [prng.generate() for _ in range(10)] - - # Same seed - prng = PhiloxEngine(seed=123) - s1 = [prng.generate() for _ in range(10)] - self.assertEqual(s0, s1) - - # Reset - prng.seed(123) - s2 = [prng.generate() for _ in range(10)] - self.assertEqual(s1, s2) - - # Different seeds - prng = PhiloxEngine(seed=321) - s3 = [prng.generate() for _ in range(10)] - self.assertNotEqual(s0, s3) - - def test_philox_engine_spawn(self): - prng = PhiloxEngine() - with self.assertRaisesRegex(AssertionError, "Expected a non-negative value"): - prng.spawn(-1) - with self.assertRaisesRegex(AssertionError, "Please provide seed"): - prng.spawn(0) - - prng.seed(123) - s0 = [prng.spawn(i)._seed for i in range(10)] - - # Same seed - prng = PhiloxEngine(seed=123) - s1 = [prng.spawn(i)._seed for i in range(10)] - self.assertEqual(s0, s1) - - # Generate after spawn - sprng1 = prng.spawn(1) - sprng2 = prng.spawn(1) - ss1 = [sprng1.generate() for _ in range(10)] - ss2 = [sprng2.generate() for _ in range(10)] - self.assertEqual(ss1, ss2) - - sprng3 = prng.spawn(2) - ss3 = [sprng3.generate() for _ in range(10)] - self.assertNotEqual(ss1, ss3) - - # Reset - prng.seed(123) - s2 = [prng.spawn(i)._seed for i in range(10)] - self.assertEqual(s1, s2) - - # Different seeds - prng = PhiloxEngine(seed=321) - s3 = [prng.spawn(i)._seed for i in range(10)] - self.assertNotEqual(s0, s3) - - -class TestSeedGenerator(unittest.TestCase): - def test_seed_generator_generate(self): - # Generate seeds - sg = SeedGenerator(123) - s0 = [sg.generate_seed() for _ in range(10)] - - # Reset - sg.seed(123) - s1 = [sg.generate_seed() for _ in range(10)] - self.assertEqual(s0, s1) - - # Different Seeds - sg.seed(321) - s2 = [sg.generate_seed() for _ in range(10)] - self.assertNotEqual(s0, s2) - - def test_seed_generator_spawn(self): - sg = SeedGenerator(123) - - # Spawn new Seed Generators - sg1 = sg.spawn(1) - sg2 = sg.spawn(2) - - for _ in range(10): - self.assertNotEqual(sg1.generate_seed(), sg2.generate_seed()) - # Generate shared seeds - self.assertEqual(sg1.generate_shared_seed(), sg2.generate_shared_seed()) - - sg1_1 = sg.spawn(1) - sg1_2 = sg.spawn(1) - for _ in range(10): - self.assertEqual(sg1_1.generate_seed(), sg1_2.generate_seed()) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_serialization.py b/test/test_serialization.py deleted file mode 100644 index 783f38329..000000000 --- a/test/test_serialization.py +++ /dev/null @@ -1,456 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import pickle -import unittest -import warnings -from functools import partial -from io import StringIO -from operator import itemgetter -from typing import List - -import expecttest -import torchdata.datapipes.iter as iterdp -import torchdata.datapipes.map as mapdp -from _utils._common_utils_for_test import create_temp_dir, create_temp_files -from torch.utils._import_utils import dill_available -from torchdata.datapipes.iter import IterableWrapper -from torchdata.datapipes.map import SequenceWrapper - -if dill_available(): - import dill - - dill.extend(use_dill=False) - -try: - import datasets -except ImportError: - datasets = None - -try: - import fsspec -except ImportError: - fsspec = None - -try: - import iopath -except ImportError: - iopath = None - -try: - import subprocess - - import rarfile - - try: - rarfile.tool_setup() - subprocess.run(("rar", "-?"), check=True) - except (rarfile.RarCannotExec, subprocess.CalledProcessError): - rarfile = None -except (ModuleNotFoundError, FileNotFoundError): - rarfile = None - -try: - import torcharrow - import torcharrow.dtypes as dt - - DTYPE = dt.Struct([dt.Field("Values", dt.int32)]) -except ImportError: - torcharrow = None - dt = None - DTYPE = None - - -def _fake_batch_fn(batch): - return [d + 1 for d in batch] - - -def _fake_fn_ls(x): - return [x, x] - - -def _filepath_fn(name: str, dir) -> str: - return os.path.join(dir, os.path.basename(name)) - - -def _filter_by_module_availability(datapipes): - filter_set = set() - if datasets is None: - filter_set.update([iterdp.HuggingFaceHubReader]) - if fsspec is None: - filter_set.update([iterdp.FSSpecFileLister, iterdp.FSSpecFileOpener, iterdp.FSSpecSaver]) - if iopath is None: - filter_set.update([iterdp.IoPathFileLister, iterdp.IoPathFileOpener, iterdp.IoPathSaver]) - if rarfile is None: - filter_set.update([iterdp.RarArchiveLoader]) - if torcharrow is None or not dill_available(): - filter_set.update([iterdp.DataFrameMaker, iterdp.ParquetDataFrameLoader]) - return [dp for dp in datapipes if dp[0] not in filter_set] - - -def _convert_to_tensor(data): - return torch.tensor(data) - - -class TestIterDataPipeSerialization(expecttest.TestCase): - def setUp(self) -> None: - self.temp_dir = create_temp_dir() - self.temp_files = create_temp_files(self.temp_dir) - self.temp_sub_dir = create_temp_dir(self.temp_dir.name) - self.temp_sub_files = create_temp_files(self.temp_sub_dir, 4, False) - - def tearDown(self) -> None: - try: - self.temp_sub_dir.cleanup() - self.temp_dir.cleanup() - except Exception as e: - warnings.warn(f"TestIterDataPipeSerialization was not able to cleanup temp dir due to {e}") - - def _serialization_test_helper(self, datapipe, use_dill): - if use_dill: - serialized_dp = dill.dumps(datapipe) - deserialized_dp = dill.loads(serialized_dp) - else: - serialized_dp = pickle.dumps(datapipe) - deserialized_dp = pickle.loads(serialized_dp) - try: - self.assertEqual(list(datapipe), list(deserialized_dp)) - except AssertionError as e: - print(f"{datapipe} is failing.") - raise e - - def _serialization_dataframe_test_helper(self, datapipe, use_dill): - if use_dill: - serialized_dp = dill.dumps(datapipe) - deserialized_dp = dill.loads(serialized_dp) - else: - serialized_dp = pickle.dumps(datapipe) - deserialized_dp = pickle.loads(serialized_dp) - for df1, df2 in zip(datapipe, deserialized_dp): - for exp, act in zip(df1, df2): - self.assertEqual(exp, act) - - def _serialization_test_for_single_dp(self, dp, use_dill, is_dataframe=False): - test_helper_fn = self._serialization_dataframe_test_helper if is_dataframe else self._serialization_test_helper - # 1. Testing for serialization before any iteration starts - test_helper_fn(dp, use_dill) - # 2. Testing for serialization afterDataPipe is partially read - it = iter(dp) - _ = next(it) - test_helper_fn(dp, use_dill) - # 3. Testing for serialization after DataPipe is fully read - it = iter(dp) - _ = list(it) - test_helper_fn(dp, use_dill) - - def _serialization_test_for_dp_with_children(self, dp1, dp2, use_dill): - # 1. Testing for serialization before any iteration starts - self._serialization_test_helper(dp1, use_dill=use_dill) - self._serialization_test_helper(dp2, use_dill=use_dill) - # 2. Testing for serialization after DataPipe is partially read - it1, it2 = iter(dp1), iter(dp2) - _, _ = next(it1), next(it2) - self._serialization_test_helper(dp1, use_dill=use_dill) - self._serialization_test_helper(dp2, use_dill=use_dill) - # 2.5. Testing for serialization after one child DataPipe is fully read - # (Only for DataPipes with children DataPipes) - it1 = iter(dp1) - _ = list(it1) # fully read one child - self._serialization_test_helper(dp1, use_dill=use_dill) - self._serialization_test_helper(dp2, use_dill=use_dill) - # 3. Testing for serialization after DataPipe is fully read - it2 = iter(dp2) - _ = list(it2) # fully read the other child - self._serialization_test_helper(dp1, use_dill=use_dill) - self._serialization_test_helper(dp2, use_dill=use_dill) - - def test_serializable(self) -> None: - # A tuple of 4 objects - # (DataPipeConstructor, custom_input_datapipe=None, dp_args=(), dp_kwargs={}) - picklable_datapipes: List = [ - (iterdp.BatchMapper, IterableWrapper([(0, 0), (0, 0), (0, 0), (0, 0)]), (_fake_batch_fn, 2, 1), {}), - (iterdp.BucketBatcher, IterableWrapper([0, 0, 0, 0, 0, 0, 0]), (5,), {}), - (iterdp.Bz2FileLoader, None, (), {}), - ( - iterdp.CSVDictParser, - IterableWrapper( - [("f1", StringIO("Label,1,1\nLabel,2,2\nLabel,3,3")), ("f2", StringIO("L,1,1\r\nL,2,2\r\nL,3,3"))] - ), - (), - {}, - ), - ( - iterdp.CSVParser, - IterableWrapper( - [("f1", StringIO("Label,1,1\nLabel,2,2\nLabel,3,3")), ("f2", StringIO("L,1,1\r\nL,2,2\r\nL,3,3"))] - ), - (), - {}, - ), - (iterdp.Cycler, None, (2,), {}), - (iterdp.DataFrameMaker, IterableWrapper([(i,) for i in range(3)]), (), {"dtype": DTYPE}), - (iterdp.Decompressor, None, (), {}), - (iterdp.Dropper, IterableWrapper([(0, 0), (0, 0), (0, 0), (0, 0)]), ([1]), {}), - (iterdp.Enumerator, None, (2,), {}), - (iterdp.FlatMapper, None, (_fake_fn_ls,), {}), - (iterdp.ShuffledFlatMapper, None, (_fake_fn_ls,), {"buffer_size": 1}), - (iterdp.Flattener, IterableWrapper([(0, (0, 1)), (0, (0, 1)), (0, (0, 1)), (0, (0, 1))]), ([1]), {}), - (iterdp.FSSpecFileLister, ".", (), {}), - (iterdp.FSSpecFileOpener, None, (), {}), - ( - iterdp.FSSpecSaver, - IterableWrapper([("1.txt", b"DATA1"), ("2.txt", b"DATA2"), ("3.txt", b"DATA3")]), - (), - {"mode": "wb", "filepath_fn": partial(_filepath_fn, dir=self.temp_dir.name)}, - ), - (iterdp.GDriveReader, None, (), {}), - (iterdp.HashChecker, None, ({},), {}), - (iterdp.Header, None, (3,), {}), - (iterdp.HttpReader, None, (), {}), - (iterdp.HuggingFaceHubReader, None, (), {}), - # TODO(593): (ejguan): Deterministic serialization is required - # (iterdp.InBatchShuffler, IterableWrapper(range(10)).batch(3), (), {}), - (iterdp.InMemoryCacheHolder, None, (), {}), - (iterdp.IndexAdder, IterableWrapper([{"a": 1, "b": 2}, {"c": 3, "a": 1}]), ("label",), {}), - (iterdp.IoPathFileLister, ".", (), {}), - (iterdp.IoPathFileOpener, None, (), {}), - ( - iterdp.IoPathSaver, - IterableWrapper([("1.txt", b"DATA1"), ("2.txt", b"DATA2"), ("3.txt", b"DATA3")]), - (), - {"mode": "wb", "filepath_fn": partial(_filepath_fn, dir=self.temp_dir.name)}, - ), - ( - iterdp.IterKeyZipper, - IterableWrapper([("a", 100), ("b", 200), ("c", 300)]), - (IterableWrapper([("a", 1), ("b", 2), ("c", 3)]), itemgetter(0), itemgetter(0)), - {}, - ), - ( - iterdp.JsonParser, - IterableWrapper( - [ - ("1.json", StringIO('["fo", {"ba":["baz", null, 1.0, 2]}]')), - ("2.json", StringIO('{"__cx__": true, "r": 1, "i": 2}')), - ] - ), - (), - {}, - ), - (iterdp.LengthSetter, None, (3,), {}), - ( - iterdp.LineReader, - IterableWrapper( - [("file1", StringIO("Line1\nLine2")), ("file2", StringIO("Line2,1\r\nLine2,2\r\nLine2,3"))] - ), - (), - {}, - ), - (iterdp.MapToIterConverter, SequenceWrapper(range(10)), (), {}), - ( - iterdp.MaxTokenBucketizer, - IterableWrapper(["1", "22", "1", "4444", "333", "1", "22", "22", "333"]), - (4,), - {}, - ), - ( - iterdp.MapKeyZipper, - IterableWrapper([("a", 1), ("b", 2), ("c", 3)]), - (SequenceWrapper({"a": 100, "b": 200, "c": 300}), itemgetter(0)), - {}, - ), - ( - iterdp.MultiplexerLongest, - IterableWrapper(range(10)), - (), - {}, - ), - (iterdp.OnDiskCacheHolder, None, (), {}), - (iterdp.OnlineReader, None, (), {}), - ( - iterdp.ParagraphAggregator, - IterableWrapper([("f1", "L1"), ("f1", "L2"), ("f2", "21"), ("f2", "22")]), - (), - {}, - ), - (iterdp.Prefetcher, None, (), {}), - (iterdp.ParquetDataFrameLoader, None, (), {"dtype": DTYPE}), - (iterdp.RarArchiveLoader, None, (), {}), - ( - iterdp.Rows2Columnar, - IterableWrapper([[{"a": 1}, {"b": 2, "a": 1}], [{"a": 1, "b": 200}, {"c": 3}]]), - (), - {}, - ), - (iterdp.Repeater, None, (2,), {}), - (iterdp.SampleMultiplexer, {IterableWrapper([0] * 10): 0.5, IterableWrapper([1] * 10): 0.5}, (), {}), - ( - iterdp.Saver, - IterableWrapper([("1.txt", b"DATA1"), ("2.txt", b"DATA2"), ("3.txt", b"DATA3")]), - (), - {"mode": "wb", "filepath_fn": partial(_filepath_fn, dir=self.temp_dir.name)}, - ), - (iterdp.Slicer, IterableWrapper([(0, 0), (0, 0), (0, 0), (0, 0)]), ([1]), {}), - (iterdp.TarArchiveLoader, None, (), {}), - # TODO(594): Add serialization tests for optional DataPipe - # (iterdp.TFRecordLoader, None, (), {}), - (iterdp.ThreadPoolMapper, None, (_fake_fn_ls,), {}), - (iterdp.UnZipper, IterableWrapper([(i, i + 10) for i in range(10)]), (), {"sequence_length": 2}), - (iterdp.WebDataset, IterableWrapper([("foo.txt", b"1"), ("bar.txt", b"2")]), (), {}), - (iterdp.XzFileLoader, None, (), {}), - (iterdp.ZipArchiveLoader, None, (), {}), - (iterdp.ZipperLongest, IterableWrapper(range(10)), (), {}), - ] - - picklable_datapipes = _filter_by_module_availability(picklable_datapipes) - - # Skipping value comparison for these DataPipes - # Most of them return streams not comparable by `self.assertEqual` - # Others are similar to caching where the outputs depend on other DataPipes - dp_skip_comparison = { - iterdp.Bz2FileLoader, - iterdp.Decompressor, - iterdp.FileOpener, - iterdp.FSSpecFileOpener, - iterdp.GDriveReader, - iterdp.IoPathFileOpener, - iterdp.HashChecker, - iterdp.HttpReader, - iterdp.HuggingFaceHubReader, - iterdp.OnDiskCacheHolder, - iterdp.OnlineReader, - iterdp.ParquetDataFrameLoader, - iterdp.SampleMultiplexer, - iterdp.RarArchiveLoader, - iterdp.TarArchiveLoader, - iterdp.TFRecordLoader, - iterdp.XzFileLoader, - iterdp.ZipArchiveLoader, - } - # These DataPipes produce multiple DataPipes as outputs and those should be compared - dp_compare_children = {iterdp.UnZipper} - - for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes: - try: - # Creating input (usually a DataPipe) for the specific dpipe being tested - if custom_input is None: - custom_input = IterableWrapper(range(10)) - - if dpipe in dp_skip_comparison: # Mke sure they are picklable and loadable (no value comparison) - datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] - serialized_dp = pickle.dumps(datapipe) - _ = pickle.loads(serialized_dp) - elif dpipe in dp_compare_children: # DataPipes that have children - dp1, dp2 = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] - self._serialization_test_for_dp_with_children(dp1, dp2, use_dill=False) - else: # Single DataPipe that requires comparison - datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] - is_dataframe = issubclass(dpipe, (iterdp.DataFrameMaker, iterdp.ParquetDataFrameLoader)) - self._serialization_test_for_single_dp(datapipe, use_dill=False, is_dataframe=is_dataframe) - except Exception as e: - print(f"{dpipe} is failing.") - raise e - - def test_serializable_with_dill(self) -> None: - """Only for DataPipes that take in a function as argument""" - input_dp = IterableWrapper(range(10)) - ref_idp = IterableWrapper(range(10)) - ref_mdp = SequenceWrapper(range(10)) - - unpicklable_datapipes: List = [ - (iterdp.BatchMapper, (lambda batch: [d + 1 for d in batch], 2), {}), - (iterdp.FlatMapper, (lambda x: [x, x],), {}), - (iterdp.ShuffledFlatMapper, (lambda x: [x, x],), {"buffer_size": 1}), - (iterdp.IterKeyZipper, (ref_idp, lambda x: x, None, True, 100), {}), - (iterdp.MapKeyZipper, (ref_mdp, lambda x: x), {}), - (iterdp.OnDiskCacheHolder, (lambda x: x,), {}), - (iterdp.ParagraphAggregator, (lambda x: x,), {}), - (iterdp.ThreadPoolMapper, (lambda x: x,), {}), - ] - # Skipping value comparison for these DataPipes - dp_skip_comparison = {iterdp.OnDiskCacheHolder, iterdp.ParagraphAggregator} - for dpipe, dp_args, dp_kwargs in unpicklable_datapipes: - if dill_available(): - try: - if dpipe in dp_skip_comparison: # Make sure they are picklable/loadable (no value comparison) - datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] - serialized_dp = dill.dumps(datapipe) - _ = dill.loads(serialized_dp) - else: - datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] - self._serialization_test_for_single_dp(datapipe, use_dill=True) - except Exception as e: - print(f"{dpipe} is failing.") - raise e - - else: - dp_no_attribute_error = (iterdp.OnDiskCacheHolder,) - try: - with self.assertWarnsRegex(UserWarning, r"^Local function is not supported by pickle"): - datapipe = dpipe(input_dp, *dp_args, **dp_kwargs) # type: ignore[call-arg] - if isinstance(datapipe, dp_no_attribute_error): - _ = pickle.dumps(datapipe) - else: - with self.assertRaises(AttributeError): - _ = pickle.dumps(datapipe) - except Exception as e: - print(f"{dpipe} is failing.") - raise e - - -class TestMapDataPipeSerialization(expecttest.TestCase): - def _serialization_test_helper(self, datapipe): - serialized_dp = pickle.dumps(datapipe) - deserialized_dp = pickle.loads(serialized_dp) - try: - self.assertEqual(list(datapipe), list(deserialized_dp)) - except AssertionError as e: - print(f"{datapipe} is failing.") - raise e - - def _serialization_test_for_dp_with_children(self, dp1, dp2): - self._serialization_test_helper(dp1) - self._serialization_test_helper(dp2) - - def test_serializable(self) -> None: - picklable_datapipes: List = [ - (mapdp.InMemoryCacheHolder, None, (), {}), - (mapdp.IterToMapConverter, IterableWrapper([(i, i) for i in range(10)]), (), {}), - (mapdp.UnZipper, SequenceWrapper([(i, i + 10) for i in range(10)]), (), {"sequence_length": 2}), - ] - - dp_skip_comparison = set() - # These DataPipes produce multiple DataPipes as outputs and those should be compared - dp_compare_children = {mapdp.UnZipper} - - for dpipe, custom_input, dp_args, dp_kwargs in picklable_datapipes: - try: - # Creating input (usually a DataPipe) for the specific dpipe being tested - if custom_input is None: - custom_input = SequenceWrapper(range(10)) - - if dpipe in dp_skip_comparison: # Mke sure they are picklable and loadable (no value comparison) - datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] - serialized_dp = pickle.dumps(datapipe) - _ = pickle.loads(serialized_dp) - elif dpipe in dp_compare_children: # DataPipes that have children - dp1, dp2 = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] - self._serialization_test_for_dp_with_children(dp1, dp2) - else: # Single DataPipe that requires comparison - datapipe = dpipe(custom_input, *dp_args, **dp_kwargs) # type: ignore[call-arg] - self._serialization_test_helper(datapipe) - except Exception as e: - print(f"{dpipe} is failing.") - raise e - - def test_serializable_with_dill(self) -> None: - """Only for DataPipes that take in a function as argument""" - pass - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_text_examples.py b/test/test_text_examples.py deleted file mode 100644 index deae65867..000000000 --- a/test/test_text_examples.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -import torch.multiprocessing as mp - -from torch.testing._internal.common_utils import slowTest -from torch.utils.data import DataLoader - -from torchtext.datasets import AG_NEWS, AmazonReviewPolarity, IMDB, SQuAD1, SQuAD2, SST2 - - -# TODO(124): Replace the following tests with the corresponding tests in TorchText -class TestTextExamples(unittest.TestCase): - def _test_helper(self, fn): - dp = fn() - for stage_dp in dp: - _ = list(stage_dp) - - @staticmethod - def _collate_fn(batch): - return batch - - def _test_DL_helper(self, fn): - mp.set_sharing_strategy("file_system") - dp = fn() - for stage_dp in dp: - dl = DataLoader( - stage_dp, - batch_size=8, - num_workers=4, - collate_fn=TestTextExamples._collate_fn, - multiprocessing_context="spawn", - ) - _ = list(dl) - - def test_SST(self) -> None: - self._test_helper(SST2) - self._test_DL_helper(SST2) - - def test_AG_NEWS(self) -> None: - self._test_helper(AG_NEWS) - self._test_DL_helper(AG_NEWS) - - @slowTest - def test_AmazonReviewPolarity(self) -> None: - self._test_helper(AmazonReviewPolarity) - self._test_DL_helper(AmazonReviewPolarity) - - @slowTest - def test_IMDB(self) -> None: - self._test_helper(IMDB) - self._test_DL_helper(IMDB) - - def test_SQuAD1(self) -> None: - self._test_helper(SQuAD1) - self._test_DL_helper(SQuAD1) - - def test_SQuAD2(self) -> None: - self._test_helper(SQuAD2) - self._test_DL_helper(SQuAD2) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_tfrecord.py b/test/test_tfrecord.py deleted file mode 100644 index eb660993f..000000000 --- a/test/test_tfrecord.py +++ /dev/null @@ -1,295 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import unittest -import warnings -from functools import partial - -import expecttest - -import torch - -from _utils._common_utils_for_test import IS_M1, reset_after_n_next_calls -from torchdata.datapipes.iter import ( - FileLister, - FileOpener, - FSSpecFileLister, - FSSpecFileOpener, - FSSpecSaver, - IterableWrapper, - TFRecordLoader, -) - -try: - import google.protobuf as _protobuf - - del _protobuf - HAS_PROTOBUF = True -except ImportError: - HAS_PROTOBUF = False -skipIfNoPROTOBUF = unittest.skipIf(not HAS_PROTOBUF, "no google protobuf") - - -class TestDataPipeTFRecord(expecttest.TestCase): - def setUp(self) -> None: - self.temp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "_fakedata", "tfrecord") - - def assertArrayEqual(self, arr1, arr2): - if isinstance(arr1, list): - arr1 = torch.stack(arr1) - if isinstance(arr2, list): - arr2 = torch.stack(arr2) - torch.testing.assert_close(arr1, arr2, check_dtype=False) - - def _ground_truth_data(self) -> None: - for i in range(4): - x = torch.arange(i * 10, (i + 1) * 10) - yield { - "x_float": x, - "x_int": (x * 10).long(), - "x_byte": [b"test str"], - } - - def _ground_truth_seq_data(self) -> None: - for i in range(4): - x = torch.arange(i * 10, (i + 1) * 10) - rep = 2 * i + 3 - yield {"x_float": x, "x_int": (x * 10).long(), "x_byte": [b"test str"]}, { - "x_float_seq": [x] * rep, - "x_int_seq": [(x * 10).long()] * rep, - "x_byte_seq": [[b"test str"]] * rep, - } - - @skipIfNoPROTOBUF - @unittest.skipIf( - IS_M1, "Protobuf 3.19.* is not supported on MacOS M1, but Tensorflow is incompatible with Protobuf 4" - ) - @torch.no_grad() - def test_tfrecord_loader_example_iterdatapipe(self) -> None: - filename = f"{self.temp_dir}/example.tfrecord" - datapipe1 = IterableWrapper([filename]) - datapipe2 = FileOpener(datapipe1, mode="b") - - # Functional Test: test if the returned data is correct - tfrecord_parser = datapipe2.load_from_tfrecord() - result = list(tfrecord_parser) - self.assertEqual(len(result), 4) - expected_res = final_expected_res = list(self._ground_truth_data()) - for true_data, loaded_data in zip(expected_res, result): - self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) - for key in ["x_float", "x_int"]: - self.assertArrayEqual(true_data[key], loaded_data[key]) - self.assertEqual(len(loaded_data["x_byte"]), 1) - self.assertEqual(true_data["x_byte"][0], loaded_data["x_byte"][0]) - - # Functional Test: test if the shape of the returned data is correct when using spec - tfrecord_parser = datapipe2.load_from_tfrecord( - { - "x_float": ((5, 2), torch.float64), - "x_int": ((5, 2), torch.int32), - "x_byte": (tuple(), None), - } - ) - result = list(tfrecord_parser) - self.assertEqual(len(result), 4) - expected_res = [ - { - "x_float": x["x_float"].reshape(5, 2), - "x_int": x["x_int"].reshape(5, 2), - "x_byte": x["x_byte"][0], - } - for x in self._ground_truth_data() - ] - for true_data, loaded_data in zip(expected_res, result): - self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) - self.assertArrayEqual(true_data["x_float"], loaded_data["x_float"].float()) - self.assertArrayEqual(true_data["x_int"], loaded_data["x_int"].long()) - self.assertEqual(loaded_data["x_float"].dtype, torch.float64) - self.assertEqual(loaded_data["x_int"].dtype, torch.int32) - self.assertEqual(true_data["x_byte"], loaded_data["x_byte"]) - - # Functional Test: ignore features missing from spec - tfrecord_parser = datapipe2.load_from_tfrecord( - { - "x_float": ((10,), torch.float32), - } - ) - result = list(tfrecord_parser) - self.assertEqual(len(result), 4) - expected_res = [ - { - "x_float": x["x_float"], - } - for x in self._ground_truth_data() - ] - for true_data, loaded_data in zip(expected_res, result): - self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) - self.assertArrayEqual(true_data["x_float"], loaded_data["x_float"].float()) - - # Functional Test: raises error if missing spec feature - with self.assertRaises(RuntimeError): - tfrecord_parser = datapipe2.load_from_tfrecord( - { - "x_float_unknown": ((5, 2), torch.float64), - "x_int": ((5, 2), torch.int32), - "x_byte": (tuple(), None), - } - ) - result = list(tfrecord_parser) - - # Reset Test: - tfrecord_parser = TFRecordLoader(datapipe2) - expected_res = final_expected_res - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(tfrecord_parser, n_elements_before_reset) - self.assertEqual(len(expected_res[:n_elements_before_reset]), len(res_before_reset)) - for true_data, loaded_data in zip(expected_res[:n_elements_before_reset], res_before_reset): - self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) - for key in ["x_float", "x_int"]: - self.assertArrayEqual(true_data[key], loaded_data[key]) - self.assertEqual(true_data["x_byte"][0], loaded_data["x_byte"][0]) - self.assertEqual(len(expected_res), len(res_after_reset)) - for true_data, loaded_data in zip(expected_res, res_after_reset): - self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) - for key in ["x_float", "x_int"]: - self.assertArrayEqual(true_data[key], loaded_data[key]) - self.assertEqual(true_data["x_byte"][0], loaded_data["x_byte"][0]) - - # __len__ Test: length isn't implemented since it cannot be known ahead of time - with self.assertRaisesRegex(TypeError, "doesn't have valid length"): - len(tfrecord_parser) - - @skipIfNoPROTOBUF - @unittest.skipIf( - IS_M1, "Protobuf 3.19.* is not supported on MacOS M1, but Tensorflow is incompatible with Protobuf 4" - ) - @torch.no_grad() - def test_tfrecord_loader_sequence_example_iterdatapipe(self) -> None: - filename = f"{self.temp_dir}/sequence_example.tfrecord" - datapipe1 = IterableWrapper([filename]) - datapipe2 = FileOpener(datapipe1, mode="b") - - # Functional Test: test if the returned data is correct - tfrecord_parser = datapipe2.load_from_tfrecord() - result = list(tfrecord_parser) - self.assertEqual(len(result), 4) - expected_res = final_expected_res = list(self._ground_truth_seq_data()) - for (true_data_ctx, true_data_seq), loaded_data in zip(expected_res, result): - self.assertSetEqual(set(true_data_ctx.keys()).union(true_data_seq.keys()), set(loaded_data.keys())) - for key in ["x_float", "x_int"]: - self.assertArrayEqual(true_data_ctx[key], loaded_data[key]) - self.assertEqual(len(true_data_seq[key + "_seq"]), len(loaded_data[key + "_seq"])) - self.assertIsInstance(loaded_data[key + "_seq"], list) - for a1, a2 in zip(true_data_seq[key + "_seq"], loaded_data[key + "_seq"]): - self.assertArrayEqual(a1, a2) - self.assertEqual(true_data_ctx["x_byte"], loaded_data["x_byte"]) - self.assertListEqual(true_data_seq["x_byte_seq"], loaded_data["x_byte_seq"]) - - # Functional Test: test if the shape of the returned data is correct when using spec - tfrecord_parser = datapipe2.load_from_tfrecord( - { - "x_float": ((5, 2), torch.float64), - "x_int": ((5, 2), torch.int32), - "x_byte": (tuple(), None), - "x_float_seq": ((-1, 5, 2), torch.float64), - "x_int_seq": ((-1, 5, 2), torch.int32), - "x_byte_seq": ((-1,), None), - } - ) - result = list(tfrecord_parser) - self.assertEqual(len(result), 4) - - expected_res = [ - ( - { - "x_float": x["x_float"].reshape(5, 2), - "x_int": x["x_int"].reshape(5, 2), - "x_byte": x["x_byte"][0], - }, - { - "x_float_seq": [y.reshape(5, 2) for y in z["x_float_seq"]], - "x_int_seq": [y.reshape(5, 2) for y in z["x_int_seq"]], - "x_byte_seq": [y[0] for y in z["x_byte_seq"]], - }, - ) - for x, z in self._ground_truth_seq_data() - ] - for (true_data_ctx, true_data_seq), loaded_data in zip(expected_res, result): - self.assertSetEqual(set(true_data_ctx.keys()).union(true_data_seq.keys()), set(loaded_data.keys())) - for key in ["x_float", "x_int"]: - l_loaded_data = loaded_data[key] - if key == "x_float": - l_loaded_data = l_loaded_data.float() - else: - l_loaded_data = l_loaded_data.int() - self.assertArrayEqual(true_data_ctx[key], l_loaded_data) - self.assertArrayEqual(true_data_seq[key + "_seq"], loaded_data[key + "_seq"]) - self.assertEqual(true_data_ctx["x_byte"], loaded_data["x_byte"]) - self.assertListEqual(true_data_seq["x_byte_seq"], loaded_data["x_byte_seq"]) - - # Functional Test: ignore features missing from spec - tfrecord_parser = datapipe2.load_from_tfrecord( - { - "x_float": ((10,), torch.float32), - } - ) - result = list(tfrecord_parser) - self.assertEqual(len(result), 4) - expected_res = [ - { - "x_float": x["x_float"], - } - for x, z in self._ground_truth_seq_data() - ] - for true_data, loaded_data in zip(expected_res, result): - self.assertSetEqual(set(true_data.keys()), set(loaded_data.keys())) - self.assertArrayEqual(true_data["x_float"], loaded_data["x_float"].float()) - - # Functional Test: raises error if missing spec feature - with self.assertRaises(RuntimeError): - tfrecord_parser = datapipe2.load_from_tfrecord( - {"x_float_unknown": ((5, 2), torch.float64), "x_int": ((5, 2), torch.int32), "x_byte": None} - ) - result = list(tfrecord_parser) - - # Reset Test: - tfrecord_parser = TFRecordLoader(datapipe2) - expected_res = final_expected_res - n_elements_before_reset = 2 - res_before_reset, res_after_reset = reset_after_n_next_calls(tfrecord_parser, n_elements_before_reset) - self.assertEqual(len(expected_res[:n_elements_before_reset]), len(res_before_reset)) - for (true_data_ctx, true_data_seq), loaded_data in zip( - expected_res[:n_elements_before_reset], res_before_reset - ): - self.assertSetEqual(set(true_data_ctx.keys()).union(true_data_seq.keys()), set(loaded_data.keys())) - for key in ["x_float", "x_int"]: - self.assertArrayEqual(true_data_ctx[key], loaded_data[key]) - self.assertEqual(len(true_data_seq[key + "_seq"]), len(loaded_data[key + "_seq"])) - self.assertIsInstance(loaded_data[key + "_seq"], list) - for a1, a2 in zip(true_data_seq[key + "_seq"], loaded_data[key + "_seq"]): - self.assertArrayEqual(a1, a2) - self.assertEqual(true_data_ctx["x_byte"], loaded_data["x_byte"]) - self.assertListEqual(true_data_seq["x_byte_seq"], loaded_data["x_byte_seq"]) - self.assertEqual(len(expected_res), len(res_after_reset)) - for (true_data_ctx, true_data_seq), loaded_data in zip(expected_res, res_after_reset): - self.assertSetEqual(set(true_data_ctx.keys()).union(true_data_seq.keys()), set(loaded_data.keys())) - for key in ["x_float", "x_int"]: - self.assertArrayEqual(true_data_ctx[key], loaded_data[key]) - self.assertEqual(len(true_data_seq[key + "_seq"]), len(loaded_data[key + "_seq"])) - self.assertIsInstance(loaded_data[key + "_seq"], list) - for a1, a2 in zip(true_data_seq[key + "_seq"], loaded_data[key + "_seq"]): - self.assertArrayEqual(a1, a2) - self.assertEqual(true_data_ctx["x_byte"], loaded_data["x_byte"]) - self.assertListEqual(true_data_seq["x_byte_seq"], loaded_data["x_byte_seq"]) - - # __len__ Test: length isn't implemented since it cannot be known ahead of time - with self.assertRaisesRegex(TypeError, "doesn't have valid length"): - len(tfrecord_parser) - - -if __name__ == "__main__": - unittest.main() diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt deleted file mode 100644 index ff388a209..000000000 --- a/third_party/CMakeLists.txt +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# ---[ pybind11 -if(USE_SYSTEM_PYBIND11) - find_package(pybind11 CONFIG) - if(NOT pybind11_FOUND) - find_package(pybind11) - endif() - if(NOT pybind11_FOUND) - message(FATAL "Cannot find system pybind11") - endif() -else() - message(STATUS "Using third_party/pybind11.") - set(pybind11_INCLUDE_DIRS ${CMAKE_CURRENT_LIST_DIR}/pybind11/include) - install(DIRECTORY ${pybind11_INCLUDE_DIRS} - DESTINATION ${CMAKE_INSTALL_PREFIX} - FILES_MATCHING PATTERN "*.h") -endif() -message(STATUS "pybind11 include dirs: " "${pybind11_INCLUDE_DIRS}") -add_library(pybind::pybind11 INTERFACE IMPORTED) -set_property(TARGET pybind::pybind11 PROPERTY - INTERFACE_INCLUDE_DIRECTORIES ${pybind11_INCLUDE_DIRS}) -set_property(TARGET pybind::pybind11 PROPERTY - INTERFACE_SYSTEM_INCLUDE_DIRECTORIES ${pybind11_INCLUDE_DIRS}) -set_property(TARGET pybind::pybind11 PROPERTY - INTERFACE_LINK_LIBRARIES python::python) - -# ---[ aws-sdk-cpp -if(USE_SYSTEM_AWS_SDK_CPP) - find_package(AWSSDK REQUIRED COMPONENTS s3 transfer) - if(NOT AWSSDK_FOUND) - message(FATAL "Cannot find system aws-sdk-cpp") - endif() -else() - message(STATUS "Using third_party/aws-sdk-cpp.") - - set(aws_cpp_sdk_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}/aws-sdk-cpp) - set(aws_cpp_sdk_INSTALL "${CMAKE_CURRENT_LIST_DIR}/aws_sdk") - set(AWSSDK_INCLUDE_DIRS "${aws_cpp_sdk_INSTALL}/include") - - set( - AWSSDK_LIBS - aws-cpp-sdk-transfer aws-cpp-sdk-s3 aws-cpp-sdk-core aws-crt-cpp - aws-c-mqtt aws-c-event-stream aws-c-s3 aws-c-auth aws-c-http aws-c-io aws-c-compression aws-c-cal aws-c-sdkutils aws-checksums aws-c-common - ) - - foreach(lib ${AWSSDK_LIBS}) - if(WIN32) - list(APPEND AWSSDK_LIBRARIES ${aws_cpp_sdk_INSTALL}/lib/${lib}.lib) - else() - list(APPEND AWSSDK_LIBRARIES ${aws_cpp_sdk_INSTALL}/lib/lib${lib}.a) - endif() - endforeach() - - if(UNIX AND NOT APPLE) - list(APPEND AWSSDK_LIBRARIES "${aws_cpp_sdk_INSTALL}/lib/libs2n.a") - endif() - - include(ExternalProject) - ExternalProject_Add( - aws_sdk - SOURCE_DIR ${aws_cpp_sdk_SOURCE_DIR} - INSTALL_DIR ${aws_cpp_sdk_INSTALL} - CMAKE_ARGS "-DBUILD_SHARED_LIBS=OFF" "-DBUILD_ONLY=transfer;s3" "-DENABLE_TESTING=OFF" "-DCMAKE_BUILD_TYPE=Release" "-DCMAKE_INSTALL_PREFIX=${aws_cpp_sdk_INSTALL}" "-DCMAKE_INSTALL_LIBDIR=lib" - BUILD_BYPRODUCTS ${AWSSDK_LIBRARIES} - ) -endif() - -message(STATUS "aws-sdk-cpp include dirs: " "${AWSSDK_INCLUDE_DIRS}") diff --git a/third_party/CONTRIBUTING.md b/third_party/CONTRIBUTING.md deleted file mode 100644 index e595d94be..000000000 --- a/third_party/CONTRIBUTING.md +++ /dev/null @@ -1,46 +0,0 @@ -# Third-Party Libraries - -The `third_party` directory contains all of the (optional) dependency libraries used by `torchdata`. And, it relies on -the `CMake` system to compile and integrate them into a C-extension module that can be found as `torchdata/_torchdata`. - -`torchdata` also relies on [`pybind11`](https://github.com/pybind/pybind11) to expose C++ API in Python. Please refer -this [directory](https://github.com/pytorch/data/tree/main/torchdata/csrc) for more detail. - -## Integration Requirement - -### Soft Dependency vs Hard Dependency - -`TorchData` as a data-processing libraries provides a bunch of `DataPipes` that are integrated with different -third-party libraries to handle specific use cases. For example, `datasets` is imported to load dataset from -`HuggingFace`, and `fsspec` is imported to provide a unified API to access and load from local or remote file systems. -Those dependencies are optional and should only be verified of their availability when they are used/referenced by users -during `DataPipe` construction. You can find examples from -[here](https://github.com/pytorch/data/blob/bb78231e5f87620385cb2f91cda87e7f9414eb4a/torchdata/datapipes/iter/load/huggingface.py#L57-L62) -and -[here](https://github.com/pytorch/data/blob/d19858202df7e8b75765074259e6023f539cbf3f/torchdata/datapipes/iter/load/fsspec.py#L59). -They are contrasted with Core dependencies that are must be installed along with `torchdata`. - -- Optional features - - For a Python library, please follow the example above to add soft dependency to `torchdata`. - - For a C library, a compilation flag should be provided to users to enable or disable compilation and integration via - `CMake`. For example, - [`BUILD_S3`](https://github.com/pytorch/data/blob/87d6dc3d6b0df6829cc2813a0ca033accfa9d795/torchdata/csrc/CMakeLists.txt#L7) - is provided for `AWSSDK`. -- Core features - - Add Python library as a hard dependency to - [`requirements.txt`](https://github.com/pytorch/data/blob/main/requirements.txt). - - Add C library as a submodule under the `third_party` and compile it against `torchdata` C-extension as always. - -### Static Linking vs Dynamic Linking - -For the third-party libraries in C, if their runtime libraries are available on both PyPI and Conda across platfroms -(Linux, MacOS, Windows and Python 3.8-3.10), it's recommended to use dynamic linking against the `torchdata` -C-extension. - -For the third-party libraries that are not available on PyPI and Conda, please add it as a submodule under `third_party` -directory and statically compile it with the `torchdata` C-extension. - -Notes: - -- On `Linux` OS, static libraries are required to follow the `manylinux2014` protocol (equivalent of `manylinux_2_17`) - when they are integrated with `torchdata`. diff --git a/third_party/aws-sdk-cpp b/third_party/aws-sdk-cpp deleted file mode 160000 index b53b56b7d..000000000 --- a/third_party/aws-sdk-cpp +++ /dev/null @@ -1 +0,0 @@ -Subproject commit b53b56b7d3ce2386c487f3b2c8da9b6c9d59f3dd diff --git a/third_party/pybind11 b/third_party/pybind11 deleted file mode 160000 index 80dc998ef..000000000 --- a/third_party/pybind11 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 80dc998efced8ceb2be59756668a7e90e8bef917 diff --git a/torchdata/__init__.py b/torchdata/__init__.py index 59c1a7282..c5b746688 100644 --- a/torchdata/__init__.py +++ b/torchdata/__init__.py @@ -4,51 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import importlib - try: from .version import __version__ # noqa: F401 except ImportError: pass - -__all__ = [ - "datapipes", - "janitor", -] - -# Please keep this list sorted -assert __all__ == sorted(__all__) - - -# Lazy import all modules -def __getattr__(name): - if name == "janitor": - return importlib.import_module(".datapipes.utils." + name, __name__) - else: - try: - return importlib.import_module("." + name, __name__) - except ModuleNotFoundError: - if name in globals(): - return globals()[name] - else: - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - - -_warning_shown = False - - -def deprecation_warning(): - global _warning_shown - if not _warning_shown: - _warning_shown = True - import warnings - - warnings.warn( - "\n################################################################################\n" - "WARNING!\n" - "The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a\n" - "future torchdata release! Please see https://github.com/pytorch/data/issues/1196\n" - "to learn more and leave feedback.\n" - "################################################################################\n", - stacklevel=2, - ) diff --git a/torchdata/_constants.py b/torchdata/_constants.py deleted file mode 100644 index 45bc4c8e3..000000000 --- a/torchdata/_constants.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# Use the same timeout as PyTorch Distributed -default_timeout_in_s = 30 * 60 - -default_dl2_worker_join_timeout_in_s = 20 diff --git a/torchdata/_extension.py b/torchdata/_extension.py deleted file mode 100644 index 26416e8bc..000000000 --- a/torchdata/_extension.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import importlib.machinery -import os -from pathlib import Path - - -_LIB_DIR = Path(__file__).parent - - -def _init_extension(): - lib_dir = os.path.dirname(__file__) - - # TODO(631): If any extension had dependency of shared library, - # in order to support load these shred libraries dynamically, - # we need to add logic to load dll path on Windows - # See: https://github.com/pytorch/pytorch/blob/master/torch/__init__.py#L56-L140 - - loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) - - extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) # type: ignore[arg-type] - ext_specs = extfinder.find_spec("_torchdata") - - if ext_specs is None: - return - - from torchdata import _torchdata as _torchdata - - -_init_extension() diff --git a/torchdata/_torchdata/__init__.pyi b/torchdata/_torchdata/__init__.pyi deleted file mode 100644 index fcea8c6c5..000000000 --- a/torchdata/_torchdata/__init__.pyi +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import List - -from torchdata import _extension # noqa: F401 - -# TODO: Add pyi generate script -class S3Handler: - def __init__(self, request_timeout_ms: int, region: str) -> None: ... - def s3_read(self, url: str) -> bytes: ... - def list_files(self, prefix: str) -> List[str]: ... - def set_buffer_size(self, buffer_size: int) -> None: ... - def set_multi_part_download(self, multi_part_download: bool) -> None: ... - def clear_marker(self) -> None: ... diff --git a/torchdata/_utils.py b/torchdata/_utils.py deleted file mode 100644 index 6eea9bd30..000000000 --- a/torchdata/_utils.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import sys -import traceback - - -class KeyErrorMessage(str): - r"""str subclass that returns itself in repr""" - - def __repr__(self): - return self - - -class ExceptionWrapper: - r""" - Wraps an exception with traceback to communicate across threads/processes - """ - - def __init__(self, exc_info=None, where: str = "in background"): - if exc_info is None: - exc_info = sys.exc_info() - self.exc_type = exc_info[0] - self.exc_msg = "".join(traceback.format_exception(*exc_info)) - self.where = where - - def reraise(self): - r""" - Reraises the wrapped exception in the current thread/process - """ - # Format a message such as: "Caught ValueError in DataLoader worker - # process 2. Original Traceback:", followed by the traceback. - msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" - if self.exc_type == KeyError: - # KeyError calls repr() on its argument (usually a dict key). This - # makes stack traces unreadable. It will not be changed in Python - # (https://bugs.python.org/issue2651), so we work around it. - msg = KeyErrorMessage(msg) - elif getattr(self.exc_type, "message", None): - # Some exceptions have first argument as non-str but explicitly - # have message field - raise self.exc_type(message=msg) - try: - exception = self.exc_type(msg) - except TypeError: - # If the exception takes multiple arguments, don't try to - # instantiate since we don't know how to - raise RuntimeError(msg) from None - raise exception diff --git a/torchdata/csrc/CMakeLists.txt b/torchdata/csrc/CMakeLists.txt deleted file mode 100644 index f4a003bfb..000000000 --- a/torchdata/csrc/CMakeLists.txt +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -if(BUILD_S3) - message(STATUS "Building S3 IO functionality") - - # To make the right CPython is built with on GitHub Actions, - # see https://github.com/actions/setup-python/issues/121#issuecomment-1014500503 - set(Python_FIND_FRAMEWORK "LAST") - - if(WIN32) - find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Interpreter Development) - set(ADDITIONAL_ITEMS Python3::Python) - else() - find_package(Python3 COMPONENTS Interpreter Development) - endif() - - # AWSSDK Dependencies - if(WIN32) - set(AWSSDK_DEP_LIBRARIES "Userenv;version;ws2_32;Bcrypt;Wininet;winhttp;Crypt32;Secur32;NCrypt;Shlwapi") - elseif(APPLE) - set(AWSSDK_DEP_LIBRARIES pthread curl) - elseif(UNIX) - set(AWSSDK_DEP_LIBRARIES pthread) - if(STATIC_DEPS) - set(OPENSSL_USE_STATIC_LIBS TRUE) - endif() - include(FindZLIB) - list(APPEND AWSSDK_DEP_LIBRARIES ${ZLIB_LIBRARIES}) - include(FindCURL) - list(APPEND AWSSDK_DEP_LIBRARIES ${CURL_LIBRARIES}) - include(FindOpenSSL) - list(APPEND AWSSDK_DEP_LIBRARIES ${OPENSSL_SSL_LIBRARIES} ${OPENSSL_CRYPTO_LIBRARIES}) - endif() - message(STATUS "AWSSDK DEPENDENCIES AWSSDK_DEP_LIBRARIES: ${AWSSDK_DEP_LIBRARIES}") - - set(CMAKE_POSITION_INDEPENDENT_CODE ON) - - set( - EXTENSION_SOURCES - pybind/pybind.cpp - pybind/S3Handler/S3Handler.cpp - ) - - add_library(_torchdata MODULE ${EXTENSION_SOURCES}) - - if(NOT USE_SYSTEM_AWS_SDK_CPP) - add_dependencies(_torchdata aws_sdk) - endif() - - target_include_directories( - _torchdata - PRIVATE - ${PROJECT_SOURCE_DIR} - ${Python_INCLUDE_DIR} - ${pybind11_INCLUDE_DIRS} - ${AWSSDK_INCLUDE_DIRS} - ) - - target_link_libraries( - _torchdata - PRIVATE - ${Python_LIBRARIES} - ${ADDITIONAL_ITEMS} - ${AWSSDK_LIBRARIES} - ${AWSSDK_DEP_LIBRARIES} - ) - - set_target_properties(_torchdata PROPERTIES PREFIX "") - if (MSVC) - set_target_properties(_torchdata PROPERTIES SUFFIX ".pyd") - endif(MSVC) - if (APPLE) - set_target_properties(_torchdata PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") - endif(APPLE) - -endif() diff --git a/torchdata/csrc/pybind/S3Handler/S3Handler.cpp b/torchdata/csrc/pybind/S3Handler/S3Handler.cpp deleted file mode 100644 index ee0a69651..000000000 --- a/torchdata/csrc/pybind/S3Handler/S3Handler.cpp +++ /dev/null @@ -1,402 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "S3Handler.h" - -namespace torchdata { - -namespace { - -static const size_t S3DefaultBufferSize = 128 * 1024 * 1024; // 128 MB -static const uint64_t S3DefaultMultiPartDownloadChunkSize = - 5 * 1024 * 1024; // 5 MB -static const int executorPoolSize = 25; -static const std::string S3DefaultMarker = ""; - -std::shared_ptr setUpS3Config( - const long requestTimeoutMs, - const std::string region) { - std::shared_ptr cfg = - std::shared_ptr( - new Aws::Client::ClientConfiguration()); - Aws::String config_file; - const char* config_file_env = getenv("AWS_CONFIG_FILE"); - if (config_file_env) { - config_file = config_file_env; - } else { - const char* home_env = getenv("HOME"); - if (home_env) { - config_file = home_env; - config_file += "/.aws/config"; - } - } - Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file); - loader.Load(); - - const char* use_https = getenv("S3_USE_HTTPS"); - if (use_https) { - if (use_https[0] == '0') { - cfg->scheme = Aws::Http::Scheme::HTTP; - } else { - cfg->scheme = Aws::Http::Scheme::HTTPS; - } - } - const char* verify_ssl = getenv("S3_VERIFY_SSL"); - if (verify_ssl) { - if (verify_ssl[0] == '0') { - cfg->verifySSL = false; - } else { - cfg->verifySSL = true; - } - } - const char* endpoint_url = getenv("S3_ENDPOINT_URL"); - if (endpoint_url) { - cfg->endpointOverride = endpoint_url; - } - if (region != "") { - cfg->region = region; - } else { - const char* env_region = getenv("AWS_REGION"); - if (env_region) { - cfg->region = env_region; - } - } - if (requestTimeoutMs > -1) { - cfg->requestTimeoutMs = requestTimeoutMs; - } - return cfg; -} - -void ShutdownClient(std::shared_ptr* s3_client) { - if (s3_client != nullptr) { - delete s3_client; - Aws::SDKOptions options; - Aws::ShutdownAPI(options); - } -} - -void ShutdownTransferManager( - std::shared_ptr* transfer_manager) { - if (transfer_manager != nullptr) { - delete transfer_manager; - } -} - -void ShutdownExecutor(Aws::Utils::Threading::PooledThreadExecutor* executor) { - if (executor != nullptr) { - delete executor; - } -} - -void parseS3Path( - const Aws::String& fname, - Aws::String* bucket, - Aws::String* object) { - if (fname.empty()) { - throw std::invalid_argument("The filename cannot be an empty string."); - } - - if (fname.size() < 5 || fname.substr(0, 5) != "s3://") { - throw std::invalid_argument("The filename must start with the S3 scheme."); - } - - std::string path = fname.substr(5); - - if (path.empty()) { - throw std::invalid_argument("The filename cannot be an empty string."); - } - - size_t pos = path.find_first_of('/'); - if (pos == 0) { - throw std::invalid_argument("The filename does not contain a bucket name."); - } - - *bucket = path.substr(0, pos); - *object = path.substr(pos + 1); - if (pos == std::string::npos) { - *object = ""; - } -} - -class S3FS { - private: - std::string bucket_name_; - std::string object_name_; - bool use_multi_part_download_; - std::shared_ptr s3_client_; - std::shared_ptr transfer_manager_; - - public: - S3FS( - const std::string& bucket, - const std::string& object, - const bool use_multi_part_download, - std::shared_ptr transfer_manager, - std::shared_ptr s3_client) - : bucket_name_(bucket), - object_name_(object), - use_multi_part_download_(use_multi_part_download), - transfer_manager_(transfer_manager), - s3_client_(s3_client) {} - - size_t Read(uint64_t offset, size_t n, char* buffer) { - if (use_multi_part_download_) { - return ReadTransferManager(offset, n, buffer); - } else { - return ReadS3Client(offset, n, buffer); - } - } - - size_t ReadS3Client(uint64_t offset, size_t n, char* buffer) { - Aws::S3::Model::GetObjectRequest getObjectRequest; - - getObjectRequest.WithBucket(bucket_name_.c_str()) - .WithKey(object_name_.c_str()); - - std::string bytes = "bytes="; - bytes += std::to_string(offset) + "-" + std::to_string(offset + n - 1); - - getObjectRequest.SetRange(bytes.c_str()); - - // When you don’t want to load the entire file into memory, - // you can use IOStreamFactory in AmazonWebServiceRequest to pass a - // lambda to create a string stream. - getObjectRequest.SetResponseStreamFactory( - []() { return Aws::New("S3IOAllocationTag"); }); - // get the object - Aws::S3::Model::GetObjectOutcome getObjectOutcome = - s3_client_->GetObject(getObjectRequest); - - if (!getObjectOutcome.IsSuccess()) { - Aws::S3::S3Error error = getObjectOutcome.GetError(); - std::cout << "ERROR: " << error.GetExceptionName() << ": " - << error.GetMessage() << std::endl; - return 0; - } else { - n = getObjectOutcome.GetResult().GetContentLength(); - // read data as a block: - getObjectOutcome.GetResult().GetBody().read(buffer, n); - return n; - } - } - - size_t ReadTransferManager(uint64_t offset, size_t n, char* buffer) { - auto create_stream_fn = [&]() { // create stream lambda fn - return Aws::New( - "S3ReadStream", - Aws::New( - "S3ReadStream", reinterpret_cast(buffer), n)); - }; // This buffer is what we used to initialize streambuf and is in memory - - std::shared_ptr downloadHandle = - transfer_manager_.get()->DownloadFile( - bucket_name_.c_str(), - object_name_.c_str(), - offset, - n, - create_stream_fn); - downloadHandle->WaitUntilFinished(); - - if (downloadHandle->GetStatus() != - Aws::Transfer::TransferStatus::COMPLETED) { - const Aws::Client::AWSError error = - downloadHandle->GetLastError(); - std::cout << "ERROR: " << error.GetExceptionName() << ": " - << error.GetMessage() << std::endl; - return 0; - } else { - return downloadHandle->GetBytesTransferred(); - } - } -}; - -} // namespace - -std::shared_ptr S3Handler::s3_handler_cfg_; - -S3Handler::S3Handler(const long requestTimeoutMs, const std::string region) - : s3_client_(nullptr, ShutdownClient), - transfer_manager_(nullptr, ShutdownTransferManager), - executor_(nullptr, ShutdownExecutor) { - initialization_lock_ = std::shared_ptr(new std::mutex()); - - // Load reading parameters - buffer_size_ = S3DefaultBufferSize; - const char* bufferSizeStr = getenv("S3_BUFFER_SIZE"); - if (bufferSizeStr) { - buffer_size_ = std::stoull(bufferSizeStr); - } - use_multi_part_download_ = true; - const char* use_multi_part_download_char = getenv("S3_MULTI_PART_DOWNLOAD"); - if (use_multi_part_download_char) { - std::string use_multi_part_download_str(use_multi_part_download_char); - if (use_multi_part_download_str == "OFF") { - use_multi_part_download_ = false; - } - } - - Aws::SDKOptions options; - Aws::InitAPI(options); - S3Handler::s3_handler_cfg_ = setUpS3Config(requestTimeoutMs, region); - InitializeS3Client(); - - last_marker_ = S3DefaultMarker; -} - -S3Handler::~S3Handler() {} - -void S3Handler::InitializeS3Client() { - std::lock_guard lock(*initialization_lock_); - s3_client_ = std::shared_ptr(new Aws::S3::S3Client( - *S3Handler::s3_handler_cfg_, - Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, - false)); -} - -void S3Handler::InitializeExecutor() { - executor_ = Aws::MakeShared( - "executor", executorPoolSize); -} - -void S3Handler::InitializeTransferManager() { - std::shared_ptr s3_client = GetS3Client(); - std::lock_guard lock(*initialization_lock_); - - Aws::Transfer::TransferManagerConfiguration transfer_config( - GetExecutor().get()); - transfer_config.s3Client = s3_client; - // This buffer is what we used to initialize streambuf and is in memory - transfer_config.bufferSize = S3DefaultMultiPartDownloadChunkSize; - transfer_config.transferBufferMaxHeapSize = - (executorPoolSize + 1) * S3DefaultMultiPartDownloadChunkSize; - transfer_manager_ = Aws::Transfer::TransferManager::Create(transfer_config); -} - -std::shared_ptr S3Handler::GetS3Client() { - if (s3_client_.get() == nullptr) { - InitializeS3Client(); - } - return s3_client_; -} - -std::shared_ptr -S3Handler::GetExecutor() { - if (executor_.get() == nullptr) { - InitializeExecutor(); - } - return executor_; -} - -std::shared_ptr -S3Handler::GetTransferManager() { - if (transfer_manager_.get() == nullptr) { - InitializeTransferManager(); - } - return transfer_manager_; -} - -size_t S3Handler::GetFileSize( - const std::string& bucket, - const std::string& object) { - Aws::S3::Model::HeadObjectRequest headObjectRequest; - headObjectRequest.WithBucket(bucket.c_str()).WithKey(object.c_str()); - Aws::S3::Model::HeadObjectOutcome headObjectOutcome = - GetS3Client()->HeadObject(headObjectRequest); - if (headObjectOutcome.IsSuccess()) { - return headObjectOutcome.GetResult().GetContentLength(); - } else { - Aws::String const& error_aws = headObjectOutcome.GetError().GetMessage(); - std::string error_str(error_aws.c_str(), error_aws.size()); - throw std::invalid_argument(error_str); - return 0; - } -} - -void S3Handler::ClearMarker() { - last_marker_ = S3DefaultMarker; -} - -void S3Handler::S3Read(const std::string& file_url, std::string* result) { - std::string bucket, object; - parseS3Path(file_url, &bucket, &object); - S3FS s3fs( - bucket, - object, - use_multi_part_download_, - GetTransferManager(), - GetS3Client()); - - uint64_t offset = 0; - uint64_t result_size = 0; - uint64_t file_size = GetFileSize(bucket, object); - size_t part_count = (std::max)( - static_cast((file_size + buffer_size_ - 1) / buffer_size_), - static_cast(1)); - result->resize(file_size); - - for (int i = 0; i < part_count; i++) { - offset = result_size; - - size_t buf_len = std::min(buffer_size_, file_size - result_size); - - size_t read_len = - s3fs.Read(offset, buf_len, (char*)(result->data()) + offset); - - result_size += read_len; - - if (result_size == file_size) { - break; - } - - if (read_len != buf_len) { - std::cout << "Result size and buffer size did not match"; - break; - } - } -} - -void S3Handler::ListFiles( - const std::string& file_url, - std::vector* filenames) { - Aws::String bucket, prefix; - parseS3Path(file_url, &bucket, &prefix); - - Aws::S3::Model::ListObjectsRequest listObjectsRequest; - listObjectsRequest.WithBucket(bucket).WithPrefix(prefix).WithMarker( - last_marker_); - - Aws::S3::Model::ListObjectsOutcome listObjectsOutcome = - GetS3Client()->ListObjects(listObjectsRequest); - if (!listObjectsOutcome.IsSuccess()) { - Aws::String const& error_aws = listObjectsOutcome.GetError().GetMessage(); - throw std::invalid_argument(error_aws); - } - - Aws::Vector objects = - listObjectsOutcome.GetResult().GetContents(); - if (objects.empty()) { - return; - } - for (const Aws::S3::Model::Object& object : objects) { - if (object.GetKey().back() == '/') // ignore folders - - { - continue; - } - Aws::String entry = "s3://" + bucket + "/" + object.GetKey(); - filenames->push_back(entry.c_str()); - } - last_marker_ = objects.back().GetKey(); - - // extreme cases when all objects are folders - if (filenames->size() == 0) { - ListFiles(file_url, filenames); - } -} - -} // namespace torchdata diff --git a/torchdata/csrc/pybind/S3Handler/S3Handler.h b/torchdata/csrc/pybind/S3Handler/S3Handler.h deleted file mode 100644 index 58b3cbf9a..000000000 --- a/torchdata/csrc/pybind/S3Handler/S3Handler.h +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "precompile.h" - -namespace torchdata { - -// In memory stream implementation -class S3UnderlyingStream : public Aws::IOStream { - public: - using Base = Aws::IOStream; - - // provide a customer controlled streambuf, so as to put all transferred - // data into this in memory buffer. - S3UnderlyingStream(std::streambuf* buf) : Base(buf) {} - - virtual ~S3UnderlyingStream() = default; -}; - -class S3Handler { - private: - static std::shared_ptr s3_handler_cfg_; - - std::shared_ptr initialization_lock_; - std::shared_ptr s3_client_; - std::shared_ptr executor_; - std::shared_ptr transfer_manager_; - - Aws::String last_marker_; - size_t buffer_size_; - bool use_multi_part_download_; - - void InitializeS3Client(); - void InitializeExecutor(); - void InitializeTransferManager(); - - std::shared_ptr GetS3Client(); - std::shared_ptr GetExecutor(); - std::shared_ptr GetTransferManager(); - size_t GetFileSize(const std::string& bucket, const std::string& object); - - public: - S3Handler(const long requestTimeoutMs, const std::string region); - ~S3Handler(); - - void SetLastMarker(const Aws::String last_marker) { - this->last_marker_ = last_marker; - } - void SetBufferSize(const uint64_t buffer_size) { - this->buffer_size_ = buffer_size; - } - void SetMultiPartDownload(const bool multi_part_download) { - this->use_multi_part_download_ = multi_part_download; - } - void ClearMarker(); - - long GetRequestTimeoutMs() const { - return s3_handler_cfg_->requestTimeoutMs; - } - Aws::String GetRegion() const { - return s3_handler_cfg_->region; - } - Aws::String GetLastMarker() const { - return last_marker_; - } - bool GetUseMultiPartDownload() const { - return use_multi_part_download_; - } - size_t GetBufferSize() const { - return buffer_size_; - } - - void S3Read(const std::string& file_url, std::string* result); - void ListFiles( - const std::string& file_url, - std::vector* filenames); -}; - -} // namespace torchdata diff --git a/torchdata/csrc/pybind/S3Handler/precompile.h b/torchdata/csrc/pybind/S3Handler/precompile.h deleted file mode 100644 index e62d7b38c..000000000 --- a/torchdata/csrc/pybind/S3Handler/precompile.h +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#ifndef TORCHDATA_S3_IO_H -#define TORCHDATA_S3_IO_H - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#endif // TORCHDATA_S3_IO_H diff --git a/torchdata/csrc/pybind/pybind.cpp b/torchdata/csrc/pybind/pybind.cpp deleted file mode 100644 index 94d4b08ab..000000000 --- a/torchdata/csrc/pybind/pybind.cpp +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -#include -#include - -#include - -namespace py = pybind11; -using torchdata::S3Handler; - -PYBIND11_MODULE(_torchdata, m) { - py::class_(m, "S3Handler") - .def(py::init()) - .def( - "s3_read", - [](S3Handler* self, const std::string& file_url) { - std::string result; - self->S3Read(file_url, &result); - return py::bytes(result); - }) - .def( - "list_files", - [](S3Handler* self, const std::string& file_url) { - std::vector filenames; - self->ListFiles(file_url, &filenames); - return filenames; - }) - .def( - "set_buffer_size", - [](S3Handler* self, const uint64_t buffer_size) { - self->SetBufferSize(buffer_size); - }) - .def( - "set_multi_part_download", - [](S3Handler* self, const bool multi_part_download) { - self->SetMultiPartDownload(multi_part_download); - }) - .def("clear_marker", [](S3Handler* self) { self->ClearMarker(); }) - .def(py::pickle( - [](const S3Handler& s3_handler) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple( - s3_handler.GetRequestTimeoutMs(), - s3_handler.GetRegion(), - s3_handler.GetLastMarker(), - s3_handler.GetUseMultiPartDownload(), - s3_handler.GetBufferSize()); - }, - [](py::tuple t) { // __setstate__ - if (t.size() != 5) - throw std::runtime_error("Invalid state!"); - - /* Create a new C++ instance */ - S3Handler s3_handler(t[0].cast(), t[1].cast()); - - /* Assign any additional state */ - s3_handler.SetLastMarker(t[2].cast()); - s3_handler.SetMultiPartDownload(t[3].cast()); - s3_handler.SetBufferSize(t[4].cast()); - - return s3_handler; - })); -} diff --git a/torchdata/dataloader2/README.md b/torchdata/dataloader2/README.md deleted file mode 100644 index f9af43dee..000000000 --- a/torchdata/dataloader2/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# DataLoader2 (Prototype) - -Please check out our [full DataLoader2 documentation](https://pytorch.org/data/main/dataloader2.html#dataloader2). - -## DataLoader2 Prototype Usage and Feedback - -`DataLoader2` is stable in terms of API, but functionally not complete yet. We welcome early adopters and feedback, as -well as potential contributors. If you are interested in trying it out, we encourage you to install the nightly version -of this library. diff --git a/torchdata/dataloader2/__init__.py b/torchdata/dataloader2/__init__.py deleted file mode 100644 index bc91544a9..000000000 --- a/torchdata/dataloader2/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from torchdata import _extension # noqa: F401 -from torchdata.dataloader2.dataloader2 import DataLoader2, DataLoader2Iterator -from torchdata.dataloader2.error import PauseIteration -from torchdata.dataloader2.reading_service import ( - CheckpointableReadingServiceInterface, - DistributedReadingService, - InProcessReadingService, - MultiProcessingReadingService, - PrototypeMultiProcessingReadingService, - ReadingServiceInterface, - SequentialReadingService, -) -from torchdata.dataloader2.shuffle_spec import ShuffleSpec - -__all__ = [ - "CheckpointableReadingServiceInterface", - "DataLoader2", - "DataLoader2Iterator", - "DistributedReadingService", - "InProcessReadingService", - "MultiProcessingReadingService", - "PauseIteration", - "PrototypeMultiProcessingReadingService", - "ReadingServiceInterface", - "SequentialReadingService", - "ShuffleSpec", -] - -assert __all__ == sorted(__all__) - - -from torchdata import deprecation_warning - -deprecation_warning() diff --git a/torchdata/dataloader2/adapter.py b/torchdata/dataloader2/adapter.py deleted file mode 100644 index 8f7cb2800..000000000 --- a/torchdata/dataloader2/adapter.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from abc import abstractmethod - -import torch - -from torchdata.dataloader2.graph import DataPipe, traverse_dps -from torchdata.datapipes.iter.util.cacheholder import _WaitPendingCacheItemIterDataPipe - - -__all__ = [ - "Adapter", - "CacheTimeout", - "Shuffle", -] - -assert __all__ == sorted(__all__) - - -class Adapter: - r""" - Adapter Base Class that follows python Callable protocol. - """ - - @abstractmethod - def __call__(self, datapipe: DataPipe) -> DataPipe: - r""" - Callable function that either runs in-place modification of - the ``DataPipe`` graph, or returns a new ``DataPipe`` graph. - - Args: - datapipe: ``DataPipe`` that needs to be adapted. - - Returns: - Adapted ``DataPipe`` or new ``DataPipe``. - """ - pass - - -class Shuffle(Adapter): - r""" - Shuffle DataPipes adapter allows control over all existing Shuffler (``shuffle``) DataPipes in the graph. - - Args: - enable: Optional boolean argument to enable/disable shuffling in the ``DataPipe`` graph. True by default. - - - True: Enables all previously disabled ``ShufflerDataPipes``. If none exists, it will add a new ``shuffle`` at the end of the graph. - - False: Disables all ``ShufflerDataPipes`` in the graph. - - None: No-op. Introduced for backward compatibility. - - Example: - - .. testsetup:: - - from torchdata.datapipes.iter import IterableWrapper - from torchdata.dataloader2 import DataLoader2 - from torchdata.dataloader2.adapter import Shuffle - - size = 12 - - .. testcode:: - - dp = IterableWrapper(range(size)).shuffle() - dl = DataLoader2(dp, [Shuffle(False)]) - assert list(range(size)) == list(dl) - """ - - def __init__(self, enable=True): - self.enable = enable - - def __call__(self, datapipe: DataPipe) -> DataPipe: - return torch.utils.data.graph_settings.apply_shuffle_settings(datapipe, shuffle=self.enable) - - -class CacheTimeout(Adapter): - r""" - CacheTimeout DataPipes adapter allows control over timeouts of all existing EndOnDiskCacheHolder (``end_caching``) - in the graph. Useful when cached pipeline takes too long to execute (ex. slow file downloading). - - Args: - timeout: int - amount of seconds parallel processes will wait for cached files to appear. - - Example: - - .. testsetup:: - - from torchdata.datapipes.iter import IterableWrapper - from torchdata.dataloader2 import DataLoader2 - from torchdata.dataloader2.adapter import CacheTimeout - - size = 12 - - .. testcode:: - - dp = IterableWrapper(range(size)).shuffle() - dl = DataLoader2(dp, [CacheTimeout(600)]) - """ - - def __init__(self, timeout=None): - if timeout is None: - raise ValueError("timeout should be integer") - self.timeout = timeout - - def __call__(self, datapipe: DataPipe) -> DataPipe: - graph = traverse_dps(datapipe) - all_pipes = torch.utils.data.graph_settings.get_all_graph_pipes(graph) - cache_locks = {pipe for pipe in all_pipes if isinstance(pipe, _WaitPendingCacheItemIterDataPipe)} - - for cache_lock in cache_locks: - cache_lock.set_timeout(self.timeout) - - return datapipe diff --git a/torchdata/dataloader2/communication/__init__.py b/torchdata/dataloader2/communication/__init__.py deleted file mode 100644 index ac35ef87c..000000000 --- a/torchdata/dataloader2/communication/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from . import eventloop, iter, map, messages, protocol, queue diff --git a/torchdata/dataloader2/communication/eventloop.py b/torchdata/dataloader2/communication/eventloop.py deleted file mode 100644 index 808b5effa..000000000 --- a/torchdata/dataloader2/communication/eventloop.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import time - -from itertools import zip_longest -from typing import Dict, List - -import torch - -from torch.utils.data import IterDataPipe, MapDataPipe -from torchdata.dataloader2 import communication -from torchdata.dataloader2.graph._serialization import extract_wrapper - -try: - import dill - - # XXX: By default, dill writes the Pickler dispatch table to inject its - # own logic there. This globally affects the behavior of the standard library - # pickler for any user who transitively depends on this module! - # Undo this extension to avoid altering the behavior of the pickler globally. - dill.extend(use_dill=False) - HAS_DILL = True -except ImportError: - HAS_DILL = False - -__all__ = [ - "DataPipeToQueuesLoop", - "CreateProcessForDataPipeline", - "CreateProcessForMultipleDataPipelines", -] - - -class _RequestCounter: - r""" - _RequestCounter is used to synchronize between eventloops within the dispatching - process. It guarantees to only handle the limit/pause/reset_epoch/resume request - util all loops have received the same message. - """ - exp_cnt: int - _keys: List[str] = ["limit", "pause", "reset_epoch", "resume"] - _cnt: Dict[str, int] - _reached: Dict[str, bool] - - def __init__(self, exp_cnt: int): - self.exp_cnt = exp_cnt - self._cnt = {k: 0 for k in self._keys} - self._reached = {k: False for k in self._keys} - - def increment(self, key: str) -> None: - assert key in self._reached - self._cnt[key] += 1 - assert self._cnt[key] <= self.exp_cnt - if self._cnt[key] == self.exp_cnt: - self._reached[key] = True - - def is_reached(self, key: str) -> bool: - assert key in self._reached - return self._reached[key] - - def reset(self, key: str) -> None: - assert key in self._reached and self._reached[key] - assert self._cnt[key] >= 1 - self._cnt[key] -= 1 - if self._cnt[key] == 0: - self._reached[key] = False - - -def MultipleDataPipesToQueuesLoop( - source_datapipes, req_queues, res_queues, process_name, worker_info, call_on_process_init=None, custom_reset_fn=None -): - r""" - Set the appropriate pipes and protocol server type, and create a loop over multiple datapipes - with the protocol server in a non-blocking manner. - - Args: - source_datapipe: DataPipe being iterated in the dispatching process - req_queue: Multiprocessing queue providing requests from the worker process - res_queue: Multiprocessing queue sending results to the worker process - process_name: The name of process (used for logging and exception handling) - worker_info: Worker information (worker id and number of workers) - call_on_process_init: Not allowed by dispatching process for now. - custom_reset_fn: Optional callable function to reset the DataPipe. - """ - assert call_on_process_init is None, "``MultipleDataPipesToQueuesLoop`` does not support call_on_process_init" - num_loops = len(source_datapipes) - assert num_loops == len(req_queues) and num_loops == len( - res_queues - ), "``MultipleDataPipesToQueuesLoop`` requires the same number of datapipes, request queues and response queues" - - torch.set_num_threads(1) - - loops = [] - request_counter = _RequestCounter(num_loops) - - loop_id = 0 - for source_datapipe, req_queue, res_queue in zip(source_datapipes, req_queues, res_queues): - loops.append( - _create_datapipe_queue_loop( - source_datapipe, - req_queue, - res_queue, - process_name, - loop_id, - worker_info, - custom_reset_fn, - blocking_request_get=False, - request_counter=request_counter, - ) - ) # Non-blocking request with reset counters - loop_id += 1 - - # Using `zip_longest` to guarantee the process is terminated only when - # all loops have received `TerminateRequest` - for _ in zip_longest(*loops): - # time.sleep to make Python switch context to get/send message in mp.Queue - # TODO(ejguan): Microbenchmarked a synthetic non-replicable case that sleep perform similar to pass. - # A more comprehensive benchmarking in real-world scneario is needed. - time.sleep(0) - - -def DataPipeToQueuesLoop( - source_datapipe, req_queue, res_queue, process_name, worker_info, call_on_process_init=None, custom_reset_fn=None -): - r""" - Initialize with the given init function, set the appropriate pipe and protocol server type, and - create a loop with the protocol server. - - Args: - source_datapipe: DataPipe being iterated in the worker process - req_queue: Multiprocessing queue providing requests from the main process - res_queue: Multiprocessing queue sending results to the main process - process_name: The name of process (used for logging and exception handling) - worker_info: Worker information (worker id and number of workers) - call_on_process_init: Callable function will be called at the time of worker process initialization. - Users can provide it to modify the DataPipe grpah in the worker process. - custom_reset_fn: Optional callable function to reset the DataPipe. - """ - # Extract Serialization Wrapper - source_datapipe = extract_wrapper(source_datapipe) - - if call_on_process_init is not None: - source_datapipe = call_on_process_init(source_datapipe) - - torch.set_num_threads(1) - - loop = _create_datapipe_queue_loop( - source_datapipe, - req_queue, - res_queue, - process_name, - worker_info.worker_id, - worker_info, - custom_reset_fn, - blocking_request_get=True, - ) - - for _ in loop: - pass - - -def _create_datapipe_queue_loop( - source_datapipe, - req_queue, - res_queue, - process_name, - loop_id, - worker_info, - custom_reset_fn=None, - blocking_request_get=True, - request_counter=None, -): - if isinstance(source_datapipe, IterDataPipe): - pipe_type = communication.iter - protocol_type = communication.protocol.IterDataPipeQueueProtocolServer - elif isinstance(source_datapipe, MapDataPipe): - pipe_type = communication.map # type: ignore[misc] - protocol_type = communication.protocol.MapDataPipeQueueProtocolServer # type: ignore[assignment] - else: - raise Exception("Only supports IterDataPipe or MapDataPipe, got", source_datapipe) - - return pipe_type.DataPipeBehindQueues( - source_datapipe, - protocol_type(req_queue, res_queue), - process_name=process_name, - loop_id=loop_id, - worker_info=worker_info, - custom_reset_fn=custom_reset_fn, - blocking_request_get=blocking_request_get, - request_counter=request_counter, - ) - - -def CreateProcessForDataPipeline( - multiprocessing_ctx, datapipe, process_name, worker_info, call_on_process_init=None, custom_reset_fn=None -): - r""" - Given a DataPipe, creates a new process with ``DataPipeToQueuesLoop`` as target, - and returns ``(process, req_queue, res_queue)``. - """ - req_queue = multiprocessing_ctx.Queue() - res_queue = multiprocessing_ctx.Queue() - process = multiprocessing_ctx.Process( - target=DataPipeToQueuesLoop, - args=(datapipe, req_queue, res_queue, process_name, worker_info, call_on_process_init, custom_reset_fn), - ) - return process, req_queue, res_queue - - -def CreateProcessForMultipleDataPipelines( - multiprocessing_ctx, datapipes, process_name, worker_info, custom_reset_fn=None -): - r""" - Given a DataPipe, creates a new process with ``MultipleDataPipesToQueuesLoop`` as target, - and returns ``(process, [req_queue_0, ...], [res_queue_0, ...])``. - """ - req_queues = [] - res_queues = [] - for _ in datapipes: - req_queues.append(multiprocessing_ctx.Queue()) - res_queues.append(multiprocessing_ctx.Queue()) - - process = multiprocessing_ctx.Process( - target=MultipleDataPipesToQueuesLoop, - args=(datapipes, req_queues, res_queues, process_name, worker_info, custom_reset_fn), - ) - return process, req_queues, res_queues diff --git a/torchdata/dataloader2/communication/iter.py b/torchdata/dataloader2/communication/iter.py deleted file mode 100644 index 73a54d1c9..000000000 --- a/torchdata/dataloader2/communication/iter.py +++ /dev/null @@ -1,502 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import time -import types -import warnings - -from collections import deque -from itertools import cycle -from typing import Callable, Deque, List, Optional - -from torch.utils.data import IterDataPipe -from torchdata._utils import ExceptionWrapper -from torchdata.dataloader2 import communication -from torchdata.dataloader2.graph import DataPipe, find_dps, list_dps, traverse_dps -from torchdata.dataloader2.random import SeedGenerator -from torchdata.dataloader2.utils import process_reset_fn - - -DEFAULT_NON_BLOCKING_SLEEP = 0.001 - -__all__ = [ - "DataPipeBehindQueues", - "EnsureNonBlockingDataPipe", - "InvalidStateResetRequired", - "NonBlocking", - "NotAvailable", - "QueueWrapper", - "default_not_available_hook", -] - - -def default_not_available_hook(): - time.sleep(DEFAULT_NON_BLOCKING_SLEEP) - - -class NotAvailable(Exception): - pass - - -class InvalidStateResetRequired(Exception): - """ - Returned by DataPipe when it is expecting to get reset request, - for example RouterDataPipe expecting all workers to request reset. - """ - - pass - - -class TerminateRequired(Exception): - """ - Returned by DataPipe when it is expecting to get terminate request, - for example it got terminate request from other source and at the process - of stopping. - """ - - pass - - -class NonBlocking(IterDataPipe): - not_available_hook = default_not_available_hook - - def __iter__(self): - self.reset_iterator() - return self - - def __next__(self): - while True: - try: - return self.nonblocking_next() - except NotAvailable: - if NonBlocking.not_available_hook is not None: - NonBlocking.not_available_hook() - - def nonblocking_next(self): - raise NotImplementedError("nonblocking_next is not implemented for %s" % self.__class__) - - def reset_iterator(self): - raise NotImplementedError("reset_iterator is not implemented for %s" % self.__class__) - - @staticmethod - def register_not_available_hook(hook_function): - NonBlocking.not_available_hook = hook_function - - -def EnsureNonBlockingDataPipe(validated_datapipe): - if not isinstance(validated_datapipe, IterDataPipe): - raise Exception("Not Iterable DataPipe " + str(validated_datapipe.__class__)) - if isinstance(validated_datapipe, NonBlocking): - return validated_datapipe - if not hasattr(validated_datapipe, "_as_iterator"): - validated_datapipe._as_iterator = None # type: ignore[attr-defined] - if not hasattr(validated_datapipe, "nonblocking_next"): - - def nonblocking_next(self): - if self._as_iterator is None: - self._as_iterator = iter(self) - return next(self._as_iterator) - - validated_datapipe.nonblocking_next = types.MethodType( # type: ignore[attr-defined] - nonblocking_next, validated_datapipe - ) - if not hasattr(validated_datapipe, "reset_iterator"): - - def reset_iterator(self): - self._as_iterator = None - - validated_datapipe.reset_iterator = types.MethodType( # type: ignore[attr-defined] - reset_iterator, validated_datapipe - ) - return validated_datapipe - - -def _sync_recv(request_counter, msg): - if request_counter is not None: - request_counter.increment(msg) - # Make sure all loops have reached - while not request_counter.is_reached(msg): - yield True - - -def _sync_resp(request_counter, msg): - if request_counter is not None: - request_counter.reset(msg) - while request_counter.is_reached(msg): - yield True - - -def DataPipeBehindQueues( - source_datapipe, - protocol, - process_name, - loop_id, - worker_info, - custom_reset_fn, - blocking_request_get=False, - request_counter=None, -): - """ - Indefinitely iterates over ``req_queue`` and passing values from source_datapipe to ``res_queue``. - - Request Types: - `ResetEpoch` - Call the `reset_epoch_fn` on the protocol's DataPipe and reset DataPipe iterator - `Terminate` - exits the infinite while loop - `GetNext` - returns the value from the DataPipe, and handles exceptions such as `StopIteration` as appropriate - `Limit` - Set limit to the DataPipe graph - `Pause` - Pause - the DataPipe graph - `Resume` - Resume the DataPipe graph - - Args: - source_datapipe: DataPipe - protocol: ``IterDataPipeQueueProtocolServer`` that contains ``req_queue`` and ``res_queue`` - process_name: Process name - loop_id: Loop ID - worker_info: Worker info include worker id and number of workers - custom_reset_fn: function to call after each request is received - blocking_request_get: determines if ``protocol.get_new_request`` will block - request_counter: Optional counter to synchronize all loops that have received requests for - reset/limit/pause/resume within the dispatching process. It would guarantee that - all loops starts to reset iterator and get next element at the same time. - """ - if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolServer): - raise Exception("Expecting IterDataPipeQueueProtocolServer, got", protocol) - source_datapipe = EnsureNonBlockingDataPipe(source_datapipe) - forever = True - while forever: - try: - # TODO: Non-blocking call is extremely slow here for python.mp, need to figure out a good workaround - request = protocol.get_new_request(block=blocking_request_get) - except communication.protocol.EmptyQueue: - yield True - continue - - # TODO: Handle Error caused by requests other than GetNext and send it to main process - if isinstance(request, communication.messages.ResetEpochRequest): - yield from _sync_recv(request_counter, "reset_epoch") - distributed_shared_seed = request_counter is not None - if request_counter is None or loop_id == 0: - seed_generator = request.seed_generator - iter_reset_fn = request.iter_reset_fn - dispatching_dps = find_dps(traverse_dps(source_datapipe), _IterateQueueDataPipes) - for dp in dispatching_dps: - dp.reset_epoch(seed_generator, iter_reset_fn) - source_datapipe = process_reset_fn( - source_datapipe, - worker_info, - seed_generator, - distributed_shared_seed, - iter_reset_fn, - custom_reset_fn, - ) - source_datapipe.reset_iterator() - yield from _sync_resp(request_counter, "reset_epoch") - protocol.response_reset_epoch() - yield True # Returns control - - elif isinstance(request, communication.messages.LimitRequest): - yield from _sync_recv(request_counter, "limit") - if request_counter is None or loop_id == 0: - num_batches = request.num_batches - limit_fn = request.limit_fn - worker_num_batches = num_batches if request.worker_num_batches is None else request.worker_num_batches - # Send limit to the worker/dispatching process - dispatching_dps = find_dps(traverse_dps(source_datapipe), _IterateQueueDataPipes) - for dp in dispatching_dps: - dp.request_limit(num_batches, limit_fn, worker_num_batches) - if limit_fn is not None: - # Set limit to the DataPipe graph in worker/dispatching process - source_datapipe = limit_fn(source_datapipe, worker_num_batches) - yield from _sync_resp(request_counter, "limit") - protocol.response_limit() - yield True # Returns control - - elif isinstance(request, communication.messages.PauseRequest): - yield from _sync_recv(request_counter, "pause") - if request_counter is None or loop_id == 0: - graph = traverse_dps(source_datapipe) - dp_list = list_dps(graph) - for dp in dp_list: - if hasattr(dp, "pause") and callable(dp.pause): - dp.pause() - dispatching_dps = find_dps(graph, _IterateQueueDataPipes) - for dp in dispatching_dps: - dp.request_pause(request.pause_fn) - if request.pause_fn is not None: - source_datapipe = request.pause_fn(source_datapipe) - yield from _sync_resp(request_counter, "pause") - protocol.response_pause() - yield True # Returns control - - elif isinstance(request, communication.messages.ResumeRequest): - yield from _sync_recv(request_counter, "resume") - if request_counter is None or loop_id == 0: - if request.resume_fn is not None: - source_datapipe = request.resume_fn(source_datapipe) - graph = traverse_dps(source_datapipe) - # Send resume to the dispatching process - dispatching_dps = find_dps(graph, _IterateQueueDataPipes) - for dp in dispatching_dps: - dp.request_resume(request.resume_fn) - for dp in reversed(list_dps(graph)): - if hasattr(dp, "resume") and callable(dp.resume): - dp.resume() - yield from _sync_resp(request_counter, "resume") - protocol.response_resume() - yield True # Returns control - - elif isinstance(request, communication.messages.TerminateRequest): - forever = False - dispatch_dps = find_dps(traverse_dps(source_datapipe), _IterateQueueDataPipes) - for dispatch_dp in dispatch_dps: - dispatch_dp.request_terminate() - protocol.response_terminate() - yield True # Returns control - - elif isinstance(request, communication.messages.GetNextRequest): - while forever: - if protocol.is_paused(): - protocol.response_stop_iteration() - warnings.warn( - "Cannot `GetNext` after `Pause` has been called. " - "`Resume` must be called first before additional elements can be yielded." - ) - yield True - break - try: - value = source_datapipe.nonblocking_next() - except NotAvailable: - yield True - continue - except StopIteration: - protocol.response_stop_iteration() - yield True - break - except InvalidStateResetRequired: - protocol.response_invalid_state() - yield True - break - except Exception: - exc = ExceptionWrapper(where=f"in {process_name} {loop_id}") - protocol.response_worker_exception(exc) - return - protocol.response_next(value) - yield True # Returns control - break - else: - raise Exception("Unrecognized type of request received", request) - - -class QueueWrapper(NonBlocking): - """ - Creates an IterDataPipe which sends requests and reads the response from the DataLoader.Queue. - The input is a ProtocolClient that contains request queue and response queue. - """ - - def __init__(self, protocol, response_wait_time=0.00001): - if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolClient): - raise Exception("Got", protocol) - self.protocol = protocol - self.counter = 0 - self._stop_iteration = False - self._response_wait_time = response_wait_time - - def request_reset_epoch(self, seed_generator, iter_reset_fn): - self._stop_iteration = False - self.counter = 0 - self.protocol.request_reset_epoch(seed_generator, iter_reset_fn) - - def _get_response(self, fn_name) -> None: - assert hasattr(self.protocol, fn_name) and callable(getattr(self.protocol, fn_name)) - get_response_fn = getattr(self.protocol, fn_name) - while True: - try: - get_response_fn() - break - except communication.protocol.EmptyQueue: - if NonBlocking.not_available_hook is not None: - NonBlocking.not_available_hook() - - def get_reset_epoch_response(self) -> None: - self._get_response("get_response_reset_epoch") - - def request_limit( - self, - num_batches: Optional[int], - limit_fn: Optional[Callable[[DataPipe, Optional[int]], DataPipe]] = None, - worker_num_batches: Optional[int] = None, - ) -> None: - self.protocol.request_limit(num_batches, limit_fn, worker_num_batches) - - def get_limit_response(self) -> None: - self._get_response("get_response_limit") - - def request_pause(self, pause_fn: Optional[Callable[[DataPipe], DataPipe]] = None) -> None: - self.protocol.request_pause(pause_fn) - - def get_pause_response(self) -> None: - self._get_response("get_response_pause") - - def request_resume(self, resume_fn: Optional[Callable[[DataPipe], DataPipe]] = None) -> None: - self.protocol.request_resume(resume_fn) - - def get_resume_response(self) -> None: - self._get_response("get_response_resume") - - def nonblocking_next(self): - if self._stop_iteration: - raise Exception("`next` or `nonblocking_next` called after receiving StopIteration") - if self.protocol.can_take_request(): - self.protocol.request_next() - try: - response = self.protocol.get_response_next(block=True, timeout=self._response_wait_time) - except communication.protocol.EmptyQueue: - raise NotAvailable - if isinstance(response, communication.messages.StopIterationResponse): - self._stop_iteration = True - raise StopIteration - if isinstance(response, communication.messages.InvalidStateResponse): - raise NotAvailable - return response.value - - -class _IterateQueueDataPipes(IterDataPipe): - r""" - Takes in ``QueueWrapper``s and iterates through them in a round-robin manner to get batches one-by-one. - - Typically, each worker has one ``QueueWrapper``. - """ - - def __init__(self, datapipes): - # TODO(VitalyFedyunin): Consider combining _IterateQueueDataPipes and QueueWrapper - # into one class, which supports any number of queues. - for dp in datapipes: - if not isinstance(dp, communication.iter.QueueWrapper): - raise Exception("Source datapipes should be an instance of iter.QueueWrapper") - self.datapipes = datapipes - self._num_processes = len(datapipes) - self.res_buffers: List[Deque] = [deque() for _ in range(len(datapipes))] - self._terminated: bool = False - self._limit: Optional[int] = None - self._request_cnt: int = 0 - - def __iter__(self): - disabled_pipe = [False] * len(self.datapipes) - cnt_disabled_pipes = 0 - - total_req_cnt = 0 - req_idx_cycle = cycle(range(self._num_processes)) - req_idx = next(req_idx_cycle) - total_res_cnt = 0 - res_idx_cycle = cycle(range(self._num_processes)) - res_idx = next(res_idx_cycle) - - while cnt_disabled_pipes < self._num_processes and not self._terminated: - # Send a round of requests until limit is reached (limit is smaller than total pipes) - for _ in range(self._num_processes): - if not disabled_pipe[req_idx]: - self.datapipes[req_idx].protocol.request_next() - self._request_cnt += 1 - total_req_cnt += 1 - req_idx = next(req_idx_cycle) - if self._limit is not None and self._request_cnt == self._limit: - break - # Receive responses from each of the workers with pending requests - while total_res_cnt < total_req_cnt and cnt_disabled_pipes < self._num_processes: - disabled = disabled_pipe[res_idx] - if not disabled: - if len(self.res_buffers[res_idx]): - response = self.res_buffers[res_idx].popleft() - else: - while not self._terminated: - try: - # Using non-blocking next to make sure termination reached - response = self.datapipes[res_idx].protocol.get_response_next(block=False) - break - except communication.protocol.EmptyQueue: - time.sleep(DEFAULT_NON_BLOCKING_SLEEP) - if isinstance(response, communication.messages.InvalidStateResponse): - raise communication.iter.InvalidStateResetRequired - if isinstance(response, communication.messages.TerminateResponse): - raise communication.iter.TerminateRequired - if isinstance(response, communication.messages.WorkerExceptionResponse): - response.exc.reraise() - if self._terminated: - break - if isinstance(response, communication.messages.StopIterationResponse): - disabled_pipe[res_idx] = True - cnt_disabled_pipes += 1 - disabled = True - req_idx = next(req_idx_cycle) - else: - # Only request if buffer is empty and has not reached the limit - if len(self.res_buffers[res_idx]) == 0 and ( - self._limit is None or self._request_cnt < self._limit - ): - self.datapipes[req_idx].protocol.request_next() - req_idx = next(req_idx_cycle) - self._request_cnt += 1 - total_req_cnt += 1 - total_res_cnt += 1 - res_idx = next(res_idx_cycle) - if not disabled: - yield response.value - - def reset_epoch( - self, - seed_generator: SeedGenerator, - iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]], - ): - self._request_cnt = 0 - for dp in self.datapipes: - dp.protocol.discard_existing_request() - for worker_id, dp in enumerate(self.datapipes): - worker_seed_generator = seed_generator.spawn(worker_id) - dp.request_reset_epoch(worker_seed_generator, iter_reset_fn) - for dp in self.datapipes: - dp.get_reset_epoch_response() - - def request_pause(self, pause_fn: Optional[Callable[[DataPipe], DataPipe]] = None) -> None: - # Store results of pending requests - for idx, dp in enumerate(self.datapipes): - if dp.protocol.waiting_for_response(): - res = dp.protocol.get_response_next(block=True) - self.res_buffers[idx].append(res) - for dp in self.datapipes: - dp.request_pause(pause_fn) - for dp in self.datapipes: - dp.get_pause_response() - - def request_resume(self, resume_fn: Optional[Callable[[DataPipe], DataPipe]] = None) -> None: - for dp in self.datapipes: - dp.request_resume(resume_fn) - for dp in self.datapipes: - dp.get_resume_response() - self._request_cnt = 0 - - def request_limit( - self, - num_batches: Optional[int], - limit_fn: Optional[Callable[[DataPipe, Optional[int]], DataPipe]] = None, - worker_num_batches: Optional[int] = None, - ) -> None: - self._limit = num_batches if worker_num_batches is None else worker_num_batches - avg_num_batches = num_batches if num_batches is None else num_batches // self._num_processes - batch_remainder = 0 if num_batches is None else num_batches % self._num_processes - for idx, dp in enumerate(self.datapipes): - ext_batch = 1 if batch_remainder > idx else 0 - wnb = None if avg_num_batches is None or worker_num_batches is not None else avg_num_batches + ext_batch - dp.request_limit(num_batches, limit_fn, wnb) - for dp in self.datapipes: - dp.get_limit_response() - - def request_terminate(self): - self._terminated = True - for dp in self.datapipes: - dp.protocol.discard_existing_request() - for dp in self.datapipes: - dp.protocol.request_terminate() diff --git a/torchdata/dataloader2/communication/map.py b/torchdata/dataloader2/communication/map.py deleted file mode 100644 index 53f5bc13e..000000000 --- a/torchdata/dataloader2/communication/map.py +++ /dev/null @@ -1,204 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import time -import types - -from torch.utils.data import MapDataPipe -from torchdata._utils import ExceptionWrapper -from torchdata.dataloader2 import communication -from torchdata.dataloader2.utils import process_reset_fn - -DEFAULT_NON_BLOCKING_SLEEP = 0.001 - -__all__ = [ - "DataPipeBehindQueues", - "EnsureNonBlockingMapDataPipe", - "NonBlockingMap", - "NotAvailable", - "QueueWrapperForMap", - "default_not_available_hook", -] - - -def default_not_available_hook(): - time.sleep(DEFAULT_NON_BLOCKING_SLEEP) - - -class NotAvailable(Exception): - pass - - -class NonBlockingMap(MapDataPipe): - not_available_hook = default_not_available_hook - - def __getitem__(self, index): - while True: - try: - return self.nonblocking_getitem(index) - except NotAvailable: - if NonBlockingMap.not_available_hook is not None: - NonBlockingMap.not_available_hook() - - def __len__(self): - try: - return self.nonblocking_len() - except NotAvailable: - if NonBlockingMap.not_available_hook is not None: - NonBlockingMap.not_available_hook() - - def nonblocking_len(self): - raise NotImplementedError("nonblocking_len is not implemented for %s" % self.__class__) - - def nonblocking_getitem(self, index): - raise NotImplementedError("nonblocking_getitem is not implemented for %s" % self.__class__) - - @staticmethod - def register_not_available_hook(hook_function): - NonBlockingMap.not_available_hook = hook_function - - -def EnsureNonBlockingMapDataPipe(validated_datapipe): - if not isinstance(validated_datapipe, MapDataPipe): - raise Exception(f"Not Map DataPipe - got {validated_datapipe.__class__}") - if isinstance(validated_datapipe, NonBlockingMap): - return validated_datapipe - if not hasattr(validated_datapipe, "nonblocking_len"): - - def nonblocking_len(self): - return self.__len__() - - validated_datapipe.nonblocking_len = types.MethodType( # type: ignore[attr-defined] - nonblocking_len, validated_datapipe - ) - if not hasattr(validated_datapipe, "nonblocking_getitem"): - - def nonblocking_getitem(self, index): - return self.__getitem__(index) - - validated_datapipe.nonblocking_getitem = types.MethodType( # type: ignore[attr-defined] - nonblocking_getitem, validated_datapipe - ) - return validated_datapipe - - -def DataPipeBehindQueues( - source_datapipe, - protocol, - process_name, - loop_id, - worker_info, - custom_reset_fn, - blocking_request_get=False, - request_counter=None, -): - """ - Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue. - - Args: - source_datapipe: DataPipe - protocol: ``MapDataPipeQueueProtocolServer`` that contains ``req_queue`` and ``res_queue`` - process_name: Process name - loop_id: Loop ID - worker_info: Worker info include worker id and number of workers - custom_reset_fn: function to call after each request is received - blocking_request_get: determines if ``protocol.get_new_request`` will block - """ - if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolServer): - raise Exception("Expecting MapDataPipeQueueProtocolServer, got", protocol) - source_datapipe = EnsureNonBlockingMapDataPipe(source_datapipe) - forever = True - while forever: - try: - # TODO: non-blocking call is extremely slow here for python.mp, need to figure out a good workaround - request = protocol.get_new_request(block=blocking_request_get) - except communication.protocol.EmptyQueue: - yield True - continue - - if isinstance(request, communication.messages.ResetEpochRequest): - distributed_shared_seed = request_counter is not None - source_datapipe = process_reset_fn( - source_datapipe, - worker_info, - request.seed_generator, - distributed_shared_seed, - request.iter_reset_fn, - custom_reset_fn, - ) - protocol.response_reset_epoch() - - elif isinstance(request, communication.messages.TerminateRequest): - forever = False - protocol.response_terminate() - - elif isinstance(request, communication.messages.LenRequest): - size = source_datapipe.nonblocking_len() - protocol.response_len(size) - - elif isinstance(request, communication.messages.GetItemRequest): - while forever: - try: - value = source_datapipe.nonblocking_getitem(request.key) - except NotAvailable: - yield True - continue - except IndexError: - # Alternatively, we can just allow the underlying DataPipe to throw an exception? - protocol.response_index_out_of_bound() - yield True - break - except Exception: - exc = ExceptionWrapper(where=f"in {process_name} {loop_id}") - protocol.response_worker_exception(exc) - break - protocol.response_item(request.key, value) - yield True # Returns control - break - else: - raise Exception("Unrecognized type of request received", request) - - -class QueueWrapperForMap(NonBlockingMap): - """ - Creates map.DataPipe which reads data from the DataLoader.Queue - """ - - def __init__(self, protocol, response_wait_time=0.00001): - if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolClient): - raise Exception("Got", protocol) - self.protocol = protocol - self.counter = 0 - self._stop_iteration = False - self._response_wait_time = response_wait_time - - def nonblocking_getitem(self, index): - if self._stop_iteration: - raise Exception("`getitem` or `nonblocking_getitem` called after receiving StopIteration") - if self.protocol.can_take_request(): - self.protocol.request_item(index) - try: - response = self.protocol.get_response_item(block=True, timeout=self._response_wait_time) - except communication.protocol.EmptyQueue: - raise NotAvailable - if isinstance(response, communication.messages.StopIterationResponse): - self._stop_iteration = True - raise IndexError(f"Index {index} is out of bound.") - if isinstance(response, communication.messages.WorkerExceptionResponse): - self._stop_iteration = True - response.exc.reraise() - return response.key, response.value - - def nonblocking_len(self): - if self._stop_iteration: - raise Exception("`len` or `nonblocking_len` called after receiving StopIteration") - if self.protocol.can_take_request(): - self.protocol.request_len() - try: - response = self.protocol.get_response_len(block=True, timeout=self._response_wait_time) - except communication.protocol.EmptyQueue: - raise NotAvailable - return response.len diff --git a/torchdata/dataloader2/communication/messages.py b/torchdata/dataloader2/communication/messages.py deleted file mode 100644 index 86b9ee387..000000000 --- a/torchdata/dataloader2/communication/messages.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torchdata._utils import ExceptionWrapper - - -class DataLoaderQueueMessage: - pass - - -class Request(DataLoaderQueueMessage): - pass - - -class Response(DataLoaderQueueMessage): - pass - - -class ResetEpochRequest(Request): - __slots__ = ("seed_generator", "iter_reset_fn") - - def __init__(self, seed_generator, iter_reset_fn): - self.seed_generator = seed_generator - self.iter_reset_fn = iter_reset_fn - - -class ResetEpochResponse(Response): - pass - - -class LimitRequest(Request): - __slots__ = ("num_batches", "limit_fn", "worker_num_batches") - - def __init__(self, num_batches, limit_fn, worker_num_batches=None): - self.num_batches = num_batches - self.limit_fn = limit_fn - self.worker_num_batches = worker_num_batches - - -class LimitResponse(Response): - pass - - -class PauseRequest(Request): - __slots__ = "pause_fn" - - def __init__(self, pause_fn): - self.pause_fn = pause_fn - - -class PauseResponse(Response): - pass - - -class ResumeRequest(Request): - __slots__ = "resume_fn" - - def __init__(self, resume_fn): - self.resume_fn = resume_fn - - -class ResumeResponse(Response): - pass - - -class TerminateRequest(Request): - pass - - -class TerminateResponse(Response): - pass - - -class LenRequest(Request): - pass - - -class LenResponse(Response): - __slots__ = "len" - - def __init__(self, len): - self.len = len - - -class GetItemRequest(Request): - __slots__ = "key" - - def __init__(self, key): - self.key = key - - -class GetItemResponse(Response): - __slots__ = ("key", "value") - - def __init__(self, key, value): - self.key = key - self.value = value - - -class GetNextRequest(Request): - pass - - -class GetNextResponse(Response): - __slots__ = "value" - - def __init__(self, value): - self.value = value - - -class StopIterationResponse(Response): - pass - - -class InvalidStateResponse(Response): - """ - Returned by DataPipe when it is expecting to get reset request, - for example RouterDataPipe expecting all workers to request reset' - """ - - pass - - -class WorkerExceptionResponse(Response): - __slots__ = "exc" - - def __init__(self, exc: ExceptionWrapper): - self.exc: ExceptionWrapper = exc diff --git a/torchdata/dataloader2/communication/protocol.py b/torchdata/dataloader2/communication/protocol.py deleted file mode 100644 index bf7aa6570..000000000 --- a/torchdata/dataloader2/communication/protocol.py +++ /dev/null @@ -1,323 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from queue import Empty as EmptyException - -from torchdata.dataloader2 import communication - - -class Protocol: - __slots__ = ("request_queue", "response_queue") - - def __init__(self, request_queue, response_queue): - self.request_queue = request_queue - self.response_queue = response_queue - - -class ProtocolClient(Protocol): - """ - ProtocolClient takes charge of putting requests into req_queue and returning results from res_queue. - """ - - _req_sent = None - - def __init__(self, request_queue, response_queue): - self.request_queue = request_queue - self.response_queue = response_queue - self._req_sent = None - - def can_take_request(self): - return self._req_sent is None - - def waiting_for_response(self): - return self._req_sent is not None - - def request_sent(self, request=True): - if not self.can_take_request(): - raise Exception("Protocol only supports one request in the Queue") - self._req_sent = request - - def request_served(self, result=None): - if not self.waiting_for_response(): - raise Exception("Expected no pending requests, but something got served", result) - self._req_sent = None - - def discard_existing_request(self): - if self.waiting_for_response(): - response = self.response_queue.get(block=True) - self.request_served(response) - - def request_limit(self, num_batches, limit_fn=None, worker_num_batches=None): - if not self.can_take_request(): - raise Exception("Can not `limit` while we are still waiting response for previous request") - request = communication.messages.LimitRequest(num_batches, limit_fn, worker_num_batches) - self.request_queue.put(request) - self.request_sent(request) - - def request_pause(self, pause_fn=None): - if not self.can_take_request(): - raise Exception("Can not `pause` while we are still waiting response for previous request") - request = communication.messages.PauseRequest(pause_fn) - self.request_queue.put(request) - self.request_sent(request) - - def request_resume(self, resume_fn=None): - if not self.can_take_request(): - raise Exception("Can not `resume` while we are still waiting response for previous request") - request = communication.messages.ResumeRequest(resume_fn) - self.request_queue.put(request) - self.request_sent(request) - - def request_terminate(self): - r""" - Drop the existing request and send TerminateRequest directly - """ - if not self.can_take_request(): - self._req_sent = None - request = communication.messages.TerminateRequest() - self.request_queue.put(request) - self.request_sent(request) - - -class ProtocolServer(Protocol): - """ - ProtocolServer takes charge of getting requests from req_queue and fetching data from source datapipe. - """ - - # TODO(966): Update the exceptions raised in this class to be more specific - - _req_received = None - _paused = False # When `True`, prevents `GetNext` in `DataPipeBehindQueues`. - - def __init__(self, request_queue, response_queue): - self.request_queue = request_queue - self.response_queue = response_queue - self._req_received = None - self._paused = False - - def is_paused(self): - return self._paused - - def have_pending_request(self): - return self._req_received is not None - - def get_new_request(self, block=False): - if self.have_pending_request(): - raise Exception("Trying to get next request, while having one un-served") - try: - response = self.request_queue.get(block=block) - except EmptyException: - raise EmptyQueue("queue is empty") - self._req_received = response - return response - # TODO(626): Validate supported requests - - def response_terminate(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - if not isinstance(self._req_received, communication.messages.TerminateRequest): - raise Exception("Replaying with `terminate` status to other type of message") - self.response_queue.put(communication.messages.TerminateResponse()) - self._req_received = None - - def response_reset_epoch(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - if not isinstance(self._req_received, communication.messages.ResetEpochRequest): - raise Exception("Replaying with `reset_epoch` status to other type of message") - self.response_queue.put(communication.messages.ResetEpochResponse()) - self._req_received = None - - def response_limit(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - if not isinstance(self._req_received, communication.messages.LimitRequest): - raise Exception("Replaying with `limit` status to other type of message") - self.response_queue.put(communication.messages.LimitResponse()) - self._req_received = None - - def response_pause(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - if not isinstance(self._req_received, communication.messages.PauseRequest): - raise Exception("Replaying with `pause` status to other type of message") - self._paused = True - self.response_queue.put(communication.messages.PauseResponse()) - self._req_received = None - - def response_resume(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - if not isinstance(self._req_received, communication.messages.ResumeRequest): - raise Exception("Replaying with `resume` status to other type of message") - self._paused = False - self.response_queue.put(communication.messages.ResumeResponse()) - self._req_received = None - - def response_worker_exception(self, exception): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.WorkerExceptionResponse(exception)) - self._req_received = None - - -class MapDataPipeQueueProtocolServer(ProtocolServer): - def response_item(self, key, value): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.GetItemResponse(key, value)) - self._req_received = None - - def response_len(self, size): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.LenResponse(size)) - self._req_received = None - - def response_index_out_of_bound(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.StopIterationResponse()) - self._req_received = None - - -class MapDataPipeQueueProtocolClient(ProtocolClient): - def request_len(self): - if not self.can_take_request(): - raise Exception("Can not request len while we are still waiting response for previous request") - request = communication.messages.LenRequest() - self.request_queue.put(request) - self.request_sent(request) - - def request_reset_epoch(self, seed_generator, iter_reset_fn): - if not self.can_take_request(): - raise Exception("Can not reset while we are still waiting response for previous request") - request = communication.messages.ResetEpochRequest(seed_generator, iter_reset_fn) - self.request_queue.put(request) - self.request_sent(request) - - def request_item(self, index): - if not self.can_take_request(): - raise Exception("Can not request item while we are still waiting response for previous request") - request = communication.messages.GetItemRequest(index) - self.request_queue.put(request) - self.request_sent(request) - - def get_response_len(self, block=False, timeout=None): - if not self.waiting_for_response(): - raise Exception("Can not expect any response without submitted request") - try: - response = self.response_queue.get(block=block, timeout=timeout) - except TimeoutError: - raise EmptyQueue("queue is empty") - self.request_served(response) - if not isinstance(response, communication.messages.LenResponse): - raise Exception("Invalid response received") - return response - - def get_response_item(self, block=False, timeout=None): - if not self.waiting_for_response(): - raise Exception("Can not expect any response without submitted request") - try: - response = self.response_queue.get(block=block, timeout=timeout) - except TimeoutError: - raise EmptyQueue("queue is empty") - self.request_served(response) - # if not isinstance(response, communication.messages.GetItemResponse): - # raise Exception('Invalid response received') - return response - - -class EmptyQueue(Exception): - pass - - -class IterDataPipeQueueProtocolServer(ProtocolServer): - def response_next(self, value): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.GetNextResponse(value)) - self._req_received = None - - def response_stop_iteration(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.StopIterationResponse()) - self._req_received = None - - def response_invalid_state(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.InvalidStateResponse()) - self._req_received = None - - -class IterDataPipeQueueProtocolClient(ProtocolClient): - def request_reset_epoch(self, seed_generator, iter_reset_fn): - if not self.can_take_request(): - raise Exception("Can not reset while we are still waiting response for previous request") - request = communication.messages.ResetEpochRequest(seed_generator, iter_reset_fn) - self.request_queue.put(request) - self.request_sent(request) - - def request_next(self): - if not self.can_take_request(): - raise Exception("Can not request next item while we are still waiting response for previous request") - request = communication.messages.GetNextRequest() - self.request_queue.put(request) - self.request_sent(request) - - def get_response_reset_epoch(self, block=False): - try: - response = self.response_queue.get(block=block) - except EmptyException: - raise EmptyQueue("queue is empty") - self.request_served(response) - - if not isinstance(response, communication.messages.ResetEpochResponse): - raise Exception("Invalid response received") - - def get_response_limit(self, block=False): - try: - response = self.response_queue.get(block=block) - except EmptyException: - raise EmptyQueue("queue is empty") - self.request_served(response) - - if not isinstance(response, communication.messages.LimitResponse): - raise Exception("Invalid response received when expecting `LimitResponse`") - - def get_response_pause(self, block=False): - try: - response = self.response_queue.get(block=block) - except EmptyException: - raise EmptyQueue("queue is empty") - self.request_served(response) - - if not isinstance(response, communication.messages.PauseResponse): - raise Exception("Invalid response received when expecting `PauseResponse`") - - def get_response_resume(self, block=False): - try: - response = self.response_queue.get(block=block) - except EmptyException: - raise EmptyQueue("queue is empty") - self.request_served(response) - - if not isinstance(response, communication.messages.ResumeResponse): - raise Exception("Invalid response received when expecting `ResumeResponse`") - - def get_response_next(self, block=False, timeout=None): - if not self.waiting_for_response(): - raise Exception("Can not expect any response without submitted request") - try: - response = self.response_queue.get(block=block, timeout=timeout) - except EmptyException: - raise EmptyQueue("queue is empty") - self.request_served(response) - - # TODO(629): Add possible response types validation here - return response diff --git a/torchdata/dataloader2/communication/queue.py b/torchdata/dataloader2/communication/queue.py deleted file mode 100644 index 64b51dd92..000000000 --- a/torchdata/dataloader2/communication/queue.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import threading -import time - - -class LocalQueue: - ops = 0 - stored = 0 - uid = 0 - empty = 0 - - def __init__(self, name="unnamed"): - self.items = [] - self.name = name - self.uid = LocalQueue.uid - LocalQueue.uid += 1 - - def put(self, item, block=True): - LocalQueue.ops += 1 - LocalQueue.stored += 1 - self.items.append(item) - - def get(self, block=True, timeout=0): - # TODO(622): Add support of block and timeout arguments - LocalQueue.ops += 1 - if not len(self.items): - LocalQueue.empty += 1 - raise Exception("LocalQueue is empty") - LocalQueue.stored -= 1 - return self.items.pop() - - -class ThreadingQueue: - def __init__(self, name="unnamed"): - self.lock = threading.Lock() - self.items = [] - self.name = name - - def put(self, item, block=True): - with self.lock: - self.items.append(item) - - def get(self, block=True, timeout=0): - # TODO(623): Add support of block and timeout arguments - while True: - with self.lock: - if len(self.items) > 0: - return self.items.pop() - if not block: - raise Exception("Not available") - # TODO(624): Figure out what to do if nothing in the queue - time.sleep(0.000001) diff --git a/torchdata/dataloader2/dataloader2.py b/torchdata/dataloader2/dataloader2.py deleted file mode 100644 index a8767a093..000000000 --- a/torchdata/dataloader2/dataloader2.py +++ /dev/null @@ -1,421 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import pickle -import warnings - -from typing import Any, Dict, Generic, Iterable, Iterator, Optional, TypeVar, Union - -from torchdata.dataloader2.adapter import Adapter -from torchdata.dataloader2.error import PauseIteration -from torchdata.dataloader2.graph._serialization import ( - clone, - DataPipe, - deserialize_datapipe, - MapDataPipe, - serialize_datapipe, -) -from torchdata.dataloader2.random import SeedGenerator -from torchdata.dataloader2.random.seed_generator import _UINT64_UPPER_BOUND -from torchdata.dataloader2.reading_service import CheckpointableReadingServiceInterface, ReadingServiceInterface - -T_co = TypeVar("T_co", covariant=True) -SERIALIZED_DATAPIPE_KEY_NAME = "serialized_datapipe" -READING_SERVICE_STATE_KEY_NAME = "reading_service_state" -RANDOMNESS_STATE_KEY_NAME = "randomness_state" - - -class DataLoader2Iterator(Iterator[T_co]): - r""" - An iterator wrapper returned by ``DataLoader2``'s ``__iter__` method. It delegates method/attribute calls - to the DataPipe iterator object. - - The purpose of this wrapper object is to track the validity of an iterator to enforce the single iterator per - ``DataLoader2`` constraint, and to finalize iteration/shutdown when necessary. - """ - - def __init__(self, dataloader: "DataLoader2", iterator_id: int): - self.dataloader = dataloader - self.iterator_id = iterator_id - self.limit_counter: Optional[int] = None - self.limit_threshold: Optional[int] = None - - def __next__(self) -> T_co: - if self.iterator_id == self.dataloader.valid_iterator_id: - self.dataloader._reset_iter = True - try: - if self.dataloader._is_paused: - raise PauseIteration("DataLoader2 has been paused. `resume` must be called before continuing.") - else: - next_val = next(self.dataloader._datapipe_iter) # type: ignore[arg-type] - if self.limit_threshold is not None: - self.limit_counter = self.limit_counter + 1 # type: ignore[operator] - return next_val - except PauseIteration: # This can be used for raising `StopIteration` without `finalize_iteration` - raise StopIteration - except StopIteration: - if self.dataloader.reading_service is not None: - self.dataloader.reading_service.finalize_iteration() - raise - except Exception: - if self.dataloader: - self.dataloader.shutdown() - raise - finally: - # Call `pause` if threshold is reached - if ( - not self.dataloader._is_paused - and self.limit_threshold is not None - and self.limit_counter >= self.limit_threshold # type: ignore[operator] - ): - self._pause() - else: # `iterator_id` is not valid - if self.dataloader.reading_service is not None: - self.dataloader.reading_service.finalize_iteration() - raise RuntimeError( - "This iterator has been invalidated because another iterator has been created " - "from the same DataLoader2.\n" - "This may be caused multiple references to the same DataLoader2. " - "For feedback regarding this single iterator per DataLoader2 constraint, feel free " - "to comment on this issue: https://github.com/pytorch/data/issues/45." - ) - - def _pause(self) -> None: - r""" - Pauses ``DataLoader2`` by halting its threads and ensure that its state remains unchanged, - allowing ``DataLoader2`` to safely perform snapshotting and similar operations afterwards. - - The ``limit_counter`` is also reset to ``0``. - """ - self.dataloader._pause() - self.limit_counter = 0 - - def resume(self) -> None: - r""" - Restarts the threads within ``DataLoader2`` and allows it to yield additional batches. - """ - self.dataloader._resume() - - def limit(self, num_batches: Optional[int]) -> None: - """ - Pauses ``DataLoader2`` from yielding additional batches after ``num_batches`` has been yielded. The count - begins after this method is invoked (i.e. previously yielded batches do not count towards the threshold). - - While paused, ``DataLoader2``'s threads are halted and its state remains unchanged, - allowing ``DataLoader2`` to safely perform snapshotting and similar operations. - After ``DataLoader2`` is paused, ``resume()`` must be called before it can start yielding again. - - Note: - - ``limit_threshold`` persists after ``pause`` and ``resume``. Use ``.limit(None)`` to remove it. - - If dispatching process is present, in order to make sure limit is in sync across processes, - please place 1-to-N ``DataPipes`` in the dispatching process (before ``sharding_round_robin_dispatch``) - - Args: - num_batches: Number of batches after which the DataLoader2 will pause, use ``None`` to remove the limit - """ - self.limit_counter = 0 - self.limit_threshold = num_batches - self.dataloader._limit(num_batches) - - def __getattr__(self, name): - """ - To delegate operations to ``dataloader._datapipe_iter``. - """ - if "dataloader" not in self.__dict__ or self.dataloader._datapipe_iter is None: - raise AttributeError - return getattr(self.dataloader._datapipe_iter, name) - - -class DataLoader2(Generic[T_co]): - r""" - ``DataLoader2`` is used to optimize and execute the given ``DataPipe`` graph - based on ``ReadingService`` and ``Adapter`` functions, with support for - - - Dynamic sharding for multiprocess and distributed data loading - - Multiple backend ``ReadingServices`` - - ``DataPipe`` graph in-place modification like shuffle control, memory pinning, etc. - - Snapshot the state of data-preprocessing pipeline (WIP) - - Args: - datapipe (``IterDataPipe`` or ``MapDataPipe``): ``DataPipe`` from which to load the data. A deepcopy of this - datapipe will be made during initialization, allowing the input to be re-used in a different ``DataLoader2`` - without sharing states. Input ``None`` can only be used if ``load_state_dict`` is called - right after the creation of the DataLoader. - datapipe_adapter_fn (``Iterable[Adapter]`` or ``Adapter``, optional): ``Adapter`` function(s) that - will be applied to the DataPipe (default: ``None``). - reading_service (ReadingServiceInterface, optional): defines how ``DataLoader2`` should execute operations over - the ``DataPipe``, e.g. multiprocessing/distributed (default: ``None``). A deepcopy of this will be - created during initialization, allowing the ReadingService to be re-used in a different - ``DataLoader2`` without sharing states. - - Note: - When a ``MapDataPipe`` is passed into ``DataLoader2``, in order to iterate through - the data, ``DataLoader2`` will attempt to create an iterator via ``iter(datapipe)``. - If the object has a non-zero-indexed indices, this may fail. - Consider using ``.shuffle()`` (which converts ``MapDataPipe`` to ``IterDataPipe``) - or ``datapipe.to_iter_datapipe(custom_indices)``. - """ - - def __init__( - self, - datapipe: Optional[DataPipe], - datapipe_adapter_fn: Optional[Union[Iterable[Adapter], Adapter]] = None, - reading_service: Optional[ReadingServiceInterface] = None, - ) -> None: - if isinstance(datapipe, MapDataPipe): - datapipe = datapipe.to_iter_datapipe() - self.datapipe = clone(datapipe) if datapipe is not None else None - self._adapted: bool = False - self._datapipe_iter: Optional[Iterator[T_co]] = None - self._reset_iter: bool = True # Sets to `False` when `__iter__` runs, and `True` when `__next__` is called - # TODO(630): Some ReadingServices might want to validate adapters, we can add this feature - if datapipe_adapter_fn is None: - self.datapipe_adapter_fns = None - elif isinstance(datapipe_adapter_fn, Iterable): - self.datapipe_adapter_fns = datapipe_adapter_fn - else: - self.datapipe_adapter_fns = [datapipe_adapter_fn] - self.reading_service = clone(reading_service) - self.reading_service_state: Optional[bytes] = None # is not `None` when `load_state_dict` is called - self._terminated: bool = False - self.valid_iterator_id: Optional[int] = None - self._is_paused = False - - if self.datapipe is not None and self.datapipe_adapter_fns is not None: - for adapter_fn in self.datapipe_adapter_fns: - self.datapipe = adapter_fn(self.datapipe) - self._datapipe_before_reading_service_adapt: DataPipe = clone(self.datapipe) - self._seed_generator: SeedGenerator = SeedGenerator() - self._seed: Optional[int] = None - self._reset_seed: bool = True - # Seed generator as of beginning of each epoch - self._initial_seed_generator: SeedGenerator = clone(self._seed_generator) - self._state_dict: Optional[Dict[str, Any]] = None - - def __iter__(self) -> DataLoader2Iterator[T_co]: - r""" - Return a singleton iterator from the ``DataPipe`` graph adapted by ``ReadingService``. - ``DataPipe`` will be restored if the serialized state is provided to construct - ``DataLoader2``. And, ``initialize_iteration`` and ``finalize_iterator`` will be - invoked at the beginning and end of the iteration correspondingly. - """ - if self.datapipe is None: - raise RuntimeError("Please provide datapipe or use load_state_dict to load datapipe from state") - - if self._terminated: - raise RuntimeError("Cannot iterate over the DataLoader as it has already been shut down") - - if self._reset_iter: - if self._seed is not None: - if self._reset_seed: - self._seed_generator.seed(self._seed) - self._reset_seed = False - else: - self._seed_generator.seed() - - # Saving initial seed generator state - self._initial_seed_generator = clone(self._seed_generator) - - if not self._adapted and self.reading_service is not None: - if self.reading_service_state is None: - self.datapipe = self.reading_service.initialize(self.datapipe) - else: - if not isinstance(self.reading_service, CheckpointableReadingServiceInterface): - raise TypeError("Cannot restore from non-checkpointable reading service") - self.datapipe = self.reading_service.restore(self.datapipe, self.reading_service_state) - self._adapted = True - - if self.reading_service is not None: - iter_reset_fn = self.reading_service.initialize_iteration(self._seed_generator) - if iter_reset_fn: - self.datapipe = iter_reset_fn(self.datapipe) - - self._datapipe_iter = iter(self.datapipe) - self._reset_iter = False - - self.valid_iterator_id = 0 if self.valid_iterator_id is None else self.valid_iterator_id + 1 - return DataLoader2Iterator(self, self.valid_iterator_id) - - def seed(self, seed: int) -> None: - r""" - Set random seed for DataLoader2 to control determinism. - - Args: - seed: Random uint64 seed - """ - if seed >= _UINT64_UPPER_BOUND: - raise ValueError(f"Expected an uint64 seed, but got {seed}.") - self._seed = seed - self._reset_seed = True - - def __del__(self) -> None: - self.shutdown() - - def shutdown(self) -> None: - r""" - Shuts down ``ReadingService`` and clean up iterator. - """ - try: - if not self._terminated: - self._terminated = True - if self.reading_service is not None: - self.reading_service.finalize_iteration() - self.reading_service.finalize() - if not self._reset_iter: - self._reset_iter = True - self._datapipe_iter = None - # Ignore AttributeError in case any attribute has been removed before `__del__` - except AttributeError: - pass - - def __enter__(self) -> "DataLoader2[T_co]": - return self - - def __exit__(self, exc_type, exc_value, traceback) -> None: - self.shutdown() - - def state_dict(self) -> Dict[str, Any]: - r""" - Return a dictionary to represent the state of data-processing pipeline with keys: - - - ``serialized_datapipe``:Serialized ``DataPipe`` before ``ReadingService`` adaption. - - ``reading_service_state``: The state of ``ReadingService`` and adapted ``DataPipe``. - """ - - # If state_dict is called right after load_state_dict calls, without iterator created in the middle, - # we should directly return the original state dict without triggering reading_service.checkpoint - # because the states are unchanged - if self.valid_iterator_id is None and self._state_dict is not None: - return self._state_dict - - reading_service_state = None - if self.reading_service is not None and isinstance(self.reading_service, CheckpointableReadingServiceInterface): - reading_service_state = self.reading_service.checkpoint() - - # Serialize datapipe after applying adapters and before reading service adaption - serialized_datapipe = serialize_datapipe(self._datapipe_before_reading_service_adapt) - serialized_randomness_state = ( - self._seed, - self._reset_seed, - pickle.dumps(self._seed_generator), - pickle.dumps(self._initial_seed_generator), - ) - - return { - SERIALIZED_DATAPIPE_KEY_NAME: serialized_datapipe, - READING_SERVICE_STATE_KEY_NAME: reading_service_state, - RANDOMNESS_STATE_KEY_NAME: serialized_randomness_state, - } - - @classmethod - def from_state( - cls, - state: Dict[str, Any], - reading_service: CheckpointableReadingServiceInterface, - ) -> "DataLoader2[T_co]": - """ - Create new ``DataLoader2`` with ``DataPipe`` graph and ``ReadingService`` restored - from the serialized state. - """ - serialized_datapipe = state[SERIALIZED_DATAPIPE_KEY_NAME] - reading_service_state = state[READING_SERVICE_STATE_KEY_NAME] - - data_loader: "DataLoader2[T_co]" = DataLoader2( - datapipe=deserialize_datapipe(serialized_datapipe), - datapipe_adapter_fn=None, - reading_service=reading_service, - ) - data_loader.reading_service_state = reading_service_state - - # This check is needed for backward compatibility of `state_dict` for users loading from older version - if RANDOMNESS_STATE_KEY_NAME in state: - randomness_state = state[RANDOMNESS_STATE_KEY_NAME] - data_loader._seed, data_loader._reset_seed = randomness_state[0], randomness_state[1] - data_loader._seed_generator = pickle.loads(randomness_state[2]) - data_loader._initial_seed_generator = pickle.loads(randomness_state[3]) - - data_loader._state_dict = state - - return data_loader - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - """ - For the existing ``DataLoader2``, load serialized state to restore ``DataPipe`` graph - and reset the internal state of ``ReadingService``. - """ - # edge case checking - # iterator has already been created: 1) iterator is just created 2) iterator is created and iter is exhausted - if self._datapipe_iter is not None: - raise RuntimeError( - "DataLoaderV2 iterator has already been created, `load_state_dict()` can’t be called. " - "Please create a new dataloader in order to use load state dict." - ) - - self._state_dict = state_dict - - serialized_datapipe = state_dict[SERIALIZED_DATAPIPE_KEY_NAME] - reading_service_state = state_dict[READING_SERVICE_STATE_KEY_NAME] - - # deserialize datapipe - deserialized_datapipe = deserialize_datapipe(serialized_datapipe) - assert deserialized_datapipe is not None - - # override existing datapipe and reading service state - self.datapipe = deserialized_datapipe - self.reading_service_state = reading_service_state - - # This check is needed for backward compatibility of `state_dict` for users loading from older version - if RANDOMNESS_STATE_KEY_NAME in state_dict: - randomness_state = state_dict[RANDOMNESS_STATE_KEY_NAME] - self._seed, self._reset_seed = randomness_state[0], randomness_state[1] - self._seed_generator = pickle.loads(randomness_state[2]) - self._initial_seed_generator = pickle.loads(randomness_state[3]) - - # re-initialize datapipe_adapter_fn and _datapipe_before_reading_service_adapt - if self.datapipe_adapter_fns is not None: - for adapter_fn in self.datapipe_adapter_fns: - self.datapipe = adapter_fn(self.datapipe) - self._datapipe_before_reading_service_adapt = clone(self.datapipe) - - def _restore_checkpoint_beginning_of_epoch(self) -> None: - r""" - At the beginning of each iteration (epoch), the initial state of randomness is automatically saved. - That state is also saved as part of ``state_dict``. This method restores the current DataLoader2 RNG state - to that initial state. - - The common use case is to invoke this method after ``DataLoader2``'s state is restored (through - ``.from_state(...)`` or ``load_state_dict(...)``) in order to resume from the beginning of the last-ran epoch. - """ - self._seed_generator = self._initial_seed_generator - - def _pause(self) -> None: - if hasattr(self.reading_service, "_pause"): - self._is_paused = True - pause_fn = self.reading_service._pause() - if pause_fn is not None: - self.datapipe = pause_fn(self.datapipe) - else: - warnings.warn("ReadingService doesn't support `pause`.") - - def _resume(self) -> None: - if hasattr(self.reading_service, "_resume"): - if not self._is_paused: - warnings.warn("Resume is called when `DataLoader2` is not paused. No operation is performed.") - else: - resume_fn = self.reading_service._resume() - if resume_fn is not None: - self.datapipe = resume_fn(self.datapipe) - self._is_paused = False - else: - warnings.warn("ReadingService doesn't support `resume`.") - - def _limit(self, num_batches: Optional[int]) -> None: - if hasattr(self.reading_service, "_limit"): - limit_fn = self.reading_service._limit(num_batches) - if limit_fn is not None: - self.datapipe = limit_fn(self.datapipe, num_batches) - else: - warnings.warn("ReadingService doesn't support `limit`.") diff --git a/torchdata/dataloader2/error.py b/torchdata/dataloader2/error.py deleted file mode 100644 index 4a7122b58..000000000 --- a/torchdata/dataloader2/error.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -class PauseIteration(StopIteration): - pass diff --git a/torchdata/dataloader2/graph/__init__.py b/torchdata/dataloader2/graph/__init__.py deleted file mode 100644 index dba249dc6..000000000 --- a/torchdata/dataloader2/graph/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps -from torchdata.dataloader2.graph.settings import set_datapipes_seed, set_graph_random_seed -from torchdata.dataloader2.graph.utils import find_dps, list_dps, remove_dp, replace_dp - - -__all__ = [ - "DataPipe", - "DataPipeGraph", - "find_dps", - "list_dps", - "remove_dp", - "replace_dp", - "set_datapipes_seed", - "set_graph_random_seed", - "traverse_dps", -] - - -assert __all__ == sorted(__all__) diff --git a/torchdata/dataloader2/graph/_serialization.py b/torchdata/dataloader2/graph/_serialization.py deleted file mode 100644 index e317531dc..000000000 --- a/torchdata/dataloader2/graph/_serialization.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import pickle - -from torch.utils.data.datapipes.datapipe import ( - _DataPipeSerializationWrapper, - _IterDataPipeSerializationWrapper, - _MapDataPipeSerializationWrapper, -) - -from torchdata.dataloader2.graph import DataPipe -from torchdata.datapipes.iter import IterDataPipe -from torchdata.datapipes.map import MapDataPipe - -try: - import dill - - # XXX: By default, dill writes the Pickler dispatch table to inject its - # own logic there. This globally affects the behavior of the standard library - # pickler for any user who transitively depends on this module! - # Undo this extension to avoid altering the behavior of the pickler globally. - dill.extend(use_dill=False) - HAS_DILL = True -except ImportError: - HAS_DILL = False - -__all__ = [ - "attach_wrapper", - "clone", - "deserialize_datapipe", - "extract_wrapper", - "serialize_datapipe", -] - - -def serialize_datapipe(datapipe: DataPipe) -> bytes: - datapipe = attach_wrapper(datapipe) - try: - return pickle.dumps(datapipe) - except pickle.PickleError as e: - raise NotImplementedError(f"Prototype only support pickle-able datapipes for checkpoint: {e}") - - -def deserialize_datapipe(serialized_state: bytes) -> DataPipe: - try: - datapipe = pickle.loads(serialized_state) - except pickle.PickleError as e: - raise NotImplementedError(f"Prototype only support pickle-able datapipes for checkpoint: {e}") - return extract_wrapper(datapipe) - - -def attach_wrapper(datapipe: DataPipe) -> DataPipe: - r""" - Wraps the ``DataPipe`` with the corresponding serialization wrapper. - """ - wrapped_dp: DataPipe = datapipe - if not isinstance(datapipe, _DataPipeSerializationWrapper): - if isinstance(datapipe, IterDataPipe): - wrapped_dp = _IterDataPipeSerializationWrapper(datapipe) - elif isinstance(datapipe, MapDataPipe): - wrapped_dp = _MapDataPipeSerializationWrapper(datapipe) - return wrapped_dp - - -def extract_wrapper(datapipe: DataPipe) -> DataPipe: - r""" - Extracts the ``DataPipe`` from the serialization wrapper. - """ - if isinstance(datapipe, _DataPipeSerializationWrapper): - datapipe = datapipe._datapipe - return datapipe - - -def clone(obj): - r""" - Standardized way to copy an object when needed, such as for DataPipe/ReadingService. - This uses `pickle` to serialize/deserialize to create the copy. - """ - use_dill = False - try: - states = pickle.dumps(obj) - except Exception: - if HAS_DILL: - states = dill.dumps(obj) - use_dill = True - else: - raise - if use_dill: - return dill.loads(states) - else: - return pickle.loads(states) diff --git a/torchdata/dataloader2/graph/settings.py b/torchdata/dataloader2/graph/settings.py deleted file mode 100644 index 965fbb73d..000000000 --- a/torchdata/dataloader2/graph/settings.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import inspect - -from typing import List - -from torchdata.dataloader2.graph.utils import DataPipe, find_dps, list_dps, traverse_dps -from torchdata.dataloader2.random import SeedGenerator -from torchdata.datapipes.iter import ShardingFilter - - -def _is_random_datapipe(datapipe: DataPipe) -> bool: - if hasattr(datapipe, "set_seed") and inspect.ismethod(datapipe.set_seed): - return True - return False - - -def set_datapipes_seed(datapipes: List[DataPipe], seed_generator: SeedGenerator, distributed_shared: bool) -> None: - for dp in datapipes: - if _is_random_datapipe(dp): - if distributed_shared: - dp.set_seed(seed_generator.generate_shared_seed()) - else: - dp.set_seed(seed_generator.generate_seed()) - - -def set_graph_random_seed(datapipe: DataPipe, seed_generator: SeedGenerator) -> DataPipe: - r""" - Set seeds to the graph of ``DataPipes`` based on a Seed Generator. All random ``DataPipes`` prior to - ``ShardingFilter`` will be set seeds by the same Seed Generator to preserve the same random state - across distributed/non-distributed workers. And, the random ``DataPipes`` after ``ShardingFilter`` - will be set seeds by the worker-local Seed Generator deterministically created based on ``worker_id``. - - Args: - datapipe: - seed_generator: - """ - graph = traverse_dps(datapipe) - sharding_filter_dps = find_dps(graph, ShardingFilter) - - # Set the same seed before sharding_filter - # Using cache to exclude potential duplciate DataPipe - cache = set() - dps_before_sharding = [] - for sf_dp in sharding_filter_dps: - dps = list_dps(traverse_dps(sf_dp)) - for dp in dps: - if id(dp) not in cache: - cache.add(id(dp)) - dps_before_sharding.append(dp) - - set_datapipes_seed(dps_before_sharding, seed_generator, distributed_shared=True) - - # Set different seeds after sharding_filter - dps_after_sharding = list_dps(graph, exclude_dps=sharding_filter_dps) - set_datapipes_seed(dps_after_sharding, seed_generator, distributed_shared=False) - - return datapipe diff --git a/torchdata/dataloader2/graph/utils.py b/torchdata/dataloader2/graph/utils.py deleted file mode 100644 index 6ac03dc6a..000000000 --- a/torchdata/dataloader2/graph/utils.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from collections import deque -from typing import Deque, Dict, List, Optional, Set, Type, Union - -from torchdata.dataloader2.graph import DataPipe, DataPipeGraph, traverse_dps -from torchdata.datapipes.iter import IterDataPipe -from torchdata.datapipes.map import MapDataPipe - - -def find_dps(graph: DataPipeGraph, dp_type: Type[DataPipe]) -> List[DataPipe]: - r""" - Given the graph of DataPipe generated by ``traverse_dps`` function, return DataPipe - instances with the provided DataPipe type. - """ - dps: List[DataPipe] = [] - cache: Set[int] = set() - - def helper(g) -> None: # pyre-ignore - for dp_id, (dp, src_graph) in g.items(): - if dp_id in cache: - continue - cache.add(dp_id) - if type(dp) is dp_type: # Please not use `isinstance`, there is a bug. - dps.append(dp) - helper(src_graph) - - helper(graph) - - return dps - - -def list_dps(graph: DataPipeGraph, exclude_dps: Optional[Union[DataPipe, List[DataPipe]]] = None) -> List[DataPipe]: - r""" - Given the graph of DataPipe generated by ``traverse_dps`` function, return a list - of all DataPipe instances without duplication. If ``exclude_dps`` is provided, - the provided ``DataPipes`` and their predecessors will be ignored. - - Note: - - The returned list is in the order of breadth first search of the graph - """ - dps: List[DataPipe] = [] - cache: Set[int] = set() - - if exclude_dps is not None: - if isinstance(exclude_dps, (IterDataPipe, MapDataPipe)): - exclude_dps = [ - exclude_dps, - ] - for exclude_dp in exclude_dps: # type: ignore[union-attr] - assert isinstance(exclude_dp, (IterDataPipe, MapDataPipe)) - # Skip DataPipe that has already been excluded - if id(exclude_dp) in cache: - continue - for dp in list_dps(traverse_dps(exclude_dp)): # type: ignore[arg-type] - cache.add(id(dp)) - - q: Deque = deque() - # Initialization - for dp_id, (dp, subgraph) in graph.items(): - if dp_id not in cache: - q.append((dp_id, dp, subgraph)) - cache.add(dp_id) - - while len(q) > 0: - dp_id, dp, subgraph = q.popleft() - dps.append(dp) - for parent_dp_id, (parent_dp, parent_subgraph) in subgraph.items(): - if parent_dp_id not in cache: - q.append((parent_dp_id, parent_dp, parent_subgraph)) - cache.add(parent_dp_id) - - return dps - - -# Given the DataPipe needs to be replaced and the expected DataPipe, return a new graph -def replace_dp(graph: DataPipeGraph, old_datapipe: DataPipe, new_datapipe: DataPipe) -> DataPipeGraph: - r""" - Given the graph of DataPipe generated by ``traverse_dps`` function and the DataPipe to be replaced and - the new DataPipe, return the new graph of DataPipe. - """ - assert len(graph) == 1 - - if id(old_datapipe) in graph: - graph = traverse_dps(new_datapipe) - - final_datapipe = list(graph.values())[0][0] - - for recv_dp, send_graph in graph.values(): - _replace_dp(recv_dp, send_graph, old_datapipe, new_datapipe) - - return traverse_dps(final_datapipe) - - -def remove_dp(graph: DataPipeGraph, datapipe: DataPipe) -> DataPipeGraph: - r""" - Given the graph of DataPipe generated by ``traverse_dps`` function and the DataPipe to be removed, - return the new graph of DataPipe. - - Note: - - This function can not remove DataPipe that takes multiple DataPipes as the input. - """ - assert len(graph) == 1 - - dp_graph = traverse_dps(datapipe) - dp_id = id(datapipe) - if len(dp_graph[dp_id][1]) == 0: - raise RuntimeError("Cannot remove the source DataPipe from the graph of DataPipe") - if len(dp_graph[dp_id][1]) > 1: - raise RuntimeError("Cannot remove the receiving DataPipe having multiple sending DataPipes") - - if dp_id in graph: - graph = graph[dp_id][1] - - for recv_dp, send_graph in graph.values(): - _remove_dp(recv_dp, send_graph, datapipe) - - # Get the last DataPipe in graph - assert len(graph) == 1 - datapipe = list(graph.values())[0][0] - - return traverse_dps(datapipe) - - -def _find_replicable_branches(graph: DataPipeGraph) -> List[DataPipe]: - r""" - Given the graph of DataPipe generated by ``traverse_dps`` function, return DataPipe - instances of which all of prior DataPipes are replicable (``dp.is_replicable() == True``). - """ - assert len(graph) == 1, "DataPipeGraph should only contain a single output DataPipe" - - dps: List[DataPipe] = [] - dp_ids: Set[int] = set() - branch_is_replicable: Dict[int, bool] = {} - - root_dp_id = list(graph.keys())[0] - root_dp, root_graph = graph[root_dp_id] - - def _is_replicable(root_dp_id, root_dp, root_graph) -> bool: # pyre-ignore - if root_dp_id in branch_is_replicable: - return branch_is_replicable[root_dp_id] - # Temporarily set to True - branch_is_replicable[root_dp_id] = True - if hasattr(root_dp, "is_replicable") and not root_dp.is_replicable(): - branch_is_replicable[root_dp_id] = False - for dp_id, (dp, src_graph) in root_graph.items(): - if not _is_replicable(dp_id, dp, src_graph): - branch_is_replicable[root_dp_id] = False - # Do not break to go through all children - if not branch_is_replicable[root_dp_id]: - # All children should have been added to branch_is_replicable already - for dp_id, (dp, _) in root_graph.items(): - if dp_id in dp_ids: - continue - if branch_is_replicable[dp_id]: - # Guarantee returning the frontmost replicable DataPipe - prior_dps = list_dps(traverse_dps(dp)) - if all(id(p_dp) not in dp_ids for p_dp in prior_dps): - dps.append(dp) - dp_ids.add(dp_id) - return branch_is_replicable[root_dp_id] - - if _is_replicable(root_dp_id, root_dp, root_graph): - if root_dp_id not in dp_ids: - # Guarantee returning the frontmost replicable DataPipe - prior_dps = list_dps(traverse_dps(root_dp)) - if all(id(p_dp) not in dp_ids for p_dp in prior_dps): - dps.append(root_dp) - dp_ids.add(root_dp_id) - - return dps - - -# For each `recv_dp`, find if the source_datapipe needs to be replaced by the new one. -# If found, find where the `old_dp` is located in `recv_dp` and switch it to the `new_dp` -def _replace_dp(recv_dp, send_graph: DataPipeGraph, old_dp: DataPipe, new_dp: DataPipe) -> None: - old_dp_id = id(old_dp) - for send_id in send_graph: - if send_id == old_dp_id: - _assign_attr(recv_dp, old_dp, new_dp, inner_dp=True) - else: - send_dp, sub_send_graph = send_graph[send_id] - _replace_dp(send_dp, sub_send_graph, old_dp, new_dp) - - -# For each `recv_dp`, find if the source_datapipe needs to be replaced by the new one. -# If found, find where the `old_dp` is located in `dp` and switch it to the `new_dp` -def _remove_dp(recv_dp, send_graph: DataPipeGraph, datapipe: DataPipe) -> None: - dp_id = id(datapipe) - for send_dp_id in send_graph: - if send_dp_id == dp_id: - send_dp, sub_send_graph = send_graph[send_dp_id] - # if len(sub_send_graph) == 0: - # raise RuntimeError("Cannot remove the source DataPipe from the graph of DataPipe") - # if len(sub_send_graph) > 1: - # raise RuntimeError("Cannot remove the receiving DataPipe having multiple sending DataPipes") - src_dp = list(sub_send_graph.values())[0][0] - _assign_attr(recv_dp, send_dp, src_dp, inner_dp=True) - else: - send_dp, sub_send_graph = send_graph[send_dp_id] - _remove_dp(send_dp, sub_send_graph, datapipe) - - -# Recursively re-assign datapipe for the sake of nested data structure -# `inner_dp` is used to prevent recursive call if we have already met a `DataPipe` -def _assign_attr(obj, old_dp, new_dp, inner_dp: bool = False): - if obj is old_dp: - return new_dp - elif isinstance(obj, (IterDataPipe, MapDataPipe)): - # Prevent recursive call for DataPipe - if not inner_dp: - return None - for k in list(obj.__dict__.keys()): - new_obj = _assign_attr(obj.__dict__[k], old_dp, new_dp) - if new_obj is not None: - obj.__dict__[k] = new_obj - break - return None - elif isinstance(obj, dict): - for k in list(obj.keys()): - new_obj = _assign_attr(obj[k], old_dp, new_dp) - if new_obj is not None: - obj[k] = new_obj - break - return None - # Tuple is immutable, has to re-create a tuple - elif isinstance(obj, tuple): - temp_list = [] - flag = False - for o in obj: - new_obj = _assign_attr(o, old_dp, new_dp, inner_dp) - if new_obj is not None: - flag = True - temp_list.append(new_dp) - else: - temp_list.append(o) - if flag: - return tuple(temp_list) # Special case - else: - return None - elif isinstance(obj, list): - for i in range(len(obj)): - new_obj = _assign_attr(obj[i], old_dp, new_dp, inner_dp) - if new_obj is not None: - obj[i] = new_obj - break - return None - elif isinstance(obj, set): - new_obj = None - for o in obj: - if _assign_attr(o, old_dp, new_dp, inner_dp) is not None: - new_obj = new_dp - break - if new_obj is not None: - obj.remove(old_dp) - obj.add(new_dp) - return None - else: - return None diff --git a/torchdata/dataloader2/linter.py b/torchdata/dataloader2/linter.py deleted file mode 100644 index c09fb864d..000000000 --- a/torchdata/dataloader2/linter.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torchdata.dataloader2.graph import DataPipe, DataPipeGraph, traverse_dps - -from torchdata.datapipes.iter import ShardingFilter, Shuffler - - -def _check_shuffle_before_sharding(datapipe: DataPipe) -> bool: - """ - This function will check if a ``shuffle`` operation is presented before each - ``sharding_filter`` operation for every single path in the ``DataPipe`` graph. - """ - graph: DataPipeGraph = traverse_dps(datapipe) # type: ignore[arg-type] - return _check_shuffler_before_sharding_helper(graph) - - -def _check_shuffler_before_sharding_helper(graph: DataPipeGraph) -> bool: - if not graph: - return True - - if len(graph) > 1: - for dp, sub_graph in graph.values(): - if isinstance(dp, ShardingFilter): - if not _has_shuffler(sub_graph): - return False - else: - if not _check_shuffler_before_sharding_helper(sub_graph): - return False - return True - - dp, dp_graph = list(graph.values())[0] - if isinstance(dp, ShardingFilter): - return _has_shuffler(dp_graph) - - return _check_shuffler_before_sharding_helper(dp_graph) - - -def _has_shuffler(graph: DataPipeGraph) -> bool: - if not graph: - return False - - if len(graph) > 1: - for dp, sub_graph in graph.values(): - if not (isinstance(dp, Shuffler) or _has_shuffler(sub_graph)): - return False - return True - - dp, dp_graph = list(graph.values())[0] - if isinstance(dp, Shuffler): - return True - - return _has_shuffler(dp_graph) diff --git a/torchdata/dataloader2/random/__init__.py b/torchdata/dataloader2/random/__init__.py deleted file mode 100644 index 7f4de2351..000000000 --- a/torchdata/dataloader2/random/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torchdata.dataloader2.random.distributed import dist_share_seed -from torchdata.dataloader2.random.seed_generator import SeedGenerator - - -__all__ = ["SeedGenerator", "dist_share_seed"] diff --git a/torchdata/dataloader2/random/_philox.py b/torchdata/dataloader2/random/_philox.py deleted file mode 100644 index f45fc468c..000000000 --- a/torchdata/dataloader2/random/_philox.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import List, Optional, Tuple - -# Note [Philox Engine implementation] -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -# Refer to: http://www.thesalmons.org/john/random123/papers/random123sc11.pdf for details regarding the engine. -# Using Philox4×32-10 for the sake of performance, randomness and crush-resistance. -# The following code could be optimized into C++ bindings - -# Philox Constants -kPhilox10A = 0x9E3779B9 -kPhilox10B = 0xBB67AE85 -kPhiloxSA = 0xD2511F53 -kPhiloxSB = 0xCD9E8D57 - -MASK_32b = 0xFFFFFFFF -MASK_64b = 0xFFFFFFFFFFFFFFFF -HALF_UINT64 = 0x8000000000000000 - - -def mulhilo32(a: int, b: int) -> Tuple[int, int]: - product = a * b - return product & MASK_32b, (product >> 32) & MASK_32b - - -def single_round(key: List[int], ctr: List[int]) -> List[int]: - lo0, hi0 = mulhilo32(kPhiloxSA, ctr[0]) - lo1, hi1 = mulhilo32(kPhiloxSB, ctr[2]) - res = [0] * 4 - res[0] = hi1 ^ ctr[1] ^ key[0] - res[1] = lo1 - res[2] = hi0 ^ ctr[3] ^ key[1] - res[3] = lo0 - return res - - -def philox_10_round(key: Tuple[int, int], ctr: List[int]) -> List[int]: - _key = list(key) - _ctr = list(ctr) - for _ in range(9): - _ctr = single_round(_key, _ctr) - _key[0] = (_key[0] + kPhilox10A) & MASK_32b - _key[1] = (_key[1] + kPhilox10B) & MASK_32b - return single_round(_key, _ctr) - - -class PhiloxEngine: - r""" - Philox is a counter-based RNG with a certain properties: - - High performance - - Statistiacl random - - Crush-resistance Bijection - - Generate new seeds or spawn parallel seeds for worker processes. - """ - - def __init__(self, seed: Optional[int] = None) -> None: - self._seed: Tuple[int, int] = (-1, -1) - self._ctr: List[int] = [0] * 4 - self._generated_seeds: Optional[List[int]] = None - self._spawn_seed: Tuple[int, int] = (-1, -1) - if seed is not None: - self.seed(seed) - - def _incr_ctr(self) -> None: - for i in range(3): - self._ctr[i] += 1 - if self._ctr[i] <= MASK_32b: - return - self._ctr[i] = 0 - self._ctr[3] += 1 - # if overflow (2^128) has occurred during addition, back to the initial counter - if self._ctr[3] > MASK_32b: - self._ctr[3] = 0 - self._incr_ctr() - - def seed(self, seed: int) -> "PhiloxEngine": - seed = seed & MASK_64b - # Convert seed from int64 to uint64 - if seed < 0: - seed = seed + HALF_UINT64 - lo = seed & MASK_32b - hi = (seed >> 32) & MASK_32b - self._seed = (lo, hi) - # Reset counter and cached seed - self._ctr = [0] * 4 - self._generated_seeds = None - # Generate the spawn seed - self._spawn_seed = tuple(philox_10_round(self._seed, self._ctr)[:2]) # type: ignore[assignment] - self._incr_ctr() - return self - - def generate(self) -> int: - assert self._seed != (-1, -1), "Please provide seed to PhiloxEngine" - - if self._generated_seeds is None: - self._generated_seeds = philox_10_round(self._seed, self._ctr) - self._incr_ctr() - res = self._generated_seeds[:2] - else: - res = self._generated_seeds[2:] - self._generated_seeds = None - return (res[1] << 32) + res[0] - - def clone(self) -> "PhiloxEngine": - new_engine = PhiloxEngine(None) - new_engine._seed = self._seed # immutable tuple - new_engine._ctr = self._ctr.copy() - new_engine._generated_seeds = None if self._generated_seeds is None else self._generated_seeds.copy() - new_engine._spawn_seed = self._spawn_seed # immutable tuple - return new_engine - - def spawn(self, index: int) -> "PhiloxEngine": - assert index >= 0, f"Expected a non-negative value for spawn, but found {index}" - assert self._spawn_seed != (-1, -1), "Please provide seed to PhiloxEngine" - - offset = index % 2 - val = index if offset == 0 else index - 1 - - ctr = [] - for _ in range(4): - ctr.append(val & MASK_32b) - val = val >> 32 - - res = philox_10_round(self._spawn_seed, ctr)[offset * 2 : offset * 2 + 2] - sub_seed = (res[1] << 32) + res[0] - return PhiloxEngine(sub_seed) diff --git a/torchdata/dataloader2/random/distributed.py b/torchdata/dataloader2/random/distributed.py deleted file mode 100644 index fe4b491af..000000000 --- a/torchdata/dataloader2/random/distributed.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional - -import torch -import torch.distributed as dist - - -_HALF_UINT64 = 0x8000000000000000 - - -def dist_share_seed(seed: int, process_group: Optional[dist.ProcessGroup] = None) -> int: - # Convert uint64 to int64 to prevent overflow for integer Tensor - seed -= _HALF_UINT64 - shared_seed = torch.tensor(seed, dtype=torch.int64) - dist.broadcast(shared_seed, src=0, group=process_group) - # Revert int64 back to uint64 - return int(shared_seed.item()) + _HALF_UINT64 diff --git a/torchdata/dataloader2/random/seed_generator.py b/torchdata/dataloader2/random/seed_generator.py deleted file mode 100644 index 1203e7f75..000000000 --- a/torchdata/dataloader2/random/seed_generator.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional, Tuple - -import torch - -from torchdata.dataloader2.random._philox import PhiloxEngine - - -_UINT64_UPPER_BOUND = 2 ** 64 - - -def _get_torch_random_seed(): - with torch.random.fork_rng(): - iinfo = torch.iinfo(torch.int64) - seed = torch.randint(iinfo.min, iinfo.max, ()).item() - # Convert int64 to uint64 - seed += 2 ** 63 - return seed - - -class SeedGenerator: - r""" - ``SeedGenerator`` is used to generate seeds in a deterministic and randomized manner - based on a user-provided initial seed. Internally, it utilizes a counter-based PRNG - called Philox to generate random seeds. - - Args: - seed: The base seed to generate random seeds - """ - _shared_rng: PhiloxEngine - _worker_rng: PhiloxEngine - - def __init__(self, seed: Optional[int] = None, _rngs: Optional[Tuple[PhiloxEngine, PhiloxEngine]] = None) -> None: - if seed is not None and _rngs is not None: - raise ValueError("SeedGenerator doesn't allow both seed and _rng specified at the same time") - if _rngs is None: - self._shared_rng = PhiloxEngine() - self._worker_rng = PhiloxEngine() - self.seed(seed) - else: - assert len(_rngs) == 2 - self._shared_rng, self._worker_rng = _rngs - - def seed(self, seed: Optional[int] = None) -> None: - r""" - Re-seed the ``SeedGenerator``. When ``None`` is provided, a random seed generated - by the default PyTorch RNG. - """ - if seed is None: - seed = _get_torch_random_seed() - if seed >= _UINT64_UPPER_BOUND: - raise ValueError(f"Expected an uint64 seed, but got {seed}.") - self._shared_rng.seed(seed) - self._worker_rng.seed(seed) - - def generate_shared_seed(self) -> int: - r""" - Generate one uint64 random seed that is supposed to be the same across - distributed processes. - """ - return self._shared_rng.generate() - - def generate_seed(self) -> int: - r""" - Generate one unique uint64 random seed based on distributed and multiprocessing - information. - """ - return self._worker_rng.generate() - - def spawn(self, worker_id: int, inplace: bool = False) -> "SeedGenerator": - r""" - Spawn a sub-SeedGenerator based on the provided worker_id. If inplace is turn on, the SeedGenerator - will evolve itself rather than spawning a new - """ - if worker_id < 0: - raise ValueError(f"Expected `rank` equal or larger than 0, but got {worker_id}.") - - if inplace: - self._worker_rng = self._worker_rng.spawn(worker_id) - return self - return SeedGenerator(seed=None, _rngs=(self._shared_rng.clone(), self._worker_rng.spawn(worker_id))) - - def __getstate__(self): - state = ( - self._shared_rng, - self._worker_rng, - ) - return state - - def __setstate__(self, state): - self._shared_rng, self._worker_rng = state diff --git a/torchdata/dataloader2/reading_service.py b/torchdata/dataloader2/reading_service.py deleted file mode 100644 index 776c6f7ef..000000000 --- a/torchdata/dataloader2/reading_service.py +++ /dev/null @@ -1,676 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import multiprocessing as py_mp -import pickle -import warnings - -from abc import ABC, abstractmethod -from datetime import timedelta -from functools import partial -from multiprocessing.queues import Queue -from typing import Callable, List, Optional, Tuple - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp - -from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES - -from torchdata._constants import default_dl2_worker_join_timeout_in_s, default_timeout_in_s -from torchdata.dataloader2 import communication -from torchdata.dataloader2.graph import DataPipe, list_dps, replace_dp, set_graph_random_seed, traverse_dps -from torchdata.dataloader2.graph._serialization import attach_wrapper -from torchdata.dataloader2.graph.utils import _find_replicable_branches -from torchdata.dataloader2.random import dist_share_seed, SeedGenerator -from torchdata.dataloader2.utils import process_init_fn, WorkerInfo -from torchdata.dataloader2.utils.dispatch import _DummyIterDataPipe, find_lca_round_robin_sharding_dp -from torchdata.datapipes.iter import FullSync - - -class ReadingServiceInterface(ABC): - r""" - Interface for ``ReadingService``. Please extend custom ``ReadingService`` based on this interface class. - - ReadingService must be picklable prior to ``initialize`` being called. This is because a copy of it will be - created by ``DataLoader2`` to avoid the situation where the same ReadingService object is used by - multiple ``DataLoader2``, and its internal state will be modifiable by each of them. - - As a result of this constraint, certain initialization steps may need to take place within the - ``initialize`` method rather than ``__init__`` of the ReadingService class. - """ - - @abstractmethod - def initialize(self, datapipe: DataPipe) -> DataPipe: - r""" - ``ReadingService`` takes a ``DataPipe`` graph, adapts it into a new ``DataPipe`` graph based on the custom need. - Called once in creating ``DataLoader2`` iterator at first time. Prior to calling this method, - the ``ReadingService`` object must be picklable. - - Args: - datapipe: Original ``DataPipe`` graph. - - Return: - An adapted or a new ``DataPipe`` graph. - """ - pass - - def finalize(self) -> None: - r""" - ``ReadingService`` cleans up internal states and fully shuts down the service. - Called in ``DataLoader2``'s ``shutdown`` and ``__del__``. - """ - pass - - def initialize_iteration( - self, seed_generator: SeedGenerator, iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None - ) -> Optional[Callable[[DataPipe], DataPipe]]: - r""" - ``ReadingService`` spins up service for an epoch. Called at the beginning - of every time getting ``DataLoader2`` iterator. - - Args: - seed_generator: SeedGenerator object created and managed by DataLoader2. As the single - source of randomness, it will govern the determinism for all of random operations - with the graph of DataPipes. - iter_reset_fn: Optional reset function from the prior ``ReadingServcie`` - when ``SequentialReadingService`` chains multiple ``ReadingServices`` - - Returns: - A new ``iter_reset_fn`` to be used by subseqeuent ``ReadingService`` - - Example: - MultiProcessingReadingService starts setting worker seeds per process and prefetching - items from the graph. - """ - pass - - def finalize_iteration(self) -> None: - r""" - ``ReadingService`` ends service after an epoch is finished. Called when - the iterator of ``DataLoader2`` is depleted. - """ - pass - - def __del__(self): - # Due to non-deterministic order of destruction, by the time `finalize` is called, - # some objects may already be `None`. - try: - self.finalize() - except AttributeError: - pass - - -class CheckpointableReadingServiceInterface(ReadingServiceInterface): - r""" - Extend ``ReadingServiceInterface`` with two additional methods to save/restore the state of the data-processing graph. - """ - - @abstractmethod - def checkpoint(self) -> bytes: - """ - ``ReadingService`` serializes the internal states. Called in ``DataLoader2.state_dict``. - """ - pass - - @abstractmethod - def restore(self, datapipe: DataPipe, serialized_state: bytes) -> DataPipe: - """ - ``ReadingService`` adapts ``DataPipe`` graph based on the serialized state. - Called once in creating ``DataLoader2`` iterator at first time. - Counterpart of ``initialize``, which adapt ``DataPipe`` graph from scratch. - - Args: - datapipe: original ``DataPipe`` graph before adapted by ``ReadingService`` - serialized_state: The serialized state of internal state used to restore the state - of the adapted ``DataPipe`` graph. - - Returns: - Adapted ``DataPipe`` generated from the serialized state. - """ - pass - - -def _collate_no_op(batch): - return batch[0] - - -class PrototypeMultiProcessingReadingService(ReadingServiceInterface): - def __new__(cls, *args, **kwargs): - warnings.warn( - "`PrototypeMultiProcessingReadingService` is deprecated and will be removed in TorchData 0.8. " - "Please use `MultiProcessingReadingService`." - ) - return MultiProcessingReadingService(*args, **kwargs) - - -class InProcessReadingService(ReadingServiceInterface): - r""" - Default ReadingService to serve the ``DataPipe`` graph in the main process, - and apply graph settings like determinism control to the graph. - - Args: - prefetch_cnt: (int, 0 by default): Number of data will be prefetched in the main process. - init_fn: (Callable, optional): Custom function to be called when the main - process starts to iterate over ``DataPipe`` graph. - reset_fn: (Callable, optional): Custom function to be called at the beginning - of each epoch with ``DataPipe``, ``WorkerInfo`` and ``SeedGenerator`` - as the expected arguments. - """ - _prefetch_cnt: int - _init_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]] - _reset_fn: Optional[Callable[[DataPipe, WorkerInfo, SeedGenerator], DataPipe]] - _end_datapipe: Optional[DataPipe] - - def __init__( - self, - prefetch_cnt: int = 0, - init_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]] = None, - reset_fn: Optional[Callable[[DataPipe, WorkerInfo, SeedGenerator], DataPipe]] = None, - ) -> None: - self._prefetch_cnt = prefetch_cnt - self._init_fn = init_fn - self._reset_fn = reset_fn - self._end_datapipe = None - - def initialize(self, datapipe: DataPipe) -> DataPipe: - worker_info = WorkerInfo(1, 0) - datapipe = process_init_fn(datapipe, worker_info, self._init_fn) - self._end_datapipe = datapipe - return datapipe - - def initialize_iteration( - self, seed_generator: SeedGenerator, iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None - ) -> Optional[Callable[[DataPipe], DataPipe]]: - assert self._end_datapipe is not None - - # Set random seeds for DataPipe that are in the main process (NOT those in worker processes) - # Worker seeds are set in `process_reset_fn` - set_graph_random_seed(self._end_datapipe, seed_generator) - - return None - - def _pause( - self, pause_fn: Optional[Callable[[DataPipe], DataPipe]] = None - ) -> Optional[Callable[[DataPipe], DataPipe]]: - """ - Pauses DataPipes' activities in the main process in order to collect state. - """ - assert self._end_datapipe is not None - - dp_list = list_dps(traverse_dps(self._end_datapipe)) - for dp in dp_list: - if hasattr(dp, "pause") and callable(dp.pause): - dp.pause() - return None - - def _resume( - self, resume_fn: Optional[Callable[[DataPipe], DataPipe]] = None - ) -> Optional[Callable[[DataPipe], DataPipe]]: - """ - Resumes DataPipes' activities. This is required to be called after `_pause` before - the DataLoader can keep yielding elements. - """ - assert self._end_datapipe is not None - - dp_list = list_dps(traverse_dps(self._end_datapipe)) - # Reversed order - for dp in dp_list[::-1]: - if hasattr(dp, "resume") and callable(dp.resume): - dp.resume() - return None - - def _limit( - self, num_batches: Optional[int], limit_fn: Optional[Callable[[DataPipe, Optional[int]], DataPipe]] = None - ) -> Optional[Callable[[DataPipe, Optional[int]], DataPipe]]: - r""" - Apply limit_fn to the DataPipe graph. - """ - if limit_fn is not None: - # TODO: Remove when flexible checkpoint is supported - limit_fn(self._end_datapipe, num_batches) # type: ignore[arg-type] - return None - - -class MultiProcessingReadingService(ReadingServiceInterface): - r""" - Spawns multiple worker processes to load data from the ``DataPipe`` graph. - If any non-replicable ``DataPipe`` (``sharding_round_robin_dispatch``) is presented in the graph, - a separate dispatching process will be created to load data from the lowest common ancestor - of all non-replicable ``DataPipes`` and distributes data to each worker process in the round-robin manner - Then, the subsequent ``DataPipe`` graph in each worker process will process the data from the dispatching - process and eventually return the result to the main process. - - Args: - num_workers (int): How many subprocesses to use for data loading. - multiprocessing_context (str, optional): Multiprocessing starting method. - If method is None then the default context is returned. - Otherwise, method should be 'fork', 'spawn'. - worker_prefetch_cnt: (int, 10 by default): Number of data will be prefetched at - the end of each worker process. - main_prefetch_cnt: (int, 10 by default): Number of data will be prefetched - at the end of the whole pipeline in the main process. - worker_init_fn: (Callable, optional): Function to be called when each worker - process launches with ``DataPipe`` and ``WorkerInfo`` as the expected arguments. - worker_reset_fn: (Callable, optional): Function to be called at the beginning - of each epoch in each worker process with ``DataPipe``, ``WorkerInfo`` - and ``SeedGenerator`` as the expected arguments. - """ - num_workers: int - multiprocessing_context: Optional[str] - worker_prefetch_cnt: int - main_prefetch_cnt: int - worker_init_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]] - worker_reset_fn: Optional[Callable[[DataPipe, WorkerInfo, SeedGenerator], DataPipe]] - _worker_processes: List[Tuple[py_mp.process.BaseProcess, Queue, Queue]] - _dispatch_process: Optional[Tuple[py_mp.process.BaseProcess, List[Queue], List[Queue]]] - _worker_datapipes: List[DataPipe] - _worker_consumer_datapipe: Optional[DataPipe] - _main_prefetch_datapipe: Optional[DataPipe] - _end_datapipe: Optional[DataPipe] - _mp: bool - _finalized: bool = False - - def __init__( - self, - num_workers: int = 0, - multiprocessing_context: Optional[str] = None, - worker_prefetch_cnt: int = 10, - main_prefetch_cnt: int = 10, - worker_init_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]] = None, - worker_reset_fn: Optional[Callable[[DataPipe, WorkerInfo, SeedGenerator], DataPipe]] = None, - ) -> None: - if num_workers == 0: - warnings.warn("Please use `InProcessReadingService` for num_workers=0") - self.num_workers = num_workers - - if multiprocessing_context is not None: - _all_start_methods = mp.get_all_start_methods() - assert ( - multiprocessing_context in _all_start_methods - ), f"Please choose one available multiprocessing context from {_all_start_methods}" - self.multiprocessing_context = multiprocessing_context - self.worker_prefetch_cnt = worker_prefetch_cnt - self.main_prefetch_cnt = main_prefetch_cnt - self.worker_init_fn = worker_init_fn - self.worker_reset_fn = worker_reset_fn - self._worker_processes = [] - self._dispatch_process = None - self._worker_datapipes = [] - self._worker_consumer_datapipe = None - self._main_prefetch_datapipe = None - self._end_datapipe = None - self._mp = num_workers > 0 - - def initialize(self, datapipe: DataPipe) -> DataPipe: - r""" - ``MultiProcessingReadingService`` finds information about sharding, - separates graph by multiple pieces and reconnects it using queues. - creates subprocesses. - """ - if not self._mp: - # TODO(616): Warn and recommend usage of InProcessReadingService - worker_info = WorkerInfo(1, 0) - datapipe = process_init_fn(datapipe, worker_info, self.worker_init_fn) - self._end_datapipe = datapipe - return datapipe - - ctx = mp.get_context(self.multiprocessing_context) - - # Launch dispatching process for the lowest common ancestor of non-replicable DataPipes - graph = traverse_dps(datapipe) - dispatching_dp = find_lca_round_robin_sharding_dp(graph) - # TODO(ejguan): When the last DataPipe is round_robin_sharding, use InPrcoessReadingService - if dispatching_dp is not None: - dummy_dp = _DummyIterDataPipe() # type: ignore - graph = replace_dp(graph, dispatching_dp, dummy_dp) # type: ignore[arg-type] - datapipe = list(graph.values())[0][0] - # TODO(ejguan): Determine buffer_size at runtime or use unlimited buffer - round_robin_dps = dispatching_dp.round_robin_demux(num_instances=self.num_workers) - # TODO(ejguan): Benchmark if we need to prefetch in dispatching process - worker_info = WorkerInfo(self.num_workers, 0) - process, req_queues, res_queues = communication.eventloop.CreateProcessForMultipleDataPipelines( - ctx, - round_robin_dps, - process_name="dispatching process", - worker_info=worker_info, - custom_reset_fn=self.worker_reset_fn, - ) - assert len(req_queues) == self.num_workers and len(res_queues) == self.num_workers - for req_queue in req_queues: - req_queue.cancel_join_thread() - for res_queue in res_queues: - res_queue.cancel_join_thread() - process.daemon = True - process.start() - self._dispatch_process = (process, req_queues, res_queues) - - # Find replicable branches for worker processes - # The rest of non-replicable DataPipes will remain in the main process - replicable_dps = _find_replicable_branches(graph) - assert ( - len(replicable_dps) == 1 - ), "MultiProcessingReadingService only supports single replicable branch currently" - replicable_dp = replicable_dps[0] - replicable_dp = attach_wrapper(replicable_dp) - - for worker_id in range(self.num_workers): - worker_info = WorkerInfo(self.num_workers, worker_id) - # Dispatching process for non-replicable DataPipes exists - dispatching_req_queue = None if self._dispatch_process is None else self._dispatch_process[1][worker_id] - dispatching_res_queue = None if self._dispatch_process is None else self._dispatch_process[2][worker_id] - call_on_process_init = partial( - process_init_fn, - worker_info=worker_info, - custom_init_fn=self.worker_init_fn, - worker_prefetch_cnt=self.worker_prefetch_cnt, - dispatching_req_queue=dispatching_req_queue, - dispatching_res_queue=dispatching_res_queue, - ) - (process, req_queue, res_queue) = communication.eventloop.CreateProcessForDataPipeline( - ctx, - replicable_dp, - process_name="worker process", - worker_info=worker_info, - call_on_process_init=call_on_process_init, - custom_reset_fn=self.worker_reset_fn, - ) - req_queue.cancel_join_thread() - process.daemon = True - process.start() - self._worker_processes.append((process, req_queue, res_queue)) # These queues are independent - local_datapipe = communication.iter.QueueWrapper( - communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue) - ) - self._worker_datapipes.append(local_datapipe) - - end_datapipe = communication.iter._IterateQueueDataPipes(self._worker_datapipes) # type: ignore[assignment] - self._worker_consumer_datapipe = end_datapipe - - if self.main_prefetch_cnt > 0: - end_datapipe = self._worker_consumer_datapipe.prefetch(self.main_prefetch_cnt) # type: ignore[union-attr] - self._main_prefetch_datapipe = end_datapipe - - # Attach non-replicable DataPipes - if replicable_dps[0] is not datapipe: - graph = replace_dp(graph, replicable_dps[0], end_datapipe) - end_datapipe = datapipe # type: ignore[assignment] - - self._end_datapipe = end_datapipe - assert self._end_datapipe is not None - - return self._end_datapipe # type: ignore[return-value] - - def initialize_iteration( - self, seed_generator: SeedGenerator, iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None - ) -> Optional[Callable[[DataPipe], DataPipe]]: - assert self._end_datapipe is not None - - # Set random seeds for DataPipe that are in the main process (NOT those in worker processes) - # Worker seeds are set in `process_reset_fn` - set_graph_random_seed(self._end_datapipe, seed_generator) - - if self._mp: - if self.main_prefetch_cnt > 0: - # Stop prefetching first - self._main_prefetch_datapipe.reset() # type: ignore[union-attr] - # Send the shared seed to subprocesses - assert self._worker_consumer_datapipe is not None - self._worker_consumer_datapipe.reset_epoch(seed_generator, iter_reset_fn) - # In-process (num_workers == 0) - else: - # Technically speaking, we should call `_process_reset_fn` to reset global RNGs - # for data-related operations. However, it would pollute the state of global RNGs - # (random, torch and numpy), if users have already seeded them in the main process - # TODO(ejguan): This should be fixed by adding a method to isolate global RNGs - pass - return None - - def finalize(self) -> None: - r""" - ``MultiProcessingReadingService`` invalidate states & properly exits all subprocesses. - """ - if self._finalized: - return - self._finalized = True - - # TODO(618): Check if anyone stuck with messages - # Clean up worker processes - if self.num_workers > 0: - self._worker_consumer_datapipe.request_terminate() # type: ignore[union-attr] - for process, req_queue, _ in self._worker_processes: - try: - process.join(default_dl2_worker_join_timeout_in_s) - except TimeoutError: - pass - req_queue.close() - - # Clean up dispatching process - if self._dispatch_process is not None: - try: - self._dispatch_process[0].join(default_dl2_worker_join_timeout_in_s) - except TimeoutError: - pass - for req_queue in self._dispatch_process[1]: - req_queue.close() - - self._worker_processes = [] - self._dispatch_process = None - - def _pause( - self, pause_fn: Optional[Callable[[DataPipe], DataPipe]] = None - ) -> Optional[Callable[[DataPipe], DataPipe]]: - r""" - Pauses DataPipes' activities such as prefetching within main/worker/dispatching processes, - in order to collect state. The provided ``pause_fn`` will be executed in - worker/dispatching processes. - """ - if self.num_workers == 0: - raise RuntimeError( - "If you would like to use `pause` with `MultiProcessingReadingService`, " - "please use more than 0 worker." - ) - assert self._end_datapipe is not None - # Call pause for DataPipes in the main process (e.g. prefetch, fullsync) - dp_list = list_dps(traverse_dps(self._end_datapipe)) - for dp in dp_list: - if hasattr(dp, "pause") and callable(dp.pause): - dp.pause() - self._worker_consumer_datapipe.request_pause(pause_fn) # type: ignore[union-attr] - return None - - def _resume( - self, resume_fn: Optional[Callable[[DataPipe], DataPipe]] = None - ) -> Optional[Callable[[DataPipe], DataPipe]]: - r""" - Resumes DataPipes' activities. This is required to be called after `_pause` before - the DataLoader can keep yielding elements. - """ - if self.num_workers > 0: - self._worker_consumer_datapipe.request_resume(resume_fn) # type: ignore[union-attr] - else: - raise RuntimeError( - "If you would like to use `resume` with `MultiProcessingReadingService`, " - "please use more than 0 worker." - ) - assert self._end_datapipe is not None - # Call resume for DataPipes in the main process (e.g. prefetch, fullsync) - dp_list = list_dps(traverse_dps(self._end_datapipe)) - for dp in dp_list[::-1]: - if hasattr(dp, "resume") and callable(dp.resume): - dp.resume() - return None - - def _limit( - self, num_batches: Optional[int], limit_fn: Optional[Callable[[DataPipe, Optional[int]], DataPipe]] = None - ) -> Optional[Callable[[DataPipe, Optional[int]], DataPipe]]: - r""" - Send limit_fn to worker/dispatching process to set the limit number to the specified DataPipes. - """ - if limit_fn is not None: - # Only propogate limit when dispatching process exists - num_batches = None if self._dispatch_process is None else num_batches - self._worker_consumer_datapipe.request_limit(num_batches, limit_fn) # type: ignore[union-attr] - # TODO: Remove when flexible checkpoint is supported - limit_fn(self._end_datapipe, num_batches) # type: ignore[arg-type] - return None - - -class DistributedReadingService(ReadingServiceInterface): - r""" - ``DistributedReadingSerivce`` handles distributed sharding on the graph of ``DataPipe`` and - guarantee the randomness by sharing the same seed across the distributed processes. - - Args: - timeout: Timeout for operations executed against the process group in seconds. - Default value equals 30 minutes. - """ - - def __init__(self, timeout: int = default_timeout_in_s): - if not dist.is_available(): - raise RuntimeError("Torch Distributed is required to be available") - self._world_size: int = 1 - self._rank: int = 0 - self._datapipe: Optional[DataPipe] = None - self._timeout: int = timeout - self._pg: Optional[dist.ProcessGroup] = None - - def initialize(self, datapipe: DataPipe) -> DataPipe: - r""" - Launches the ``gloo``-backend distributed process group. Carries out distributed sharding - on the graph of ``DataPipe`` and returns the graph attached with a ``FullSyncIterDataPipe`` - at the end. - """ - if not (dist.is_available() and dist.is_initialized()): - raise RuntimeError("Torch Distributed is required to be initialized") - self._world_size = dist.get_world_size() - self._rank = dist.get_rank() - self._pg = dist.new_group(backend="gloo", timeout=timedelta(seconds=self._timeout)) - torch.utils.data.graph_settings.apply_sharding( - datapipe, self._world_size, self._rank, SHARDING_PRIORITIES.DISTRIBUTED - ) - # Only append FullSyncIterDataPipe if it's not presented at the end of the pipeline - if not isinstance(datapipe, FullSync): - datapipe = datapipe.fullsync(self._timeout) - self._datapipe = datapipe - return datapipe - - def initialize_iteration( - self, seed_generator: SeedGenerator, iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None - ) -> Optional[Callable[[DataPipe], DataPipe]]: - r""" - Shares the same seed from rank 0 to other ranks across the distributed processes - and apply the random seed to the ``DataPipe`` graph. - """ - assert self._datapipe is not None - - shared_seed = dist_share_seed(seed_generator.generate_shared_seed(), self._pg) - seed_generator.seed(shared_seed) - seed_generator = seed_generator.spawn(self._rank, inplace=True) - set_graph_random_seed(self._datapipe, seed_generator) - return None - - def finalize(self) -> None: - r""" - Clean up the distributed process group. - """ - if self._pg is not None: - dist.destroy_process_group(self._pg) - self._pg = None - - -class SequentialReadingService(CheckpointableReadingServiceInterface): - def __init__(self, *reading_services): - self.reading_services = reading_services - - # Sequential Order - def initialize(self, datapipe: DataPipe) -> DataPipe: - for rs in self.reading_services: - datapipe = rs.initialize(datapipe) - return datapipe - - # Reversed Order - def finalize(self) -> None: - for rs in reversed(self.reading_services): - rs.finalize() - - # Sequential Order - def initialize_iteration( - self, seed_generator: SeedGenerator, iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None - ) -> Optional[Callable[[DataPipe], DataPipe]]: - chained_iter_reset_fn = iter_reset_fn - for rs in self.reading_services: - chained_iter_reset_fn = rs.initialize_iteration( - seed_generator=seed_generator, iter_reset_fn=chained_iter_reset_fn - ) - return chained_iter_reset_fn - - # Reversed Order - def finalize_iteration(self) -> None: - for rs in reversed(self.reading_services): - rs.finalize_iteration() - - # Sequential Order - def checkpoint(self) -> bytes: - states = [] - for rs in self.reading_services: - if hasattr(rs, "checkpoint") and callable(rs.checkpoint): - states.append(rs.checkpoint()) - else: - warnings.warn(f"{rs} doesn't support `checkpoint`, skipping...") - states.append(b"") - return pickle.dumps(states) - - # Sequential Order, to align with initialize - def restore(self, datapipe, serialized_state: bytes) -> DataPipe: - states = pickle.loads(serialized_state) - assert len(states) == len(self.reading_services) - for rs, state in zip(self.reading_services, states): - if hasattr(rs, "restore") and callable(rs.restore): - datapipe = rs.restore(datapipe, state) - else: - warnings.warn(f"{rs} doesn't support `restore` from state, initialize from scratch") - datapipe = rs.initialize(datapipe) - return datapipe - - def _pause( - self, pause_fn: Optional[Callable[[DataPipe], DataPipe]] = None - ) -> Optional[Callable[[DataPipe], DataPipe]]: - r""" - Pause the ``DataPipe`` graph defined in all ``ReadingServices``. For example of - ``MultiProcessingReadingService`` would accept a ``pause_fn`` from a prior ``ReadingService`` - to execute custom pause logic within worker/dispatching processes. - """ - for rs in self.reading_services: - if hasattr(rs, "_pause"): - pause_fn = rs._pause(pause_fn) - return pause_fn - - def _resume( - self, resume_fn: Optional[Callable[[DataPipe], DataPipe]] = None - ) -> Optional[Callable[[DataPipe], DataPipe]]: - r""" - Resume the ``DataPipe`` graph defined in all ``ReadingServices``. For example of - ``MultiProcessingReadingService`` would accept a ``resume_fn`` from a prior ``ReadingService`` - to execute custom resume logic within worker/dispatching processes. - """ - for rs in self.reading_services: - if hasattr(rs, "_resume"): - resume_fn = rs._resume(resume_fn) - return resume_fn - - def _limit( - self, num_batches: Optional[int], limit_fn: Optional[Callable[[DataPipe, Optional[int]], DataPipe]] = None - ) -> Optional[Callable[[DataPipe, Optional[int]], DataPipe]]: - r""" - Limit the ``DataPipe`` graph defined in all ``ReadingServices``. For example of - ``MultiProcessingReadingService`` would accept a ``limit_fn`` from a prior ``ReadingService`` - to set limit to ``DataPipes` within worker/dispatching processes. - """ - for rs in self.reading_services: - if hasattr(rs, "_limit"): - limit_fn = rs._limit(num_batches, limit_fn) - return limit_fn diff --git a/torchdata/dataloader2/shuffle_spec.py b/torchdata/dataloader2/shuffle_spec.py deleted file mode 100644 index dca09afce..000000000 --- a/torchdata/dataloader2/shuffle_spec.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -import abc - - -class ShuffleSpec(abc.ABC): - """Defines a shuffle specification.""" diff --git a/torchdata/dataloader2/utils/__init__.py b/torchdata/dataloader2/utils/__init__.py deleted file mode 100644 index 79308ea4c..000000000 --- a/torchdata/dataloader2/utils/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from torchdata.dataloader2.utils.worker import process_init_fn, process_reset_fn, WorkerInfo - - -__all__ = [ - "WorkerInfo", - "process_init_fn", - "process_reset_fn", -] - -assert __all__ == sorted(__all__) diff --git a/torchdata/dataloader2/utils/dispatch.py b/torchdata/dataloader2/utils/dispatch.py deleted file mode 100644 index 10cdd8964..000000000 --- a/torchdata/dataloader2/utils/dispatch.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# from multiprocessing.queues import Queue -from typing import Dict, List, Optional, Set - -from torchdata.dataloader2.graph import DataPipe, DataPipeGraph, list_dps, traverse_dps - -from torchdata.datapipes.iter import IterDataPipe, ShardingRoundRobinDispatcher - - -__all__ = ["_DummyIterDataPipe", "find_lca_round_robin_sharding_dp", "find_non_dispatching_branches"] - - -class _DummyIterDataPipe(IterDataPipe): - r""" - This DataPipe is a placeholder to be replaced by the ``QueueWrapper`` - that connects the worker process for non-replicable DataPipe. - """ - # TODO: Revert `_DummyIterDataPipe` as the placeholder when `_SerializationWrapper` - # can handle mp.Queue. See: https://github.com/pytorch/data/issues/934 - # req_queue: Queue - # res_queue: Queue - - -def find_lca_round_robin_sharding_dp(graph: DataPipeGraph) -> Optional[DataPipe]: - r""" - Given the graph of DataPipe generated by ``traverse_dps`` function, return the - DataPipe instance that is the lowest common ancestor of all ``sharding_round_robin_dispatch`` DataPipes - - Note: - - If multiple branches share the same source DataPipe and any branch contains a - non-replicable DataPipe, the lowest common ancestor of all branches is returned. - - If there is any non-replicable DataPipe in a circular-referenced (sub)graph, the - whole (sub)graph is treated as non-replicable and the last DataPipe is returned. - """ - assert len(graph) == 1, "DataPipeGraph should only contain a single output DataPipe" - - def _is_round_robin_sharding(dp: DataPipe) -> bool: - return type(dp) == ShardingRoundRobinDispatcher - - dps = list_dps(graph) - non_replicable_dps: Set[int] = set() - for dp in dps: - # Skip when it has been visited - if id(dp) in non_replicable_dps: - continue - if _is_round_robin_sharding(dp): - parent_dps = list_dps(traverse_dps(dp)) - for par_dp in parent_dps: - non_replicable_dps.add(id(par_dp)) - - root_dp_id = list(graph.keys())[0] - root_dp, root_graph = graph[root_dp_id] - - lca_for_subgraph: Dict[int, Optional[DataPipe]] = {} - - def _get_lca_from_graph(root_dp_id, root_dp, root_graph) -> Optional[DataPipe]: # pyre-ignore - if root_dp_id in lca_for_subgraph: - return lca_for_subgraph[root_dp_id] - if root_dp_id in non_replicable_dps: - lca_for_subgraph[root_dp_id] = root_dp - return root_dp - lca_for_subgraph[root_dp_id] = None - non_replicable_parents = [] - for dp_id, (dp, src_graph) in root_graph.items(): - res = _get_lca_from_graph(dp_id, dp, src_graph) - if res is not None: - non_replicable_parents.append(res) - # `root_dp` becomes the lowest common ancestor of this branch, - # if there are more than one unique non-replicable DataPipe prior to it. - if len(non_replicable_parents) > 0: - # One unique non-replicable DataPipe - if len(non_replicable_parents) == 1 or all( - dp == non_replicable_parents[0] for dp in non_replicable_parents - ): - lca_for_subgraph[root_dp_id] = non_replicable_parents[0] - # Multiple non-replicable DataPipes - else: - lca_for_subgraph[root_dp_id] = root_dp - return lca_for_subgraph[root_dp_id] - - return _get_lca_from_graph(root_dp_id, root_dp, root_graph) - - -def find_non_dispatching_branches(graph: DataPipeGraph) -> List[DataPipe]: - r""" - Given the graph of DataPipe generated by ``traverse_dps`` function, return the DataPipe - instances that don't have ``_DummyIterDataPipe`` (dipatching process) in the prior graph. - """ - assert len(graph) == 1, "DataPipeGraph should only contain a single output DataPipe" - - dps: List[DataPipe] = [] - non_dispatching_branches: Dict[int, bool] = {} - - root_dp_id = list(graph.keys())[0] - root_dp, root_graph = graph[root_dp_id] - - def _is_non_dispatching(root_dp_id, root_dp, root_graph) -> bool: # pyre-ignore - if root_dp_id in non_dispatching_branches: - return non_dispatching_branches[root_dp_id] - if type(root_dp) == _DummyIterDataPipe: - non_dispatching_branches[root_dp_id] = False - return False - non_dispatching_branches[root_dp_id] = True - for dp_id, (dp, src_graph) in root_graph.items(): - if not _is_non_dispatching(dp_id, dp, src_graph): - non_dispatching_branches[root_dp_id] = False - # Do not break to go through all children - if not non_dispatching_branches[root_dp_id]: - # All children should have been added to non_dispatching_branches already - for dp_id, (dp, _) in root_graph.items(): - if non_dispatching_branches[dp_id]: - dps.append(dp) - return non_dispatching_branches[root_dp_id] - - if _is_non_dispatching(root_dp_id, root_dp, root_graph): - dps.append(root_dp) - - return dps diff --git a/torchdata/dataloader2/utils/worker.py b/torchdata/dataloader2/utils/worker.py deleted file mode 100644 index 2a5a50ab2..000000000 --- a/torchdata/dataloader2/utils/worker.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import random - -from dataclasses import dataclass -from multiprocessing.queues import Queue -from typing import Callable, Optional - -import torch - -from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES - -from torchdata.dataloader2 import communication -from torchdata.dataloader2.graph import ( - DataPipe, - find_dps, - list_dps, - replace_dp, - set_datapipes_seed, - set_graph_random_seed, - traverse_dps, -) -from torchdata.dataloader2.random import SeedGenerator -from torchdata.dataloader2.utils.dispatch import _DummyIterDataPipe, find_non_dispatching_branches -from torchdata.datapipes.iter import IterDataPipe -from torchdata.datapipes.map import MapDataPipe - -try: - import numpy - - HAS_NUMPY = True -except ModuleNotFoundError: - HAS_NUMPY = False - - -@dataclass(frozen=True) -class WorkerInfo: - r""" - Message class for keeping track of worker information. - - Args: - num_workers (int): Total number of worker processes - worker_id (int): Worker ID for the current worker process - """ - num_workers: int - worker_id: int - - -def process_init_fn( - datapipe: DataPipe, - worker_info: WorkerInfo, - custom_init_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]] = None, - worker_prefetch_cnt: int = 0, - dispatching_req_queue: Optional[Queue] = None, - dispatching_res_queue: Optional[Queue] = None, -) -> DataPipe: - r""" - Based on the worker information, shard the ``DataPipe`` graph dynamically. - """ - # Find if there is non-replicable DataPipe - graph = traverse_dps(datapipe) - non_replicable_dp = find_dps(graph, _DummyIterDataPipe) # type: ignore - - # There are two cases for DataPipe graph in terms of mp sharding: - # 1) All DataPipes are replicable, apply mp sharding to the whole graph - if len(non_replicable_dp) == 0: - torch.utils.data.graph_settings.apply_sharding( - datapipe, worker_info.num_workers, worker_info.worker_id, SHARDING_PRIORITIES.MULTIPROCESSING - ) - assert dispatching_req_queue is None and dispatching_res_queue is None - # 2) There is non-replicable DataPipe. Since we have replaced the lowest common - # ancestor by a `_DummyIterDataPipe`, we would only apply mp sharding - # to replicable branches that don't have `_DummyIterDataPipe`. - else: - assert len(non_replicable_dp) == 1 - assert not (dispatching_req_queue is None and dispatching_res_queue is None) - dispatching_req_queue.cancel_join_thread() # type: ignore[union-attr] - non_dispatching_branches = find_non_dispatching_branches(graph) - for dp in non_dispatching_branches: - torch.utils.data.graph_settings.apply_sharding( - dp, worker_info.num_workers, worker_info.worker_id, SHARDING_PRIORITIES.MULTIPROCESSING - ) - - queue_wrapper = communication.iter.QueueWrapper( - communication.protocol.IterDataPipeQueueProtocolClient(dispatching_req_queue, dispatching_res_queue) - ) - dispatch_process_dp = communication.iter._IterateQueueDataPipes([queue_wrapper]) - graph = replace_dp(graph, non_replicable_dp[0], dispatch_process_dp) - datapipe = list(graph.values())[0][0] - - if custom_init_fn is not None: - datapipe = custom_init_fn(datapipe, worker_info) - assert isinstance(datapipe, (IterDataPipe, MapDataPipe)) - - if worker_prefetch_cnt > 0: - datapipe = datapipe.prefetch(worker_prefetch_cnt) - - return datapipe - - -def _set_global_random_state(seed_generator: SeedGenerator, distributed_shared: bool = False) -> None: - py_seed = seed_generator.generate_shared_seed() if distributed_shared else seed_generator.generate_seed() - random.seed(py_seed) - - torch_seed = seed_generator.generate_shared_seed() if distributed_shared else seed_generator.generate_seed() - torch.manual_seed(torch_seed) - - if HAS_NUMPY: - # Convert uint64 to uint32 for Numpy - np_seed = seed_generator.generate_shared_seed() if distributed_shared else seed_generator.generate_seed() - np_seed = np_seed >> 32 - numpy.random.seed(np_seed) - - -def process_reset_fn( - datapipe: DataPipe, - worker_info: WorkerInfo, - seed_generator: SeedGenerator, - distributed_shared_seed: bool = False, - iter_reset_fn: Optional[Callable[[DataPipe], DataPipe]] = None, - custom_reset_fn: Optional[Callable[[DataPipe, WorkerInfo, SeedGenerator], DataPipe]] = None, -) -> DataPipe: - r""" - Based on the distributed shared random seed and worker id, this function is used to - reset the random state of the ``DataPipe`` graph and the global random states for ``torch``, - ``random`` and ``numpy``. - """ - # Set global random states - _set_global_random_state(seed_generator, distributed_shared=distributed_shared_seed) - - if distributed_shared_seed: - graph = traverse_dps(datapipe) - dps = list_dps(graph) - set_datapipes_seed(dps, seed_generator=seed_generator, distributed_shared=distributed_shared_seed) - else: - set_graph_random_seed(datapipe, seed_generator) - - if iter_reset_fn is not None: - datapipe = iter_reset_fn(datapipe) - assert isinstance(datapipe, (IterDataPipe, MapDataPipe)) - - if custom_reset_fn is not None: - datapipe = custom_reset_fn(datapipe, worker_info, seed_generator) - assert isinstance(datapipe, (IterDataPipe, MapDataPipe)) - - return datapipe diff --git a/torchdata/datapipes/__init__.py b/torchdata/datapipes/__init__.py deleted file mode 100644 index 8739bfad2..000000000 --- a/torchdata/datapipes/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torch.utils.data import DataChunk, functional_datapipe - -from torchdata import _extension # noqa: F401 - -from . import iter, map, utils - -__all__ = ["DataChunk", "functional_datapipe", "iter", "map", "utils"] - - -from torchdata import deprecation_warning - -deprecation_warning() diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py deleted file mode 100644 index 683200272..000000000 --- a/torchdata/datapipes/iter/__init__.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -############################################################################### -# Reference From PyTorch Core -############################################################################### -from torch.utils.data import IterDataPipe -from torch.utils.data.datapipes.iter import ( - Batcher, - Collator, - Concater, - Demultiplexer, - FileLister, - FileOpener, - Filter, - Forker, - Grouper, - IterableWrapper, - Mapper, - Multiplexer, - RoutedDecoder, - Sampler, - ShardingFilter, - Shuffler, - StreamReader, - UnBatcher, - Zipper, -) -from torchdata.datapipes.iter.load.aisio import ( - AISFileListerIterDataPipe as AISFileLister, - AISFileLoaderIterDataPipe as AISFileLoader, -) - -############################################################################### -# TorchData -############################################################################### -from torchdata.datapipes.iter.load.fsspec import ( - FSSpecFileListerIterDataPipe as FSSpecFileLister, - FSSpecFileOpenerIterDataPipe as FSSpecFileOpener, - FSSpecSaverIterDataPipe as FSSpecSaver, -) - -from torchdata.datapipes.iter.load.huggingface import HuggingFaceHubReaderIterDataPipe as HuggingFaceHubReader - -from torchdata.datapipes.iter.load.iopath import ( - IoPathFileListerIterDataPipe as IoPathFileLister, - IoPathFileOpenerIterDataPipe as IoPathFileOpener, - IoPathSaverIterDataPipe as IoPathSaver, -) - -from torchdata.datapipes.iter.load.online import ( - GDriveReaderDataPipe as GDriveReader, - HTTPReaderIterDataPipe as HttpReader, - OnlineReaderIterDataPipe as OnlineReader, -) -from torchdata.datapipes.iter.load.s3io import ( - S3FileListerIterDataPipe as S3FileLister, - S3FileLoaderIterDataPipe as S3FileLoader, -) -from torchdata.datapipes.iter.transform.bucketbatcher import ( - BucketBatcherIterDataPipe as BucketBatcher, - InBatchShufflerIterDataPipe as InBatchShuffler, - MaxTokenBucketizerIterDataPipe as MaxTokenBucketizer, -) -from torchdata.datapipes.iter.transform.callable import ( - BatchAsyncMapperIterDataPipe as BatchAsyncMapper, - BatchMapperIterDataPipe as BatchMapper, - DropperIterDataPipe as Dropper, - FlatMapperIterDataPipe as FlatMapper, - FlattenIterDataPipe as Flattener, - ShuffledFlatMapperIterDataPipe as ShuffledFlatMapper, - SliceIterDataPipe as Slicer, - ThreadPoolMapperIterDataPipe as ThreadPoolMapper, -) -from torchdata.datapipes.iter.util.bz2fileloader import Bz2FileLoaderIterDataPipe as Bz2FileLoader -from torchdata.datapipes.iter.util.cacheholder import ( - EndOnDiskCacheHolderIterDataPipe as EndOnDiskCacheHolder, - InMemoryCacheHolderIterDataPipe as InMemoryCacheHolder, - OnDiskCacheHolderIterDataPipe as OnDiskCacheHolder, -) -from torchdata.datapipes.iter.util.combining import ( - IterKeyZipperIterDataPipe as IterKeyZipper, - MapKeyZipperIterDataPipe as MapKeyZipper, - RoundRobinDemultiplexerIterDataPipe as RoundRobinDemultiplexer, - UnZipperIterDataPipe as UnZipper, -) -from torchdata.datapipes.iter.util.cycler import CyclerIterDataPipe as Cycler, RepeaterIterDataPipe as Repeater -from torchdata.datapipes.iter.util.dataframemaker import ( - DataFrameMakerIterDataPipe as DataFrameMaker, - ParquetDFLoaderIterDataPipe as ParquetDataFrameLoader, -) -from torchdata.datapipes.iter.util.decompressor import ( - DecompressorIterDataPipe as Decompressor, - ExtractorIterDataPipe as Extractor, -) -from torchdata.datapipes.iter.util.distributed import FullSyncIterDataPipe as FullSync -from torchdata.datapipes.iter.util.hashchecker import HashCheckerIterDataPipe as HashChecker -from torchdata.datapipes.iter.util.header import HeaderIterDataPipe as Header, LengthSetterIterDataPipe as LengthSetter -from torchdata.datapipes.iter.util.indexadder import ( - EnumeratorIterDataPipe as Enumerator, - IndexAdderIterDataPipe as IndexAdder, -) -from torchdata.datapipes.iter.util.jsonparser import JsonParserIterDataPipe as JsonParser -from torchdata.datapipes.iter.util.mux_longest import MultiplexerLongestIterDataPipe as MultiplexerLongest -from torchdata.datapipes.iter.util.paragraphaggregator import ParagraphAggregatorIterDataPipe as ParagraphAggregator -from torchdata.datapipes.iter.util.plain_text_reader import ( - CSVDictParserIterDataPipe as CSVDictParser, - CSVParserIterDataPipe as CSVParser, - LineReaderIterDataPipe as LineReader, -) -from torchdata.datapipes.iter.util.prefetcher import ( - PinMemoryIterDataPipe as PinMemory, - PrefetcherIterDataPipe as Prefetcher, -) -from torchdata.datapipes.iter.util.randomsplitter import RandomSplitterIterDataPipe as RandomSplitter -from torchdata.datapipes.iter.util.rararchiveloader import RarArchiveLoaderIterDataPipe as RarArchiveLoader -from torchdata.datapipes.iter.util.rows2columnar import Rows2ColumnarIterDataPipe as Rows2Columnar -from torchdata.datapipes.iter.util.samplemultiplexer import SampleMultiplexerDataPipe as SampleMultiplexer -from torchdata.datapipes.iter.util.saver import SaverIterDataPipe as Saver -from torchdata.datapipes.iter.util.shardexpander import ShardExpanderIterDataPipe as ShardExpander -from torchdata.datapipes.iter.util.sharding import ( - ShardingRoundRobinDispatcherIterDataPipe as ShardingRoundRobinDispatcher, -) -from torchdata.datapipes.iter.util.tararchiveloader import TarArchiveLoaderIterDataPipe as TarArchiveLoader -from torchdata.datapipes.iter.util.tfrecordloader import ( - TFRecordExample, - TFRecordExampleSpec, - TFRecordLoaderIterDataPipe as TFRecordLoader, -) -from torchdata.datapipes.iter.util.webdataset import WebDatasetIterDataPipe as WebDataset -from torchdata.datapipes.iter.util.xzfileloader import XzFileLoaderIterDataPipe as XzFileLoader -from torchdata.datapipes.iter.util.zip_longest import ZipperLongestIterDataPipe as ZipperLongest -from torchdata.datapipes.iter.util.ziparchiveloader import ZipArchiveLoaderIterDataPipe as ZipArchiveLoader -from torchdata.datapipes.map.util.converter import MapToIterConverterIterDataPipe as MapToIterConverter - -__all__ = [ - "AISFileLister", - "AISFileLoader", - "BatchAsyncMapper", - "BatchMapper", - "Batcher", - "BucketBatcher", - "Bz2FileLoader", - "CSVDictParser", - "CSVParser", - "Collator", - "Concater", - "Cycler", - "DataFrameMaker", - "Decompressor", - "Demultiplexer", - "Dropper", - "EndOnDiskCacheHolder", - "Enumerator", - "Extractor", - "FSSpecFileLister", - "FSSpecFileOpener", - "FSSpecSaver", - "FileLister", - "FileOpener", - "Filter", - "FlatMapper", - "Flattener", - "Forker", - "FullSync", - "GDriveReader", - "Grouper", - "HashChecker", - "Header", - "HttpReader", - "HuggingFaceHubReader", - "InBatchShuffler", - "InMemoryCacheHolder", - "IndexAdder", - "IoPathFileLister", - "IoPathFileOpener", - "IoPathSaver", - "IterDataPipe", - "IterKeyZipper", - "IterableWrapper", - "JsonParser", - "LengthSetter", - "LineReader", - "MapKeyZipper", - "MapToIterConverter", - "Mapper", - "MaxTokenBucketizer", - "Multiplexer", - "MultiplexerLongest", - "OnDiskCacheHolder", - "OnlineReader", - "ParagraphAggregator", - "ParquetDataFrameLoader", - "PinMemory", - "Prefetcher", - "RandomSplitter", - "RarArchiveLoader", - "Repeater", - "RoundRobinDemultiplexer", - "RoutedDecoder", - "Rows2Columnar", - "S3FileLister", - "S3FileLoader", - "SampleMultiplexer", - "Sampler", - "Saver", - "ShardExpander", - "ShardingFilter", - "ShardingRoundRobinDispatcher", - "ShuffledFlatMapper", - "Shuffler", - "Slicer", - "StreamReader", - "TFRecordLoader", - "TarArchiveLoader", - "ThreadPoolMapper", - "UnBatcher", - "UnZipper", - "WebDataset", - "XzFileLoader", - "ZipArchiveLoader", - "Zipper", - "ZipperLongest", -] - -# Please keep this list sorted -assert __all__ == sorted(__all__) diff --git a/torchdata/datapipes/iter/__init__.pyi.in b/torchdata/datapipes/iter/__init__.pyi.in deleted file mode 100644 index 773ba85b1..000000000 --- a/torchdata/datapipes/iter/__init__.pyi.in +++ /dev/null @@ -1,46 +0,0 @@ -${init_base} - -######################################################################################################################## -# The part below is generated by parsing through the Python files where IterDataPipes are defined. -# This base template ("__init__.pyi.in") is generated from mypy stubgen with minimal editing for code injection -# The output file will be "__init__.pyi". The generation function is called by "setup.py". -# Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other -# classes/objects here, even though we are not injecting extra code into them at the moment. - -from .util.decompressor import CompressionType -from torchdata._constants import default_timeout_in_s -from torchdata.datapipes.map import MapDataPipe -from torchdata.datapipes.utils import pin_memory_fn -from torch.utils.data import DataChunk, IterableDataset, default_collate -from torch.utils.data.datapipes._typing import _DataPipeMeta -from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES - -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, TypeVar, Union, Hashable - -try: - import torcharrow -except ImportError: - torcharrow = None - -T = TypeVar("T") -T_co = TypeVar("T_co", covariant=True) -ForkIterDataPipeCopyOptions = Literal["shallow", "deep"] - -class IterDataPipe(IterableDataset[T_co], metaclass=_DataPipeMeta): - functions: Dict[str, Callable] = ... - reduce_ex_hook: Optional[Callable] = ... - getstate_hook: Optional[Callable] = ... - def __getattr__(self, attribute_name: Any): ... - @classmethod - def register_function(cls, function_name: Any, function: Any) -> None: ... - @classmethod - def register_datapipe_as_function( - cls, function_name: Any, cls_to_register: Any, enable_df_api_tracing: bool = ... - ): ... - def __getstate__(self): ... - def __reduce_ex__(self, *args: Any, **kwargs: Any): ... - @classmethod - def set_getstate_hook(cls, hook_fn: Any) -> None: ... - @classmethod - def set_reduce_ex_hook(cls, hook_fn: Any) -> None: ... - ${IterDataPipeMethods} diff --git a/torchdata/datapipes/iter/load/README.md b/torchdata/datapipes/iter/load/README.md deleted file mode 100644 index 5fd020924..000000000 --- a/torchdata/datapipes/iter/load/README.md +++ /dev/null @@ -1,88 +0,0 @@ -# Iterable Datapipes - -## S3 IO Datapipe Documentation - -**WARNING**: S3 IO Datapipes have been deprecated. Consider using -[S3 Connector for PyTorch](https://github.com/awslabs/s3-connector-for-pytorch). - -### Build from Source - -`ninja` is required to link PyThon implementation to C++ source code. - -```bash -conda install ninja -``` - -S3 IO datapipes are included when building with flag `BUILD_S3=1`. The following commands can build `torchdata` from -source with S3 datapipes. - -```bash -BUILD_S3=1 pip install . -``` - -We also offer nightly and official (>=0.4.0) TorchData releases integrated with `AWSSDK` on the most of platforms. -Please check the [link](https://github.com/pytorch/data/tree/main/packaging#awssdk) for the list of supported platforms -with the pre-assembled binaries. - -If you'd like to use customized installations of `pybind11` or `aws-sdk-cpp`, you may set the following flags when -building from source. - -``` -USE_SYSTEM_PYBIND11=1 -USE_SYSTEM_AWS_SDK_CPP=1 -USE_SYSTEM_LIBS=1 # uses both pre-installed pybind11 and aws-sdk-cpp -``` - -Note: refer to the official documentation for detailed installtion instructions of -[aws-sdk-cpp](https://github.com/aws/aws-sdk-cpp). - -### Example - -Please refer to the documentation: - -- [`S3FileLister`](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.S3FileLister.html#s3filelister) -- [`S3FileLoader`](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.S3FileLoader.html#s3fileloader) - -### Note - -Your environment must be properly configured for AWS to use the DataPipes. It is possible to do that via the AWS Command -Line Interface (`aws configure`). - -It's recommended to set up a detailed configuration file with the `AWS_CONFIG_FILE` environment variable. The following -environment variables are also parsed: `HOME`, `S3_USE_HTTPS`, `S3_VERIFY_SSL`, `S3_ENDPOINT_URL`, `AWS_REGION` (would -be overwritten by the `region` variable). - -### Troubleshooting - -If you get `Access Denied` or no response, it's very possibly a -[wrong region configuration](https://github.com/aws/aws-sdk-cpp/issues/1211) or an -[accessing issue with `aws-sdk-cpp`](https://aws.amazon.com/premiumsupport/knowledge-center/s3-access-denied-aws-sdk/). - -## AIStore IO Datapipe - -[AIStore](https://github.com/NVIDIA/aistore) (AIS for short) is a highly available lightweight object storage system -that specifically focuses on petascale deep learning. As a reliable redundant storage, AIS supports n-way mirroring and -erasure coding. But it is not purely – or not only – a storage system: it’ll shuffle user datasets and run custom -extract-transform-load workloads. - -AIS is an elastic cluster that can grow and shrink at runtime and can be ad-hoc deployed, with or without Kubernetes, -anywhere from a single Linux machine to a bare-metal cluster of any size. - -AIS fully supports Amazon S3, Google Cloud, and Microsoft Azure backends, providing a unified namespace across multiple -connected backends and/or other AIS clusters, and [more](https://github.com/NVIDIA/aistore#features). Getting started -with AIS will take only a few minutes (prerequisites boil down to having a Linux with a disk) and can be done either by -running a prebuilt all-in-one docker image or directly from the open-source. - -### Dependency - -The `AISFileLister` and `AISFileLoader` under [`aisio.py`](/torchdata/datapipes/iter/load/aisio.py) internally use the -[Python SDK](https://github.com/NVIDIA/aistore/tree/master/sdk/python) for AIStore. - -Run `pip install aistore` or `conda install aistore` to install the [python package](https://pypi.org/project/aistore/). - -### Example - -Please refer to the documentation: - -- [`AISFileLister`](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.AISFileLister.html#aisfilelister) -- [`AISFileLoader`](https://pytorch.org/data/main/generated/torchdata.datapipes.iter.AISFileLoader.html#aisfileloader) diff --git a/torchdata/datapipes/iter/load/__init__.py b/torchdata/datapipes/iter/load/__init__.py deleted file mode 100644 index 2e41cd717..000000000 --- a/torchdata/datapipes/iter/load/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torchdata/datapipes/iter/load/aisio.py b/torchdata/datapipes/iter/load/aisio.py deleted file mode 100644 index c12af5539..000000000 --- a/torchdata/datapipes/iter/load/aisio.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Iterator, Tuple - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe -from torchdata.datapipes.utils import StreamWrapper - -try: - from aistore.client import Client - from aistore.pytorch.utils import parse_url, unparse_url - - HAS_AIS = True - -except ImportError: - HAS_AIS = False - -try: - import aistore - from packaging.version import parse as parse_version - - AIS_VERSION_CHECK = parse_version(aistore.__version__) >= parse_version("1.0.2") - -except (ImportError, AttributeError): - AIS_VERSION_CHECK = False - - -def _assert_aistore() -> None: - if not HAS_AIS: - raise ModuleNotFoundError( - "Package `aistore` (>=1.0.2) is required to be installed to use this datapipe." - "Please run `pip install --upgrade aistore` or `conda install aistore` to install the package" - "For more info visit: https://github.com/NVIDIA/aistore/blob/master/sdk/python/" - ) - - -def _assert_aistore_version() -> None: - if not AIS_VERSION_CHECK: - raise ImportError( - "AIStore version >= 1.0.2 required" - "Please run `pip install --upgrade aistore` or `conda update aistore` to install the latest version" - ) - - -@functional_datapipe("list_files_by_ais") -class AISFileListerIterDataPipe(IterDataPipe[str]): - """ - Iterable Datapipe that lists files from the AIStore backends with the given URL prefixes - (functional name: ``list_files_by_ais``). - Acceptable prefixes include but not limited to - `ais://bucket-name`, `ais://bucket-name/` - - Note: - - This function also supports files from multiple backends (`aws://..`, `gcp://..`, `azure://..`, etc) - - Input must be a list and direct URLs are not supported. - - length is -1 by default, all calls to len() are invalid as - not all items are iterated at the start. - - This internally uses AIStore Python SDK. - - Args: - source_datapipe(IterDataPipe[str]): a DataPipe that contains URLs/URL - prefixes to objects on AIS - url(str): AIStore endpoint - length(int): length of the datapipe - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper, AISFileLister - >>> ais_prefixes = IterableWrapper(['gcp://bucket-name/folder/', 'aws:bucket-name/folder/', 'ais://bucket-name/folder/', ...]) - >>> dp_ais_urls = AISFileLister(url='localhost:8080', source_datapipe=ais_prefixes) - >>> for url in dp_ais_urls: - ... pass - >>> # Functional API - >>> dp_ais_urls = ais_prefixes.list_files_by_ais(url='localhost:8080') - >>> for url in dp_ais_urls: - ... pass - """ - - def __init__(self, source_datapipe: IterDataPipe[str], url: str, length: int = -1) -> None: - _assert_aistore() - _assert_aistore_version() - self.source_datapipe: IterDataPipe[str] = source_datapipe - self.length: int = length - self.client = Client(url) - - def __iter__(self) -> Iterator[str]: - for prefix in self.source_datapipe: - provider, bck_name, prefix = parse_url(prefix) - obj_iter = self.client.bucket(bck_name, provider).list_objects_iter(prefix=prefix) - for entry in obj_iter: - yield unparse_url(provider=provider, bck_name=bck_name, obj_name=entry.name) - - def __len__(self) -> int: - if self.length == -1: - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") - return self.length - - -@functional_datapipe("load_files_by_ais") -class AISFileLoaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): - """ - Iterable DataPipe that loads files from AIStore with the given URLs (functional name: ``load_files_by_ais``). - Iterates all files in BytesIO format and returns a tuple (url, BytesIO). - - Note: - - This function also supports files from multiple backends (`aws://..`, `gcp://..`, `azure://..`, etc) - - Input must be a list and direct URLs are not supported. - - This internally uses AIStore Python SDK. - - Args: - source_datapipe(IterDataPipe[str]): a DataPipe that contains URLs/URL prefixes to objects - url(str): AIStore endpoint - length(int): length of the datapipe - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper, AISFileLister,AISFileLoader - >>> ais_prefixes = IterableWrapper(['gcp://bucket-name/folder/', 'aws:bucket-name/folder/', 'ais://bucket-name/folder/', ...]) - >>> dp_ais_urls = AISFileLister(url='localhost:8080', source_datapipe=ais_prefixes) - >>> dp_cloud_files = AISFileLoader(url='localhost:8080', source_datapipe=dp_ais_urls) - >>> for url, file in dp_cloud_files: - ... pass - >>> # Functional API - >>> dp_cloud_files = dp_ais_urls.load_files_by_ais(url='localhost:8080') - >>> for url, file in dp_cloud_files: - ... pass - """ - - def __init__(self, source_datapipe: IterDataPipe[str], url: str, length: int = -1) -> None: - _assert_aistore() - _assert_aistore_version() - self.source_datapipe: IterDataPipe[str] = source_datapipe - self.length = length - self.client = Client(url) - - def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: - for url in self.source_datapipe: - provider, bck_name, obj_name = parse_url(url) - yield url, StreamWrapper( - self.client.bucket(bck_name=bck_name, provider=provider).object(obj_name=obj_name).get().raw() - ) - - def __len__(self) -> int: - return len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/load/fsspec.py b/torchdata/datapipes/iter/load/fsspec.py deleted file mode 100644 index 39a875a94..000000000 --- a/torchdata/datapipes/iter/load/fsspec.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import posixpath - -from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple, Union - -from torch.utils.data.datapipes.utils.common import match_masks - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterableWrapper, IterDataPipe -from torchdata.datapipes.utils import StreamWrapper - -try: - import fsspec - -except ImportError: - fsspec = None - -U = Union[bytes, bytearray, str] - - -def _assert_fsspec() -> None: - if fsspec is None: - raise ModuleNotFoundError( - "Package `fsspec` is required to be installed to use this datapipe." - "Please use `pip install fsspec` or `conda install -c conda-forge fsspec`" - "to install the package" - ) - - -@functional_datapipe("list_files_by_fsspec") -class FSSpecFileListerIterDataPipe(IterDataPipe[str]): - r""" - Lists the contents of the directory at the provided ``root`` pathname or URL, - and yields the full pathname or URL for each file within the - directory (functional name: ``list_files_by_fsspec``). - - Args: - root: The root `fsspec` path directory or list of path directories to list files from - masks: Unix style filter string or string list for filtering file name(s) - kwargs: Extra options that make sense to a particular storage connection, - e.g. host, port, username, password, etc. - - Example: - - .. testsetup:: - - dir_path = "path" - - .. testcode:: - - from torchdata.datapipes.iter import FSSpecFileLister - - datapipe = FSSpecFileLister(root=dir_path) - """ - - def __init__( - self, - root: Union[str, Sequence[str], IterDataPipe], - masks: Union[str, List[str]] = "", - **kwargs, - ) -> None: - _assert_fsspec() - - if isinstance(root, str): - root = [ - root, - ] - if not isinstance(root, IterDataPipe): - self.datapipe: IterDataPipe = IterableWrapper(root) # type: ignore[assignment] - else: - self.datapipe = root - self.masks = masks - self.kwargs_for_connection = kwargs - - def __iter__(self) -> Iterator[str]: - for root in self.datapipe: - fs, path = fsspec.core.url_to_fs(root, **self.kwargs_for_connection) - - if isinstance(fs.protocol, str): - protocol_list = [fs.protocol] - else: - protocol_list = fs.protocol - - # fspec.core.url_to_fs will return "abfs" for both, "az://" and "abfs://" urls - if "abfs" in protocol_list: - protocol_list.append("az") - - is_local = fs.protocol == "file" or not any(root.startswith(protocol) for protocol in protocol_list) - if fs.isfile(path): - yield root - else: - for file_name in fs.ls(path, detail=False): # Ensure it returns List[str], not List[Dict] - if not match_masks(file_name, self.masks): - continue - - # ensure the file name has the full fsspec protocol path - if any(file_name.startswith(protocol) for protocol in protocol_list): - yield file_name - else: - if is_local: - abs_path = os.path.join(path, file_name) - elif not file_name.startswith(path): - abs_path = posixpath.join(path, file_name) - else: - abs_path = file_name - - starts_with = False - for protocol in protocol_list: - if root.startswith(protocol): - starts_with = True - yield protocol + "://" + abs_path - break - - if not starts_with: - yield abs_path - - -@functional_datapipe("open_files_by_fsspec") -class FSSpecFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): - r""" - Opens files from input datapipe which contains `fsspec` paths and yields a tuple of - pathname and opened file stream (functional name: ``open_files_by_fsspec``). - - Args: - source_datapipe: Iterable DataPipe that provides the pathnames or URLs - mode: An optional string that specifies the mode in which the file is opened (``"r"`` by default) - kwargs_for_open: Optional Dict to specify kwargs for opening files (``fs.open()``) - kwargs: Extra options that are used to establish a particular storage connection, - e.g. host, port, username, password, etc. - - Example: - - .. testsetup:: - - dir_path = "path" - - .. testcode:: - - from torchdata.datapipes.iter import FSSpecFileLister - - datapipe = FSSpecFileLister(root=dir_path) - file_dp = datapipe.open_files_by_fsspec() - """ - - def __init__( - self, source_datapipe: IterDataPipe[str], mode: str = "r", *, kwargs_for_open: Optional[Dict] = None, **kwargs - ) -> None: - _assert_fsspec() - - self.source_datapipe: IterDataPipe[str] = source_datapipe - self.mode: str = mode - self.kwargs_for_open = kwargs_for_open if kwargs_for_open is not None else {} - self.kwargs_for_connection = kwargs - - def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: - for file_uri in self.source_datapipe: - fs, path = fsspec.core.url_to_fs(file_uri, **self.kwargs_for_connection) - file = fs.open(path, self.mode, **self.kwargs_for_open) - yield file_uri, StreamWrapper(file) - - def __len__(self) -> int: - return len(self.source_datapipe) - - -@functional_datapipe("save_by_fsspec") -class FSSpecSaverIterDataPipe(IterDataPipe[str]): - r""" - Takes in a DataPipe of tuples of metadata and data, saves the data to the target - path (generated by the filepath_fn and metadata), and yields the resulting `fsspec` - path (functional name: ``save_by_fsspec``). - - Args: - source_datapipe: Iterable DataPipe with tuples of metadata and data - mode: Mode in which the file will be opened for write the data (``"w"`` by default) - filepath_fn: Function that takes in metadata and returns the target path of the new file - kwargs_for_open: Optional Dict to specify kwargs for opening files (``fs.open()``) - kwargs: Extra options that are used to establish a particular storage connection, - e.g. host, port, username, password, etc. - - - Example: - - .. testsetup:: - - file_prefix = "file" - - .. testcode:: - - from torchdata.datapipes.iter import IterableWrapper - - - def filepath_fn(name: str) -> str: - return file_prefix + name - - - name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"} - source_dp = IterableWrapper(sorted(name_to_data.items())) - fsspec_saver_dp = source_dp.save_by_fsspec(filepath_fn=filepath_fn, mode="wb") - res_file_paths = list(fsspec_saver_dp) - - .. testcleanup:: - - import os - - for name in name_to_data.keys(): - os.remove(file_prefix + name) - """ - - def __init__( - self, - source_datapipe: IterDataPipe[Tuple[Any, U]], - mode: str = "w", - filepath_fn: Optional[Callable] = None, - *, - kwargs_for_open: Optional[Dict] = None, - **kwargs, - ): - _assert_fsspec() - - self.source_datapipe: IterDataPipe[Tuple[Any, U]] = source_datapipe - self.mode: str = mode - self.filepath_fn: Optional[Callable] = filepath_fn - self.kwargs_for_open = kwargs_for_open if kwargs_for_open is not None else {} - self.kwargs_for_connection = kwargs - - def __iter__(self) -> Iterator[str]: - for meta, data in self.source_datapipe: - filepath = meta if self.filepath_fn is None else self.filepath_fn(meta) - fs, path = fsspec.core.url_to_fs(filepath, **self.kwargs_for_connection) - with fs.open(path, self.mode, **self.kwargs_for_open) as f: - f.write(data) - yield filepath - - def __len__(self) -> int: - return len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/load/huggingface.py b/torchdata/datapipes/iter/load/huggingface.py deleted file mode 100644 index 783fa9fb6..000000000 --- a/torchdata/datapipes/iter/load/huggingface.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Any, Iterator, Tuple - -from torchdata.datapipes.iter import IterDataPipe -from torchdata.datapipes.utils import StreamWrapper - -try: - import datasets -except ImportError: - datasets = None - - -def _get_response_from_huggingface_hub( - dataset: str, - streaming: bool = True, - **config_kwargs, -) -> Iterator[Any]: - hf_dataset = datasets.load_dataset(path=dataset, streaming=streaming, **config_kwargs) - return iter(hf_dataset) - - -class HuggingFaceHubReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): - r""" - Takes in dataset names and returns an Iterable HuggingFace dataset. - Please refer to https://huggingface.co/docs/datasets/loading for the meaning and type of each argument. - Contrary to their implementation, default behavior differs in the following: - - * ``streaming`` is set to ``True`` - - Args: - dataset: path or name of the dataset - **config_kwargs: additional arguments for ``datasets.load_dataset()`` - - Example: - - .. testsetup:: - - import datasets - from torchdata.datapipes.iter import IterableWrapper, HuggingFaceHubReader - from unittest.mock import MagicMock - - datasets.load_dataset = MagicMock(return_value=datasets.Dataset.from_dict( - {"id": ["7bd227d9-afc9-11e6-aba1-c4b301cdf627", "7bd22905-afc9-11e6-a5dc-c4b301cdf627" ], "package_name": ["com.mantz_it.rfanalyzer"] * 2} - )) - - .. testcode:: - - huggingface_reader_dp = HuggingFaceHubReader("lhoestq/demo1", revision="main") - elem = next(iter(huggingface_reader_dp)) - assert elem["package_name"] == "com.mantz_it.rfanalyzer" - - """ - - def __init__( - self, - dataset: str, - **config_kwargs, - ) -> None: - if datasets is None: - raise ModuleNotFoundError( - "Package `datasets` is required to be installed to use this datapipe." - "Please use `pip install datasets` or `conda install -c conda-forge datasets`" - "to install the package" - ) - - self.dataset = dataset - self.config_kwargs = config_kwargs - - def __iter__(self) -> Iterator[Any]: - return _get_response_from_huggingface_hub(dataset=self.dataset, **self.config_kwargs) - - def __len__(self) -> int: - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/torchdata/datapipes/iter/load/iopath.py b/torchdata/datapipes/iter/load/iopath.py deleted file mode 100644 index cb334ebe4..000000000 --- a/torchdata/datapipes/iter/load/iopath.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os - -from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union - -from torch.utils.data.datapipes.utils.common import match_masks - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterableWrapper, IterDataPipe -from torchdata.datapipes.utils import StreamWrapper - -try: - import iopath - -except ImportError: - iopath = None - -U = Union[bytes, bytearray, str] - - -def _create_default_pathmanager(): - from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathManager - - pathmgr = PathManager() - pathmgr.register_handler(HTTPURLHandler(), allow_override=True) - pathmgr.register_handler(OneDrivePathHandler(), allow_override=True) - # S3PathHandler is not included in 0.1.8 - try: - from iopath.common.s3 import S3PathHandler - - pathmgr.register_handler(S3PathHandler(), allow_override=True) - except ImportError: - pass - return pathmgr - - -@functional_datapipe("list_files_by_iopath") -class IoPathFileListerIterDataPipe(IterDataPipe[str]): - r""" - Lists the contents of the directory at the provided ``root`` pathname or URL, - and yields the full pathname or URL for each file within the directory (functional name: ``list_files_by_iopath``). - - Args: - root: The root local filepath or URL directory or list of roots to list files from - masks: Unix style filter string or string list for filtering file name(s) - pathmgr: Custom ``iopath.PathManager``. If not specified, a default ``PathManager`` is created. - - Note: - Default ``PathManager`` currently supports local file path, normal HTTP URL and OneDrive URL. - S3 URL is supported only with ``iopath``>=0.1.9. - - Example: - - .. testsetup:: - - s3_url = "path" - - .. testcode:: - - from torchdata.datapipes.iter import IoPathFileLister - - datapipe = IoPathFileLister(root=s3_url) - """ - - def __init__( - self, - root: Union[str, Sequence[str], IterDataPipe], - masks: Union[str, List[str]] = "", - *, - pathmgr=None, - handler=None, - ) -> None: - if iopath is None: - raise ModuleNotFoundError( - "Package `iopath` is required to be installed to use this datapipe." - "Please use `pip install iopath` or `conda install -c conda-forge iopath`" - "to install the package" - ) - - if isinstance(root, str): - root = [ - root, - ] - if not isinstance(root, IterDataPipe): - self.datapipe: IterDataPipe = IterableWrapper(root) # type: ignore[assignment] - else: - self.datapipe = root - self.pathmgr = _create_default_pathmanager() if pathmgr is None else pathmgr - self.masks = masks - if handler is not None: - self.register_handler(handler, allow_override=True) - - def register_handler(self, handler, allow_override=False): - self.pathmgr.register_handler(handler, allow_override=allow_override) - - def __iter__(self) -> Iterator[str]: - for path in self.datapipe: - if self.pathmgr.isfile(path): - yield path - else: - for file_name in self.pathmgr.ls(path): - if match_masks(file_name, self.masks): - yield os.path.join(path, file_name) - - -@functional_datapipe("open_files_by_iopath") -class IoPathFileOpenerIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): - r""" - Opens files from input datapipe which contains pathnames or URLs, - and yields a tuple of pathname and opened file stream (functional name: ``open_files_by_iopath``). - - Args: - source_datapipe: Iterable DataPipe that provides the pathnames or URLs - mode: An optional string that specifies the mode in which the file is opened (``"r"`` by default) - pathmgr: Custom ``iopath.PathManager``. If not specified, a default ``PathManager`` is created. - - Note: - Default ``PathManager`` currently supports local file path, normal HTTP URL and OneDrive URL. - S3 URL is supported only with `iopath`>=0.1.9. - - Example: - - .. testsetup:: - - s3_url = "path" - - .. testcode:: - - from torchdata.datapipes.iter import IoPathFileLister - - datapipe = IoPathFileLister(root=s3_url) - file_dp = datapipe.open_files_by_iopath() - """ - - def __init__(self, source_datapipe: IterDataPipe[str], mode: str = "r", pathmgr=None, handler=None) -> None: - if iopath is None: - raise ModuleNotFoundError( - "Package `iopath` is required to be installed to use this datapipe." - "Please use `pip install iopath` or `conda install -c conda-forge iopath`" - "to install the package" - ) - - self.source_datapipe: IterDataPipe[str] = source_datapipe - self.pathmgr = _create_default_pathmanager() if pathmgr is None else pathmgr - self.mode: str = mode - if handler is not None: - self.register_handler(handler, allow_override=True) - - def register_handler(self, handler, allow_override=False): - self.pathmgr.register_handler(handler, allow_override=allow_override) - - def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: - for file_uri in self.source_datapipe: - file = self.pathmgr.open(file_uri, self.mode) - yield file_uri, StreamWrapper(file) - - def __len__(self) -> int: - return len(self.source_datapipe) - - -@functional_datapipe("save_by_iopath") -class IoPathSaverIterDataPipe(IterDataPipe[str]): - - r""" - Takes in a DataPipe of tuples of metadata and data, saves the data - to the target path which is generated by the ``filepath_fn`` and metadata, and yields the resulting path - in `iopath` format (functional name: ``save_by_iopath``). - - Args: - source_datapipe: Iterable DataPipe with tuples of metadata and data - mode: Mode in which the file will be opened for write the data (``"w"`` by default) - filepath_fn: Function that takes in metadata and returns the target path of the new file - pathmgr: Custom ``iopath.PathManager``. If not specified, a default ``PathManager`` is created. - - Note: - Default ``PathManager`` currently supports local file path, normal HTTP URL and OneDrive URL. - S3 URL is supported only with `iopath`>=0.1.9. - - Example: - - .. testsetup:: - - s3_url = "url" - - .. testcode:: - - from torchdata.datapipes.iter import IterableWrapper - - - def filepath_fn(name: str) -> str: - return s3_url + name - - - name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"} - source_dp = IterableWrapper(sorted(name_to_data.items())) - iopath_saver_dp = source_dp.save_by_iopath(filepath_fn=filepath_fn, mode="wb") - res_file_paths = list(iopath_saver_dp) - - .. testcleanup:: - - import os - - for file in ["1.txt", "1.txt.lock", "2.txt", "2.txt.lock", "3.txt", "3.txt.lock"]: - os.remove(s3_url + file) - """ - - def __init__( - self, - source_datapipe: IterDataPipe[Tuple[Any, U]], - mode: str = "w", - filepath_fn: Optional[Callable] = None, - *, - pathmgr=None, - handler=None, - ): - if iopath is None: - raise ModuleNotFoundError( - "Package `iopath` is required to be installed to use this datapipe." - "Please use `pip install iopath` or `conda install -c conda-forge iopath`" - "to install the package" - ) - - self.source_datapipe: IterDataPipe[Tuple[Any, U]] = source_datapipe - self.mode: str = mode - self.filepath_fn: Optional[Callable] = filepath_fn - self.pathmgr = _create_default_pathmanager() if pathmgr is None else pathmgr - if handler is not None: - self.register_handler(handler, allow_override=True) - - def __iter__(self) -> Iterator[str]: - for meta, data in self.source_datapipe: - filepath = meta if self.filepath_fn is None else self.filepath_fn(meta) - with iopath.file_lock(filepath): - if not os.path.exists(filepath): - with self.pathmgr.open(filepath, self.mode) as f: - f.write(data) - yield filepath - - def register_handler(self, handler, allow_override=False): - self.pathmgr.register_handler(handler, allow_override=allow_override) - - def __len__(self) -> int: - return len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/load/online.py b/torchdata/datapipes/iter/load/online.py deleted file mode 100644 index 7b9d8d9dc..000000000 --- a/torchdata/datapipes/iter/load/online.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import re -import urllib -import warnings -from typing import Any, Dict, Iterator, Optional, Tuple - -import requests - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe -from torchdata.datapipes.utils import StreamWrapper - - -# TODO(642): Remove this helper function when https://bugs.python.org/issue42627 is resolved -def _get_proxies() -> Optional[Dict[str, str]]: - import os - - if os.name == "nt": - proxies = urllib.request.getproxies() - address = proxies.get("https") - # The default proxy type of Windows is HTTP - if address and address.startswith("https"): - address = "http" + address[5:] - proxies["https"] = address - return proxies - return None - - -def _get_response_from_http( - url: str, *, timeout: Optional[float], **query_params: Optional[Dict[str, Any]] -) -> Tuple[str, StreamWrapper]: - with requests.Session() as session: - proxies = _get_proxies() - r = session.get(url, timeout=timeout, proxies=proxies, stream=True, **query_params) # type: ignore[arg-type] - r.raise_for_status() - return url, StreamWrapper(r.raw) - - -@functional_datapipe("read_from_http") -class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): - r""" - Takes file URLs (HTTP URLs pointing to files), and yields tuples of file URL and - IO stream (functional name: ``read_from_http``). - - Args: - source_datapipe: a DataPipe that contains URLs - timeout: timeout in seconds for HTTP request - skip_on_error: whether to skip over urls causing problems, otherwise an exception is raised - **kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/ - - Example: - - .. testcode:: - - from torchdata.datapipes.iter import IterableWrapper, HttpReader - - file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE" - query_params = {"auth" : ("fake_username", "fake_password"), "allow_redirects" : True} - timeout = 120 - http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, **query_params) - reader_dp = http_reader_dp.readlines() - it = iter(reader_dp) - path, line = next(it) - print((path, line)) - - Output: - - .. testoutput:: - - ('https://raw.githubusercontent.com/pytorch/data/main/LICENSE', b'BSD 3-Clause License') - """ - - def __init__( - self, - source_datapipe: IterDataPipe[str], - timeout: Optional[float] = None, - skip_on_error: bool = False, - **kwargs: Optional[Dict[str, Any]], - ) -> None: - self.source_datapipe: IterDataPipe[str] = source_datapipe - self.timeout = timeout - self.skip_on_error = skip_on_error - self.query_params = kwargs - - def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: - for url in self.source_datapipe: - try: - yield _get_response_from_http(url, timeout=self.timeout, **self.query_params) - except Exception as e: - if self.skip_on_error: - warnings.warn(f"{e}, skipping...") - else: - raise - - def __len__(self) -> int: - return len(self.source_datapipe) - - -def _extract_gdrive_api_response(content: str) -> Optional[str]: - match = re.search("Google Drive - (?P<api_response>.+?)", content) - return match["api_response"] if match is not None else None - - -def _get_response_from_google_drive( - url: str, *, timeout: Optional[float], **query_params: Optional[Dict[str, Any]] -) -> Tuple[str, StreamWrapper]: - confirm_token = None - - with requests.Session() as session: - response = session.get(url, timeout=timeout, stream=True, **query_params) # type: ignore[arg-type] - response.raise_for_status() - - for k, v in response.cookies.items(): - if k.startswith("download_warning"): - confirm_token = v - break - else: - api_response = _extract_gdrive_api_response(response.text) - if api_response == "Virus scan warning": - confirm_token = "t" - elif api_response == "Quota exceeded": - raise RuntimeError(f"Google drive link {url} is currently unavailable, because the quota was exceeded.") - - if confirm_token: - url = url + "&confirm=" + confirm_token - - response = session.get(url, timeout=timeout, stream=True, **query_params) # type: ignore[arg-type] - response.raise_for_status() - - if "content-disposition" not in response.headers: - raise RuntimeError( - f"Google drive link {url} internal error: " - "headers don't contain content-disposition. This is usually caused by " - "using a sharing/viewing link instead of a download link. Click 'Download' on the " - "Google Drive page, which should redirect you to a download page, and use the link " - "of that page." - ) - - filename = re.findall('filename="(.+)"', response.headers["content-disposition"]) - if filename is None: - raise RuntimeError(f"Google drive link {url}: filename could not be autodetected") - - return filename[0], StreamWrapper(response.raw) - - -@functional_datapipe("read_from_gdrive") -class GDriveReaderDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): - r""" - Takes URLs pointing at GDrive files, and yields tuples of file name and - IO stream (functional name: ``read_from_gdrive``). - - Args: - source_datapipe: a DataPipe that contains URLs to GDrive files - timeout: timeout in seconds for HTTP request - skip_on_error: whether to skip over urls causing problems, otherwise an exception is raised - **kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/ - - Example: - - .. testsetup:: - - from torchdata.datapipes.iter import GDriveReader - - GDriveReader.readlines = lambda self: [ - ("https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile", b"") - ] - - .. testcode:: - - from torchdata.datapipes.iter import IterableWrapper, GDriveReader - - gdrive_file_url = "https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile" - gdrive_reader_dp = GDriveReader(IterableWrapper([gdrive_file_url])) - reader_dp = gdrive_reader_dp.readlines() - it = iter(reader_dp) - path, line = next(it) - print((path, line)) - - Output: - - .. testoutput:: - - ('https://drive.google.com/uc?export=download&id=SomeIDToAGDriveFile', b'') - """ - source_datapipe: IterDataPipe[str] - - def __init__( - self, - source_datapipe: IterDataPipe[str], - *, - timeout: Optional[float] = None, - skip_on_error: bool = False, - **kwargs: Optional[Dict[str, Any]], - ) -> None: - self.source_datapipe = source_datapipe - self.timeout = timeout - self.skip_on_error = skip_on_error - self.query_params = kwargs - - def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: - for url in self.source_datapipe: - try: - yield _get_response_from_google_drive(url, timeout=self.timeout, **self.query_params) - except Exception as e: - if self.skip_on_error: - warnings.warn(f"{e}, skipping...") - else: - raise - - def __len__(self) -> int: - return len(self.source_datapipe) - - -@functional_datapipe("read_from_remote") -class OnlineReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): - r""" - Takes file URLs (can be HTTP URLs pointing to files or URLs to GDrive files), and - yields tuples of file URL and IO stream (functional name: ``read_from_remote``). - - Args: - source_datapipe: a DataPipe that contains URLs - timeout: timeout in seconds for HTTP request - skip_on_error: whether to skip over urls causing problems, otherwise an exception is raised - **kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/ - - Example: - - .. testcode:: - - from torchdata.datapipes.iter import IterableWrapper, OnlineReader - - file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE" - online_reader_dp = OnlineReader(IterableWrapper([file_url])) - reader_dp = online_reader_dp.readlines() - it = iter(reader_dp) - path, line = next(it) - print((path, line)) - - Output: - - .. testoutput:: - - ('https://raw.githubusercontent.com/pytorch/data/main/LICENSE', b'BSD 3-Clause License') - """ - source_datapipe: IterDataPipe[str] - - def __init__( - self, - source_datapipe: IterDataPipe[str], - *, - timeout: Optional[float] = None, - skip_on_error: bool = False, - **kwargs: Optional[Dict[str, Any]], - ) -> None: - self.source_datapipe = source_datapipe - self.timeout = timeout - self.skip_on_error = skip_on_error - self.query_params = kwargs - - def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: - for url in self.source_datapipe: - parts = urllib.parse.urlparse(url) - - if re.match(r"(drive|docs)[.]google[.]com", parts.netloc): - try: - yield _get_response_from_google_drive(url, timeout=self.timeout, **self.query_params) - except Exception as e: - if self.skip_on_error: - warnings.warn(f"{e}, skipping...") - else: - raise - else: - try: - yield _get_response_from_http(url, timeout=self.timeout, **self.query_params) - except Exception as e: - if self.skip_on_error: - warnings.warn(f"{e}, skipping...") - else: - raise - - def __len__(self) -> int: - return len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/load/s3io.py b/torchdata/datapipes/iter/load/s3io.py deleted file mode 100644 index 82214920f..000000000 --- a/torchdata/datapipes/iter/load/s3io.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from io import BytesIO -from typing import Iterator, List, Tuple, Union - -import torchdata - -from torch.utils.data.datapipes.utils.common import match_masks -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe -from torchdata.datapipes.utils import StreamWrapper - - -@functional_datapipe("list_files_by_s3") -class S3FileListerIterDataPipe(IterDataPipe[str]): - r"""[DEPRECATED] Use https://github.com/awslabs/s3-connector-for-pytorch instead. - - Iterable DataPipe that lists Amazon S3 file URLs with the given prefixes (functional name: ``list_files_by_s3``). - Acceptable prefixes include ``s3://bucket-name``, ``s3://bucket-name/``, ``s3://bucket-name/folder``. - - Note: - 1. ``source_datapipe`` **must** contain a list of valid S3 URLs - 2. ``length`` is `-1` by default, and any call to ``__len__()`` is invalid, because the length is unknown - until all files are iterated. - 3. ``request_timeout_ms`` and ``region`` will overwrite settings in the configuration file or - environment variables. - 4. The lack of AWS proper configuration can lead empty response. For more details related to S3 IO DataPipe - setup and AWS config, please see the `README file`_. - - .. _README file: - https://github.com/pytorch/data/tree/main/torchdata/datapipes/iter/load#s3-io-datapipe-documentation - - Args: - source_datapipe: a DataPipe that contains URLs/URL prefixes to s3 files - length: Nominal length of the datapipe - request_timeout_ms: timeout setting for each reqeust (3,000ms by default) - region: region for access files (inferred from credentials by default) - - Example: - - .. testsetup:: - - from unittest import mock - from torchdata.datapipes.iter import IterableWrapper, S3FileLister - - file_lister_patch = mock.patch.object(S3FileLister, "__iter__", return_value=iter([])) - file_lister_patch.start() - - .. testcode:: - - from torchdata.datapipes.iter import IterableWrapper, S3FileLister - - s3_prefixes = IterableWrapper(['s3://bucket-name/folder/', ...]) - - dp_s3_urls = S3FileLister(s3_prefixes) - for d in dp_s3_urls: - pass - - # Functional API - dp_s3_urls = s3_prefixes.list_files_by_s3(request_timeout_ms=100) - for d in dp_s3_urls: - pass - - .. testcleanup:: - - file_lister_patch.stop() - """ - - def __init__( - self, - source_datapipe: IterDataPipe[str], - length: int = -1, - request_timeout_ms=-1, - region="", - masks: Union[str, List[str]] = "", - ) -> None: - if not hasattr(torchdata, "_torchdata") or not hasattr(torchdata._torchdata, "S3Handler"): - raise ModuleNotFoundError("TorchData must be built with BUILD_S3=1 to use this datapipe.") - - self.source_datapipe: IterDataPipe[str] = source_datapipe - self.length: int = length - self.handler = torchdata._torchdata.S3Handler(request_timeout_ms, region) - self.masks = masks - - def __iter__(self) -> Iterator[str]: - for prefix in self.source_datapipe: - while True: - urls = self.handler.list_files(prefix) - for url in urls: - if match_masks(url, self.masks): - yield url - if not urls: - break - self.handler.clear_marker() - - def __len__(self) -> int: - if self.length == -1: - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") - return self.length - - -@functional_datapipe("load_files_by_s3") -class S3FileLoaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): - r"""[DEPRECATED] Use https://github.com/awslabs/s3-connector-for-pytorch instead. - - Iterable DataPipe that loads Amazon S3 files from the given S3 URLs (functional name: ``load_files_by_s3``). - ``S3FileLoader`` iterates all given S3 URLs in ``BytesIO`` format with ``(url, BytesIO)`` tuples. - - Note: - 1. ``source_datapipe`` **must** contain a list of valid S3 URLs. - 2. ``request_timeout_ms`` and ``region`` will overwrite settings in the - configuration file or environment variables. - 3. The lack of AWS proper configuration can lead empty response. For more details related to S3 IO DataPipe - setup and AWS config, please see the `README file`_. - - .. _README file: - https://github.com/pytorch/data/tree/main/torchdata/datapipes/iter/load#s3-io-datapipe-documentation - - Args: - source_datapipe: a DataPipe that contains URLs to s3 files - request_timeout_ms: timeout setting for each reqeust (3,000ms by default) - region: region for access files (inferred from credentials by default) - buffer_size: buffer size of each chunk to download large files progressively (128Mb by default) - multi_part_download: flag to split each chunk into small packets and download those packets in parallel (enabled by default) - - Example: - - .. testsetup:: - - from unittest import mock - from torchdata.datapipes.iter import S3FileLister - - file_lister_patch = mock.patch.object(S3FileLister, "__iter__", return_value=iter([])) - file_lister_patch.start() - - .. testcode:: - - from torchdata.datapipes.iter import IterableWrapper, S3FileLoader - - dp_s3_urls = IterableWrapper(['s3://bucket-name/folder/', ...]).list_files_by_s3() - # In order to make sure data are shuffled and sharded in the - # distributed environment, `shuffle` and `sharding_filter` - # are required. For detail, please check our tutorial in: - # https://pytorch.org/data/main/tutorial.html#working-with-dataloader - sharded_s3_urls = dp_s3_urls.shuffle().sharding_filter() - - dp_s3_files = S3FileLoader(sharded_s3_urls) - for url, fd in dp_s3_files: # Start loading data - data = fd.read() - - # Functional API - dp_s3_files = sharded_s3_urls.load_files_by_s3(buffer_size=256) - for url, fd in dp_s3_files: - data = fd.read() - - .. testcleanup:: - - file_lister_patch.stop() - """ - - def __init__( - self, - source_datapipe: IterDataPipe[str], - request_timeout_ms=-1, - region="", - buffer_size=None, - multi_part_download=None, - ) -> None: - if not hasattr(torchdata, "_torchdata") or not hasattr(torchdata._torchdata, "S3Handler"): - raise ModuleNotFoundError("TorchData must be built with BUILD_S3=1 to use this datapipe.") - - self.source_datapipe: IterDataPipe[str] = source_datapipe - self.handler = torchdata._torchdata.S3Handler(request_timeout_ms, region) - if buffer_size: - self.handler.set_buffer_size(buffer_size) - if multi_part_download: - self.handler.set_multi_part_download(multi_part_download) - - def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: - for url in self.source_datapipe: - yield url, StreamWrapper(BytesIO(self.handler.s3_read(url))) - - def __len__(self) -> int: - return len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/transform/__init__.py b/torchdata/datapipes/iter/transform/__init__.py deleted file mode 100644 index 2e41cd717..000000000 --- a/torchdata/datapipes/iter/transform/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torchdata/datapipes/iter/transform/bucketbatcher.py b/torchdata/datapipes/iter/transform/bucketbatcher.py deleted file mode 100644 index fb2a7f617..000000000 --- a/torchdata/datapipes/iter/transform/bucketbatcher.py +++ /dev/null @@ -1,319 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import heapq -import random - -from dataclasses import dataclass, field -from functools import partial -from typing import Callable, final, Generic, Iterator, List, Optional, TypeVar - -import torch - -from torchdata.datapipes import DataChunk, functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - - -T = TypeVar("T") -T_co = TypeVar("T_co", covariant=True) - - -@functional_datapipe("in_batch_shuffle") -class InBatchShufflerIterDataPipe(IterDataPipe[DataChunk[T_co]]): - r""" - Shuffles each mini-batch from the prior DataPipe (functional name: ``in_batch_shuffle``). - - Args: - datapipe: Iterable DataPipe with batched data - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> source_dp = IterableWrapper(range(10)) - >>> batch_dp = source_dp.batch(batch_size=3, drop_last=True) - >>> list(batch_dp) - [[0, 1, 2], [3, 4, 5], [6, 7, 8]] - >>> in_batch_shuffle_dp = batch_dp.in_batch_shuffle() - >>> list(in_batch_shuffle_dp) - [[2, 0, 1], [3, 5, 4], [7, 8, 6]] - """ - - def __init__(self, datapipe: IterDataPipe[DataChunk[T_co]]): - self.datapipe = datapipe - self._enabled = True - self._seed: Optional[int] = None - self._rng = random.Random() - - def set_shuffle(self, shuffle=True): - self._enabled = shuffle - return self - - def set_seed(self, seed: int): - self._seed = seed - return self - - def __iter__(self) -> Iterator[DataChunk[T_co]]: - if not self._enabled: - for batch in self.datapipe: - yield batch - else: - for batch in self.datapipe: - new_batch = self._rng.sample(batch, len(batch)) - yield DataChunk(new_batch) - - @final - def reset(self) -> None: - if self._enabled: - if self._seed is None: - self._seed = int(torch.empty((), dtype=torch.int64).random_().item()) - self._rng.seed(self._seed) - self._seed = None - - def __len__(self) -> int: - return len(self.datapipe) - - def __getstate__(self): - state = ( - self.datapipe, - self._enabled, - self._seed, - self._rng.getstate(), - self._valid_iterator_id, - self._number_of_samples_yielded, - ) - if IterDataPipe.getstate_hook is not None: - return IterDataPipe.getstate_hook(state) - return state - - def __setstate__(self, state): - ( - self.datapipe, - self._enabled, - self._seed, - rng_state, - self._valid_iterator_id, - self._number_of_samples_yielded, - ) = state - self._rng = random.Random() - self._rng.setstate(rng_state) - - -@functional_datapipe("bucketbatch") -class BucketBatcherIterDataPipe(IterDataPipe[DataChunk[T_co]]): - r""" - Creates mini-batches of data from sorted bucket (functional name: ``bucketbatch``). An outer - dimension will be added as ``batch_size`` if ``drop_last`` is set to ``True``, - or ``length % batch_size`` for the last batch if ``drop_last`` is set to ``False``. - - The purpose of this DataPipe is to batch samples with some similarity according to the sorting function - being passed. For an example in the text domain, it may be batching examples with similar number of tokens - to minimize padding and to increase throughput. - - Args: - datapipe: Iterable DataPipe being batched - batch_size: The size of each batch - drop_last: Option to drop the last batch if it's not full - batch_num: Number of batches within a bucket (i.e. `bucket_size = batch_size * batch_num`) - bucket_num: Number of buckets to consist a pool for shuffling (i.e. `pool_size = bucket_size * bucket_num`) - sort_key: Callable to sort a bucket (list) - use_in_batch_shuffle: if True, do in-batch shuffle; if False, buffer shuffle - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> source_dp = IterableWrapper(range(10)) - >>> batch_dp = source_dp.bucketbatch(batch_size=3, drop_last=True) - >>> list(batch_dp) - [[5, 6, 7], [9, 0, 1], [4, 3, 2]] - >>> def sort_bucket(bucket): - >>> return sorted(bucket) - >>> batch_dp = source_dp.bucketbatch( - >>> batch_size=3, drop_last=True, batch_num=100, - >>> bucket_num=1, use_in_batch_shuffle=False, sort_key=sort_bucket - >>> ) - >>> list(batch_dp) - [[3, 4, 5], [6, 7, 8], [0, 1, 2]] - """ - datapipe: IterDataPipe[T_co] - batch_size: int - drop_last: bool - batch_num: int - bucket_num: int - sort_key: Optional[Callable] - use_in_batch_shuffle: bool - - def __new__( - cls, - datapipe: IterDataPipe[T_co], - batch_size: int, - drop_last: bool = False, - batch_num: int = 100, - bucket_num: int = 1, - sort_key: Optional[Callable] = None, - use_in_batch_shuffle: bool = True, - ): - assert batch_size > 0, "Batch size is required to be larger than 0!" - assert batch_num > 0, "Number of batches is required to be larger than 0!" - assert bucket_num > 0, "Number of buckets is required to be larger than 0!" - - bucket_size = batch_size * batch_num - pool_size = bucket_size * bucket_num - - # Shuffle by pool_size - if bucket_num > 1 or sort_key is None: - if use_in_batch_shuffle: - datapipe = datapipe.batch(batch_size=pool_size, drop_last=False).in_batch_shuffle().unbatch() - else: - datapipe = datapipe.shuffle(buffer_size=pool_size) - # Sort by bucket_size if sort_key is given - if sort_key is not None: - datapipe = datapipe.batch(bucket_size).map(fn=sort_key).unbatch() - # Batch and drop last (if needed) - datapipe = datapipe.batch(batch_size, drop_last=drop_last) - # Shuffle the batched data - if sort_key is not None: - # In-batch shuffle each bucket seems not that useful, it seems misleading since .batch is called prior. - if use_in_batch_shuffle: - datapipe = datapipe.batch(batch_size=bucket_num, drop_last=False).in_batch_shuffle().unbatch() - else: - datapipe = datapipe.shuffle(buffer_size=bucket_size) - return datapipe - - -def _default_len_fn(token): - return len(token) - - -@dataclass(order=True, frozen=True) -class PrioritizedItem(Generic[T_co]): - length: int - data: T_co = field(compare=False) - - -def _token_len_fn(token: T, len_fn: Callable) -> PrioritizedItem[T]: - return PrioritizedItem(length=len_fn(token), data=token) - - -def _token_filter_fn(data, *, min_len, max_len): - return data.length >= min_len and data.length <= max_len - - -@functional_datapipe("max_token_bucketize") -class MaxTokenBucketizerIterDataPipe(IterDataPipe[DataChunk[T_co]]): - r""" - Creates mini-batches of data from a min-heap with limited size, and the total length of samples - returned by ``len_fn`` within each batch will be limited by ``max_token_count`` - (functional name: ``max_token_bucketize``). If ``min_len`` or ``max_len`` is set, the samples with - length that is out of ``[min_len, max_len]`` will be filtered out. - - The purpose of this DataPipe is to batch samples with similar length according to ``len_fn``. - Min-heap is used here to make sure the samples are sorted incrementally based on the length. And, - the total length of samples in each batch is guaranteed to be smaller than ``max_token_count``. - For an example in the audio domain, it may be batching samples with similar length. Then, given the - ``max_token_count``, each batch may be concatenated to a Tensor with the same size and minimum padding. - - If ``include_padding`` is set to ``True``, the token count of each batch includes the padding a succeeding - DataPipe could add. This guarentees that even after the batch is padded, ``max_token_count`` will not be exceeded. - This can prevent out-of-memory issues for data with large variations in length. - - Note that batches are bucketized starting from the smallest size in a buffer. - This can limit the variablity of batches if ``buffer_size`` is large. - To increase variablity, apply ``torchdata.datapipes.iter.Shuffler`` before and after this DataPipe, - and keep ``buffer_size`` small. - - - Args: - datapipe: Iterable DataPipe being batched - max_token_count: Maximum length of total length of data in each batch - len_fn: Function to be applied to each element to get lengths. ``len(data)`` is used by default. - min_len: Optional minimum length to be included into each batch - max_len: Optional maximum length to be included into each batch. - buffer_size: This restricts how many samples are taken from prior DataPipe to bucketize - include_padding: If True, the size of each batch includes the extra padding to the largest length in the batch. - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> source_dp = IterableWrapper(['1', '11', '1', '1111', '111', '1', '11', '11', '111']) - >>> # Using default len_fn to sort samples based on length (string length in this case) - >>> batch_dp = source_dp.max_token_bucketize(max_token_count=5) - >>> list(batch_dp) - [['1', '1', '1', '11'], ['11', '11'], ['111'], ['111'], ['1111']] - >>> batch_dp = source_dp.max_token_bucketize(max_token_count=4, buffer_size=4) - >>> list(batch_dp) - [['1', '1', '1'], ['11', '11'], ['11'], ['111'], ['111'], ['1111']] - """ - datapipe: IterDataPipe[PrioritizedItem[T_co]] - max_token_count: int - len_fn: Callable - min_len: int - max_len: Optional[int] - buffer_size: int - - def __init__( - self, - datapipe: IterDataPipe[T_co], - max_token_count: int, - len_fn: Callable = _default_len_fn, - min_len: int = 0, - max_len: Optional[int] = None, - buffer_size: int = 1000, - include_padding: bool = False, - ) -> None: - if max_len is None: - max_len = max_token_count - - if min_len < 0 or min_len > max_len: - raise ValueError("``min_len`` should be larger than 0 and equal to or smaller than ``max_len``.") - if max_len > max_token_count: - raise ValueError("``max_token_count`` must be equal to or greater than ``max_len``.") - if buffer_size <= 0: - raise ValueError("'buffer_size' is required to be a positive integer.") - self.datapipe = datapipe.map(partial(_token_len_fn, len_fn=len_fn)) - self.datapipe = self.datapipe.filter(partial(_token_filter_fn, min_len=min_len, max_len=max_len)) - self.max_token_count = max_token_count - self.buffer_size = buffer_size - self.include_padding = include_padding - - def __iter__(self) -> Iterator[DataChunk[T_co]]: - buffer: List[PrioritizedItem[T_co]] = [] - batch: List[T_co] = [] - batch_size: int = 0 - max_length: int = 0 - for d in self.datapipe: - heapq.heappush(buffer, d) - if len(buffer) == self.buffer_size: - buffer, batch, batch_size, max_length, data_chunk = self._pop_buffer( - buffer, batch, batch_size, max_length - ) - if data_chunk is not None: - yield data_chunk - while buffer: - buffer, batch, batch_size, max_length, data_chunk = self._pop_buffer(buffer, batch, batch_size, max_length) - if data_chunk is not None: - yield data_chunk - if batch: - yield DataChunk(batch) - - def _pop_buffer(self, buffer: List[PrioritizedItem[T_co]], batch: List[T_co], batch_size: int, max_length: int): - data_chunk_to_yield = None - d: PrioritizedItem[T_co] = heapq.heappop(buffer) - length = d.length - token = d.data - - if self.include_padding: - max_length = max(length, max_length) - new_batch_size = (len(batch) + 1) * max_length - else: - new_batch_size = batch_size + length - - if new_batch_size > self.max_token_count: - data_chunk_to_yield = DataChunk(batch) - batch = [token] - batch_size = length - max_length = length - else: - batch.append(token) - batch_size = new_batch_size - - return buffer, batch, batch_size, max_length, data_chunk_to_yield diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py deleted file mode 100644 index a876f0b71..000000000 --- a/torchdata/datapipes/iter/transform/callable.py +++ /dev/null @@ -1,954 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import asyncio -import inspect -import random -import warnings -from collections import deque -from concurrent import futures - -from typing import Callable, Hashable, Iterator, List, Optional, Set, Sized, TypeVar, Union - -import torch -from torch.utils.data.datapipes.utils.common import _check_unpickable_fn, validate_input_col -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -T_co = TypeVar("T_co", covariant=True) - - -def _no_op_fn(*args): - """ - No-operation function, returns passed arguments. - """ - if len(args) == 1: - return args[0] - return args - - -@functional_datapipe("map_batches") -class BatchMapperIterDataPipe(IterDataPipe[T_co]): - r""" - Combines elements from the source DataPipe to batches and applies a function - over each batch, then flattens the outputs to a single, unnested IterDataPipe - (functional name: ``map_batches``). - - Args: - datapipe: Source IterDataPipe - fn: The function to be applied to each batch of data - batch_size: The size of batch to be aggregated from ``datapipe`` - input_col: Index or indices of data which ``fn`` is applied, such as: - - - ``None`` as default to apply ``fn`` to the data directly. - - Integer(s) is used for list/tuple. - - Key(s) is used for dict. - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> def fn(batch): - >>> return [d + 1 for d in batch] - >>> source_dp = IterableWrapper(list(range(5))) - >>> mapped_dp = source_dp.map_batches(fn, batch_size=3) - >>> list(mapped_dp) - [1, 2, 3, 4, 5] - - Notes: - Compared with ``map``, the reason that ``map_batches`` doesn't take - ``output_col`` argument is the size of ``fn`` output is not guaranteed - to be the same as input batch. With different size, this operation cannot - assign data back to original data structure. - - And, this operation is introduced based on the use case from `TorchText`. - A pybinded C++ vectorized function can be applied for efficiency. - """ - datapipe: IterDataPipe - fn: Callable - batch_size: int - - def __init__( - self, - datapipe: IterDataPipe, - fn: Callable, - batch_size: int, - input_col=None, - ) -> None: - self.datapipe = datapipe - - _check_unpickable_fn(fn) - self.fn = fn # type: ignore[assignment] - - assert batch_size > 0, "Batch size is required to be larger than 0!" - self.batch_size = batch_size - self.input_col = input_col - - def _apply_fn(self, batch): - if self.input_col is None: - return self.fn(batch) - - if isinstance(self.input_col, (list, tuple)): - args = [[data[idx] for idx in self.input_col] for data in batch] - else: - args = [data[self.input_col] for data in batch] - return self.fn(args) - - def __iter__(self) -> Iterator[T_co]: - batch: List = [] - for d in self.datapipe: - batch.append(d) - if len(batch) == self.batch_size: - yield from self._apply_fn(batch) - batch = [] - if batch: - yield from self._apply_fn(batch) - - def __len__(self) -> int: - raise TypeError(f"{type(self).__name__}'s length relies on the output of its function.") - - -@functional_datapipe("flatmap") -class FlatMapperIterDataPipe(IterDataPipe[T_co]): - r""" - Applies a function over each item from the source DataPipe, then - flattens the outputs to a single, unnested IterDataPipe (functional name: ``flatmap``). - - Note: - The output from ``fn`` must be a Sequence. Otherwise, an error will be raised. - If ``fn`` is ``None``, source DataPipe will be just flattened vertically, provided that items can be unpacked. - - Args: - datapipe: Source IterDataPipe - fn: the function to be applied to each element in the DataPipe, the output must be a Sequence - input_col: Index or indices of data which ``fn`` is applied, such as: - - - ``None`` as default to apply ``fn`` to the data directly. - - Integer(s) is/are used for list/tuple. - - Key(s) is/are used for dict. - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> def fn(e): - >>> return [e, e * 10] - >>> source_dp = IterableWrapper(list(range(5))) - >>> flatmapped_dp = source_dp.flatmap(fn) - >>> list(flatmapped_dp) - [0, 0, 1, 10, 2, 20, 3, 30, 4, 40] - >>> - >>> source_dp = IterableWrapper([[1, 2, 3], [4, 5, 6]]) - >>> flatmapped_dp = source_dp.flatmap() - >>> list(flatmapped_dp) - [1, 2, 3, 4, 5, 6] - """ - datapipe: IterDataPipe - fn: Optional[Callable] - - def __init__(self, datapipe: IterDataPipe, fn: Optional[Callable] = None, input_col=None) -> None: - self.datapipe = datapipe - - if fn is None: - fn = _no_op_fn - _check_unpickable_fn(fn) - self.fn = fn # type: ignore[assignment] - self.input_col = input_col - validate_input_col(fn, input_col) - - def _apply_fn(self, data): - if self.input_col is None: - return self.fn(data) # type: ignore[misc] - elif isinstance(self.input_col, (list, tuple)): - args = tuple(data[col] for col in self.input_col) - return self.fn(*args) # type: ignore[misc] - else: - return self.fn(data[self.input_col]) # type: ignore[misc] - - def __iter__(self) -> Iterator[T_co]: - for d in self.datapipe: - yield from self._apply_fn(d) - - def __len__(self) -> int: - raise TypeError(f"{type(self).__name__}'s length relies on the output of its function.") - - -@functional_datapipe("shuffled_flatmap") -class ShuffledFlatMapperIterDataPipe(IterDataPipe): - r""" - Applies a function over each item from the source DataPipe, - then collects the iterables returned in a buffer, - then, at every iteration, chooses at random one of the iterables in the buffer - and yields one item from this iterable (functional name: ``shuffled_flatmap``). - - When the buffer is full, the DataPipe will begin to yield elements from iterables within the buffer. - New iterables will be added to the buffer once the existing ones run out of elements. - Note: - The output from ``fn`` must be an Iterable. Otherwise, an error will be raised. - If ``fn`` is ``None``, source DataPipe will be just flattened vertically, provided that items can be unpacked. - - Args: - datapipe: Source IterDataPipe - fn: the function to be applied to each element in the DataPipe, the output must be a Sequence - input_col: Index or indices of data which ``fn`` is applied, such as: - - - ``None`` as default to apply ``fn`` to the data directly. - - Integer(s) is/are used for list/tuple. - - Key(s) is/are used for dict. - buffer_size: the max number of iterables this DataPipe can hold at a time (default to ``100``) - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> source_dp = IterableWrapper([[1, 2, 3, 4], 'abcd', 'ABCD']) - >>> shuffled_flatmapped_dp = source_dp.shuffled_flatmap(buffer_size=2) - >>> list(shuffled_flatmapped_dp) - ['a', 'b', 'c', 1, 'd', 'A', 'B', 'C', 2, 'D', 3, 4] - >>> - >>> # To shuffle all the elements, you can combine `shuffled_flatmap` with `in_batch_shuffle` like this: - >>> fully_shuffled_flatmapped_dp = source_dp.in_batch_shuffle() - >>> fully_shuffled_flatmapped_dp = fully_shuffled_flatmapped_dp.shuffled_flatmap() - >>> list(fully_shuffled_flatmapped_dp) - ['b', 3, 'c', 'd', 'C', 'A', 'a', 2, 'B', 'D', 4, 1] - """ - datapipe: IterDataPipe - fn: Optional[Callable] - buffer_size: int - _buffer: List[Iterator] - _enabled: bool - _seed: Optional[int] - _rng: random.Random - _no_op_fn: bool = False - - def __init__( - self, datapipe: IterDataPipe, fn: Optional[Callable] = None, input_col=None, buffer_size: int = 100 - ) -> None: - super().__init__() - self._buffer = [] - self.datapipe = datapipe - - if fn is None: - fn = _no_op_fn - self._no_op_fn = True - _check_unpickable_fn(fn) - self.fn = fn # type: ignore[assignment] - self.input_col = input_col - validate_input_col(fn, input_col) - - assert buffer_size > 0, "buffer_size should be larger than 0" - self.buffer_size = buffer_size - self._enabled = True - self._seed = None - self._rng = random.Random() - - def set_shuffle(self, shuffle=True): - self._enabled = shuffle - return self - - def set_seed(self, seed: int): - self._seed = seed - return self - - def reset(self) -> None: - self._buffer = [] - if self._enabled: - if self._seed is None: - self._seed = int(torch.empty((), dtype=torch.int64).random_().item()) - self._rng.seed(self._seed) - self._seed = None - - def _apply_fn(self, data): - if self.input_col is None: - return self.fn(data) # type: ignore[misc] - elif isinstance(self.input_col, (list, tuple)): - args = tuple(data[col] for col in self.input_col) - return self.fn(*args) # type: ignore[misc] - else: - return self.fn(data[self.input_col]) # type: ignore[misc] - - def __iter__(self) -> Iterator[T_co]: - if not self._enabled: # equivalent to flatmap - for x in self.datapipe: - yield from self._apply_fn(x) - else: - idx = self._rng.randint(0, self.buffer_size - 1) - for x in self.datapipe: - while len(self._buffer) == self.buffer_size: - try: - yield next(self._buffer[idx]) - idx = self._rng.randint(0, self.buffer_size - 1) - except StopIteration: - self._buffer.pop(idx) - self._buffer.append(iter(self._apply_fn(x))) - while self._buffer: - try: - idx = self._rng.randint(0, len(self._buffer) - 1) - yield next(self._buffer[idx]) - except StopIteration: - self._buffer.pop(idx) - - def __len__(self) -> int: - if self._no_op_fn: - return sum(map(len, self.datapipe)) - raise TypeError(f"{type(self).__name__}'s length relies on the output of its function.") - - def __getstate__(self): - state = ( - self.datapipe, - self.fn, - self.input_col, - self.buffer_size, - self._buffer, - self._enabled, - self._seed, - self._rng.getstate(), - self._valid_iterator_id, - self._number_of_samples_yielded, - ) - if IterDataPipe.getstate_hook is not None: - return IterDataPipe.getstate_hook(state) - return state - - def __setstate__(self, state): - ( - self.datapipe, - self.fn, - self.input_col, - self.buffer_size, - self._buffer, - self._enabled, - self._seed, - rng_state, - self._valid_iterator_id, - self._number_of_samples_yielded, - ) = state - self._rng = random.Random() - self._rng.setstate(rng_state) - - def __del__(self): - self._buffer.clear() - - -@functional_datapipe("drop") -class DropperIterDataPipe(IterDataPipe[T_co]): - r""" - Drop columns/elements in input DataPipe via its indices (functional name: ``drop``). - - Args: - datapipe: IterDataPipe with columns to be dropped - indices: a single column index to be dropped or a list of indices - - - Integer(s) is/are used for list/tuple. - - Key(s) is/are used for dict. - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper, ZipperMapDataPipe - >>> dp1 = IterableWrapper(range(5)) - >>> dp2 = IterableWrapper(range(10, 15)) - >>> dp = dp1.zip(dp2) - >>> list(dp) - [(0, 10), (1, 11), (2, 12), (3, 13), (4, 14)] - >>> drop_dp = dp.drop(1) - >>> list(drop_dp) - [(0), (1), (2), (3), (4)] - """ - datapipe: IterDataPipe - - def __init__( - self, - datapipe: IterDataPipe, - indices: Union[Hashable, List[Hashable]], - ) -> None: - super().__init__() - self.datapipe = datapipe - if isinstance(indices, list): - self.indices = set(indices) - else: - self.indices = {indices} - - def __iter__(self) -> Iterator[T_co]: - for old_item in self.datapipe: - if isinstance(old_item, tuple): - new_item = tuple(x for i, x in enumerate(old_item) if i not in self.indices) # type: ignore[assignment] - elif isinstance(old_item, list): - new_item = [x for i, x in enumerate(old_item) if i not in self.indices] # type: ignore[assignment] - elif isinstance(old_item, dict): - new_item = {k: v for (k, v) in old_item.items() if k not in self.indices} # type: ignore[assignment] - else: - new_item = old_item - warnings.warn( - "The next item was not an iterable and cannot be filtered, " - "please be aware that no filter was done or new item created." - ) - - # check to make sure all indices requested were in the item. warn if not - try: - for i in self.indices: - old_item[i] - except (IndexError, KeyError): - warnings.warn( - "At least one index in the filter is not present in the item being returned," - " please be aware that expected columns/keys may be missing." - ) - - yield new_item # type: ignore[misc] - - def __len__(self) -> int: - if isinstance(self.datapipe, Sized): - return len(self.datapipe) - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") - - -@functional_datapipe("slice") -class SliceIterDataPipe(IterDataPipe[T_co]): - r""" - returns a slice of elements in input DataPipe via start/stop/step or indices (functional name: ``slice``). - - Args: - datapipe: IterDataPipe with iterable elements - index: a single start index for the slice or a list of indices to be returned instead of a start/stop slice - - - Integer(s) is/are used for list/tuple. - - Key(s) is/are used for dict. - - - stop: the slice stop. ignored if index is a list or if element is a dict - step: step to be taken from start to stop. ignored if index is a list or if element is a dict - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp = IterableWrapper([(0, 10, 100), (1, 11, 111), (2, 12, 122), (3, 13, 133), (4, 14, 144)]) - >>> slice_dp = dp.slice(0, 2) - >>> list(slice_dp) - [(0, 10), (1, 11), (2, 12), (3, 13), (4, 14)] - """ - datapipe: IterDataPipe - - def __init__( - self, - datapipe: IterDataPipe, - index: Union[int, List[Hashable]], - stop: Optional[int] = None, - step: Optional[int] = None, - ) -> None: - super().__init__() - self.datapipe = datapipe - - self.index = index - self.stop = stop - self.step = step - - if isinstance(index, list): - if stop or step: - warnings.warn( - "A list of indices was passed as well as a stop or step for the slice, " - "these arguments can't be used together so only the indices list will be used." - ) - - def __iter__(self) -> Iterator[T_co]: - for old_item in self.datapipe: - if isinstance(old_item, tuple): - if isinstance(self.index, list): - new_item = tuple(x for i, x in enumerate(old_item) if i in self.index) # type: ignore[assignment] - else: - new_item = old_item[self.index : self.stop : self.step] # type: ignore[assignment] - elif isinstance(old_item, list): - if isinstance(self.index, list): - new_item = [x for i, x in enumerate(old_item) if i in self.index] # type: ignore[assignment] - else: - new_item = old_item[self.index : self.stop : self.step] # type: ignore[assignment] - elif isinstance(old_item, dict): - if isinstance(self.index, list): - new_item = {k: v for (k, v) in old_item.items() if k in self.index} # type: ignore[assignment] - elif self.index in old_item.keys(): - new_item = {self.index: old_item.get(self.index)} # type: ignore[assignment] - else: - new_item = old_item # type: ignore[assignment] - warnings.warn( - "Dictionaries are not sliced by steps, only direct index. " - "Please be aware that no filter was done or new item created." - ) - else: - new_item = old_item # type: ignore[assignment] - warnings.warn( - "The next item was not an iterable and cannot be filtered, " - "please be aware that no filter was done or new item created." - ) - - if isinstance(self.index, list): - # check to make sure all indices requested were in the item. warn if not - try: - for i in self.index: - old_item[i] - except (IndexError, KeyError): - warnings.warn( - "At least one index in the filter is not present in the item being returned," - " please be aware that expected columns/keys may be missing." - ) - - yield new_item # type: ignore[misc] - - def __len__(self) -> int: - if isinstance(self.datapipe, Sized): - return len(self.datapipe) - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") - - -@functional_datapipe("flatten") -class FlattenIterDataPipe(IterDataPipe[T_co]): - r""" - returns a flattened copy of the input DataPipe at the per sample/element level based on provided indices (functional name: ``flatten``). - - Note: - no args will flatten each item in the datapipe 1 level - - Args: - datapipe: IterDataPipe with iterable elements - indices: a single index/key for the item to flatten from an iterator item or a list of indices/keys to be flattened - - - Integer(s) is/are used for list/tuple. - - Key(s) is/are used for dict. - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp = IterableWrapper([(0, 10, (100, 1000)), (1, 11, (111, 1001)), (2, 12, (122, 1002)), (3, 13, (133, 1003)), (4, 14, (144, 1004))]) - >>> flatten_dp = dp.flatten(2) - >>> list(flatten_dp) - [(0, 10, 100, 1000), (1, 11, 111, 1001), (2, 12, 122, 1002), (3, 13, 133, 1003), (4, 14, 144, 1004)] - >>> - >>> dp = IterableWrapper([(0, (1, 2)), (3, (4, 5)), (6, (7, 8))]) - >>> flatten_dp = dp.flatten() - >>> list(flatten_dp) - [(0, 1, 2), (3, 4, 5), (6, 7, 8)] - """ - datapipe: IterDataPipe - indices: Set[Hashable] = set() - - def __init__( - self, - datapipe: IterDataPipe, - indices: Optional[Union[Hashable, List[Hashable]]] = None, - ) -> None: - super().__init__() - self.datapipe = datapipe - if indices: - if isinstance(indices, list): - self.indices = set(indices) - else: - self.indices = {indices} - - def __iter__(self) -> Iterator[T_co]: - flatten_all = False - if not self.indices: - flatten_all = True - for old_item in self.datapipe: - if isinstance(old_item, dict): - new_item = {} # type: ignore[assignment] - for k, v in old_item.items(): - if k in self.indices: - pass - if (flatten_all or (k in self.indices)) and isinstance(v, dict): - for k_sub, v_sub in v.items(): - if k_sub not in old_item: - new_item[k_sub] = v_sub - else: - warnings.warn( - "Flattener tried to insert the same key twice into the dict item," - "the second key,value pair has been dropped." - ) - else: - if k not in new_item: - new_item[k] = v - else: - warnings.warn( - "Flattener tried to insert the same key twice into the dict item," - "the second key,value pair has been dropped." - ) - else: - is_tuple = False - new_item = [] # type: ignore[assignment] - if isinstance(old_item, tuple): - is_tuple = True - old_item = list(old_item) - for i, item in enumerate(old_item): - if (flatten_all or (i in self.indices)) and isinstance(item, (list, tuple)): - new_item.extend(list(item)) # type: ignore[attr-defined] - else: - new_item.append(item) # type: ignore[attr-defined] - if is_tuple: - new_item = tuple(new_item) # type: ignore[assignment] - - # check to make sure all indices requested were in the item. warn if not - try: - if self.indices: - for index in self.indices: - old_item[index] - except (IndexError, KeyError): - warnings.warn( - "At least one index in the filter is not present in the item being returned," - " please be aware that expected columns/keys may be missing." - ) - yield new_item # type: ignore[misc] - - def __len__(self) -> int: - if isinstance(self.datapipe, Sized): - return len(self.datapipe) - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") - - -class _BatchAsyncMapperIterDataPipe(IterDataPipe): - datapipe: IterDataPipe - async_fn: Callable - - def __init__( - self, - source_datapipe: IterDataPipe, - async_fn: Callable, - input_col=None, - output_col=None, - max_concurrency: int = 32, - ): - self.source_datapipe = source_datapipe - if not inspect.iscoroutinefunction(async_fn): - raise ValueError(f"Expected a corotine function with an async def syntax, but got a {type(async_fn)}") - self.async_fn = async_fn # type: ignore[assignment] - if input_col is None and output_col is not None: - raise ValueError("`output_col` must be None when `input_col` is None.") - self.input_col = input_col - if isinstance(output_col, (list, tuple)): - if len(output_col) > 1: - raise ValueError("`output_col` must be a single-element list or tuple") - output_col = output_col[0] - self.output_col = output_col - self.max_concurrency = max_concurrency - - def __iter__(self): - policy = asyncio.get_event_loop_policy() - loop = policy.new_event_loop() - try: - for batch in self.source_datapipe: - policy.set_event_loop(loop) - new_batch = loop.run_until_complete(self.processbatch(batch)) - yield new_batch - finally: - loop.run_until_complete(loop.shutdown_asyncgens()) - loop.close() - - async def processbatch(self, batch): - sem = asyncio.Semaphore(self.max_concurrency) - - async def controlled_async_fn(async_fn, *data): - async with sem: - return await async_fn(*data) - - coroutines = [] - if self.input_col is None: - for data in batch: - coroutines.append(controlled_async_fn(self.async_fn, data)) - results = await asyncio.gather(*coroutines) - return results - - for data in batch: - if isinstance(self.input_col, (list, tuple)): - args = tuple(data[col] for col in self.input_col) - coroutines.append(controlled_async_fn(self.async_fn, *args)) - else: - coroutines.append(controlled_async_fn(self.async_fn, data[self.input_col])) - results = await asyncio.gather(*coroutines) - - new_batch = [] - for data, res in zip(batch, results): - t_flag = isinstance(data, tuple) - if t_flag: - data = list(data) - - if self.output_col is None: - if isinstance(self.input_col, (list, tuple)): - data[self.input_col[0]] = res - for idx in sorted(self.input_col[1:], reverse=True): - del data[idx] - else: - data[self.input_col] = res - elif self.output_col == -1: - data.append(res) - else: - data[self.output_col] = res - - if t_flag: - data = tuple(data) - - new_batch.append(data) - return new_batch - - def __len__(self): - return len(self.source_datapipe) - - -@functional_datapipe("async_map_batches") -class BatchAsyncMapperIterDataPipe(IterDataPipe): - r""" - Combines elements from the source DataPipe to batches and applies a coroutine function - over each element within the batch concurrently, then flattens the outpus to a - single, unnested IterDataPipe (functional name: ``async_map_batches``). - - Args: - source_datapipe: Source IterDataPipe - async_fn: The coroutine function to be applied to each batch of data - batch_size: The size of batch to be aggregated from ``source_datapipe`` - input_col: Index or indices of data which ``fn`` is applied, such as: - - - ``None`` as default to apply ``fn`` to the data directly. - - Integer(s) is used for list/tuple. - - Key(s) is used for dict. - - output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified - only when ``input_col`` is not ``None`` - - - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with - multiple indices, the left-most one is used, and other indices will be removed. - - Integer is used for list/tuple. ``-1`` represents to append result at the end. - - Key is used for dict. New key is acceptable. - - max_concurrency: Maximum concurrency to call async functions. (Default: ``32``) - flatten: Determine if the batches get flatten in the end (Default: ``True``) - If ``False``, outputs will be in batches of size ``batch_size`` - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> async def mul_ten(x): - ... await asyncio.sleep(1) - ... return x * 10 - >>> dp = IterableWrapper(range(50)) - >>> dp = dp.async_map_batches(mul_ten, 16) - >>> list(dp) - [0, 10, 20, 30, ...] - >>> dp = IterableWrapper([(i, i) for i in range(50)]) - >>> dp = dp.async_map_batches(mul_ten, 16, input_col=1) - >>> list(dp) - [(0, 0), (1, 10), (2, 20), (3, 30), ...] - >>> dp = IterableWrapper([(i, i) for i in range(50)]) - >>> dp = dp.async_map_batches(mul_ten, 16, input_col=1, output_col=-1) - >>> list(dp) - [(0, 0, 0), (1, 1, 10), (2, 2, 20), (3, 3, 30), ...] - # Async fetching html from remote - >>> from aiohttp import ClientSession - >>> async def fetch_html(url: str, **kwargs): - ... async with ClientSession() as session: - ... resp = await session.request(method="GET", url=url, **kwargs) - ... resp.raise_for_status() - ... html = await resp.text() - ... return html - >>> dp = IterableWrapper(urls) - >>> dp = dp.async_map_batches(fetch_html, 16) - """ - - def __new__( - self, - source_datapipe, - async_fn: Callable, - batch_size: int, - input_col=None, - output_col=None, - max_concurrency: int = 32, - flatten: bool = True, - ): - dp = source_datapipe.batch(batch_size) - dp = _BatchAsyncMapperIterDataPipe(dp, async_fn, input_col, output_col, max_concurrency) - if flatten: - dp = dp.flatmap() - try: - source_length = len(source_datapipe) - if isinstance(source_length, int) and source_length >= 0: - dp = dp.set_length(source_length) - except (TypeError, NotImplementedError): - pass - return dp - - -@functional_datapipe("threadpool_map") -class ThreadPoolMapperIterDataPipe(IterDataPipe[T_co]): - r""" - Applies a function over each item from the source DataPipe concurrently - using ``ThreadPoolExecutor`` (functional name: ``threadpool_map``). - The function can be any regular Python function or partial object. Lambda - function is not recommended as it is not supported by pickle. - - Args: - source_datapipe: Source IterDataPipe - fn: Function being applied over each item - input_col: Index or indices of data which ``fn`` is applied, such as: - - - ``None`` as default to apply ``fn`` to the data directly. - - Integer(s) is used for list/tuple. - - Key(s) is used for dict. - - output_col: Index of data where result of ``fn`` is placed. ``output_col`` can be specified - only when ``input_col`` is not ``None`` - - - ``None`` as default to replace the index that ``input_col`` specified; For ``input_col`` with - multiple indices, the left-most one is used, and other indices will be removed. - - Integer is used for list/tuple. ``-1`` represents to append result at the end. - - Key is used for dict. New key is acceptable. - - scheduled_tasks: How many tasks will be scheduled at any given time (Default value: 128) - max_workers: Maximum number of threads to execute function calls - **threadpool_kwargs: additional arguments to be given to the ``ThreadPoolExecutor`` - - Note: - For more information about ``max_workers`` and additional arguments for the ``ThreadPoolExecutor`` - please refer to: https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor - - Note: - For optimal use of all threads, ``scheduled_tasks`` > ``max_workers`` is strongly recommended. The higher the - variance of the time needed to finish execution of the given ``fn`` is, the higher the value - of ``scheduled_tasks`` needs to be to avoid threads sitting idle while waiting - for the next result (as results are returned in correct order). - - However, too high value of ``scheduled_tasks`` might lead to long waiting period until the first element is yielded - as ``next`` is called ``scheduled_tasks`` many times on ``source_datapipe`` before yielding. - - We encourage you to try out different values of ``max_workers`` and ``scheduled_tasks`` - in search for optimal values for your use-case. - - Example: - - .. testsetup:: - - from torchdata.datapipes.iter import IterableWrapper - import requests - import time - from unittest.mock import MagicMock - - requests.get = MagicMock() - urls = [] - - .. testcode:: - - # fetching html from remote - def fetch_html(url: str, **kwargs): - r = requests.get(url, **kwargs) - r.raise_for_status() - return r.content - dp = IterableWrapper(urls) - dp = dp.threadpool_map(fetch_html,max_workers=16) - - .. testcode:: - - def mul_ten(x): - time.sleep(0.1) - return x * 10 - - dp = IterableWrapper([(i, i) for i in range(50)]) - dp = dp.threadpool_map(mul_ten, input_col=1) - print(list(dp)) - - .. testoutput:: - - [(0, 0), (1, 10), (2, 20), (3, 30), ...] - - .. testcode:: - - dp = IterableWrapper([(i, i) for i in range(50)]) - dp = dp.threadpool_map(mul_ten, input_col=1, output_col=-1) - print(list(dp)) - - .. testoutput:: - - [(0, 0, 0), (1, 1, 10), (2, 2, 20), (3, 3, 30), ...] - - """ - - datapipe: IterDataPipe - fn: Callable - - def __init__( - self, - source_datapipe: IterDataPipe, - fn: Callable, - input_col=None, - output_col=None, - scheduled_tasks: int = 128, - max_workers: Optional[int] = None, - **threadpool_kwargs, - ) -> None: - super().__init__() - self.datapipe = source_datapipe - - _check_unpickable_fn(fn) - self.fn = fn # type: ignore[assignment] - - if scheduled_tasks <= 0: - raise ValueError("'scheduled_tasks' is required to be a positive integer.") - self.scheduled_tasks = scheduled_tasks - if max_workers is not None and max_workers <= 0: - raise ValueError("'max_workers' is required to be a positive integer.") - self.max_workers = max_workers - self.threadpool_kwargs = threadpool_kwargs - - self.input_col = input_col - if input_col is None and output_col is not None: - raise ValueError("`output_col` must be None when `input_col` is None.") - if isinstance(output_col, (list, tuple)): - if len(output_col) > 1: - raise ValueError("`output_col` must be a single-element list or tuple") - output_col = output_col[0] - self.output_col = output_col - validate_input_col(fn, input_col) - - def _apply_fn(self, data): - if self.input_col is None and self.output_col is None: - return self.fn(data) - - if self.input_col is None: - res = self.fn(data) - elif isinstance(self.input_col, (list, tuple)): - args = tuple(data[col] for col in self.input_col) - res = self.fn(*args) - else: - res = self.fn(data[self.input_col]) - - # Copy tuple to list and run in-place modification because tuple is immutable. - if isinstance(data, tuple): - t_flag = True - data = list(data) - else: - t_flag = False - - if self.output_col is None: - if isinstance(self.input_col, (list, tuple)): - data[self.input_col[0]] = res - for idx in sorted(self.input_col[1:], reverse=True): - del data[idx] - else: - data[self.input_col] = res - else: - if self.output_col == -1: - data.append(res) - else: - data[self.output_col] = res - - # Convert list back to tuple - return tuple(data) if t_flag else data - - def __iter__(self) -> Iterator[T_co]: - with futures.ThreadPoolExecutor(max_workers=self.max_workers, **self.threadpool_kwargs) as executor: - futures_deque: deque = deque() - has_next = True - itr = iter(self.datapipe) - for _ in range(self.scheduled_tasks): - try: - futures_deque.append(executor.submit(self._apply_fn, next(itr))) - except StopIteration: - has_next = False - break - - while len(futures_deque) > 0: - if has_next: - try: - futures_deque.append(executor.submit(self._apply_fn, next(itr))) - except StopIteration: - has_next = False - yield futures_deque.popleft().result() - - def __len__(self) -> int: - if isinstance(self.datapipe, Sized): - return len(self.datapipe) - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/torchdata/datapipes/iter/util/__init__.py b/torchdata/datapipes/iter/util/__init__.py deleted file mode 100644 index 2e41cd717..000000000 --- a/torchdata/datapipes/iter/util/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torchdata/datapipes/iter/util/bz2fileloader.py b/torchdata/datapipes/iter/util/bz2fileloader.py deleted file mode 100644 index 554b4d8d6..000000000 --- a/torchdata/datapipes/iter/util/bz2fileloader.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import bz2 -import warnings -from io import BufferedIOBase -from typing import Iterable, Iterator, Tuple - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -from torchdata.datapipes.utils import StreamWrapper -from torchdata.datapipes.utils.common import validate_pathname_binary_tuple - - -@functional_datapipe("load_from_bz2") -class Bz2FileLoaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]): - r""" - Decompresses bz2 binary streams from an Iterable DataPipe which contains tuples of - path name and bz2 binary streams, and yields a tuple of path name and extracted binary - stream (functional name: ``load_from_bz2``). - - Args: - datapipe: Iterable DataPipe that provides tuples of path name and bz2 binary stream - length: Nominal length of the DataPipe - - Note: - The opened file handles will be closed automatically if the default ``DecoderDataPipe`` - is attached. Otherwise, user should be responsible to close file handles explicitly - or let Python's GC close them periodically. - - Example: - >>> from torchdata.datapipes.iter import FileLister, FileOpener - >>> datapipe1 = FileLister(".", "*.bz2") - >>> datapipe2 = FileOpener(datapipe1, mode="b") - >>> bz2_loader_dp = datapipe2.load_from_bz2() - >>> for _, stream in bz2_loader_dp: - >>> print(stream.read()) - b'0123456789abcdef' - """ - - def __init__(self, datapipe: Iterable[Tuple[str, BufferedIOBase]], length: int = -1) -> None: - super().__init__() - self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe - self.length: int = length - - def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: - for data in self.datapipe: - validate_pathname_binary_tuple(data) - pathname, data_stream = data - try: - extracted_fobj = bz2.open(data_stream, mode="rb") # type: ignore[call-overload] - new_pathname = pathname.rstrip(".bz2") # https://github.com/pytorch/data/issues/1240 - yield new_pathname, StreamWrapper(extracted_fobj, data_stream, name=new_pathname) # type: ignore[misc] - except Exception as e: - warnings.warn(f"Unable to extract files from corrupted bzip2 stream {pathname} due to: {e}, abort!") - raise e - finally: - if isinstance(data_stream, StreamWrapper): - data_stream.autoclose() - - def __len__(self) -> int: - if self.length == -1: - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") - return self.length diff --git a/torchdata/datapipes/iter/util/cacheholder.py b/torchdata/datapipes/iter/util/cacheholder.py deleted file mode 100644 index ca4c705b5..000000000 --- a/torchdata/datapipes/iter/util/cacheholder.py +++ /dev/null @@ -1,577 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import hashlib -import inspect -import os.path -import sys -import time -import uuid -import warnings - -from collections import deque -from functools import partial -from typing import Any, Callable, Deque, Dict, Iterator, List, Optional, Tuple, TypeVar - -try: - import portalocker -except ImportError: - portalocker = None - -from torch.utils._import_utils import dill_available -from torch.utils.data.datapipes.utils.common import _check_unpickable_fn - -from torch.utils.data.graph import traverse_dps -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterableWrapper, IterDataPipe - -if dill_available(): - import dill - - dill.extend(use_dill=False) - - -def _assert_portalocker() -> None: - try: - import portalocker # noqa: F401 - except ImportError as e: - if os.name == "nt" and str(e).startswith("DLL load failed while importing"): - print( - "Please take a look at FAQ in https://github.com/pytorch/data#frequently-asked-questions-faq" - "for the solution of this Error." - ) - raise - else: - raise ModuleNotFoundError( - "Package `portalocker` is required to be installed to use this datapipe." - "Please use `pip install 'portalocker>=2.0.0'` or" - "`conda install -c conda-forge 'portalocker>=2.0.0'`" - "to install the package" - ) - - -T_co = TypeVar("T_co", covariant=True) - -PROMISE_FILE_DELETE_TIMEOUT = 30 -PROMISE_FILE_DELETE_RETRY_INTERVAL = 0.005 - -from enum import IntEnum - - -class CacheState(IntEnum): - UNCACHED = 0 - CACHED_SINGLE_ENTITY = 1 - CACHED_MULTIPLE_ENTITIES = 2 - - -@functional_datapipe("in_memory_cache") -class InMemoryCacheHolderIterDataPipe(IterDataPipe[T_co]): - r""" - Stores elements from the source DataPipe in memory, up to a size limit - if specified (functional name: ``in_memory_cache``). This cache is FIFO - once the cache is full, - further elements will not be added to the cache until the previous ones are yielded and popped off from the cache. - - Args: - source_dp: source DataPipe from which elements are read and stored in memory - size: The maximum size (in megabytes) that this DataPipe can hold in memory. This defaults to unlimited. - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> source_dp = IterableWrapper(range(10)) - >>> cache_dp = source_dp.in_memory_cache(size=5) - >>> list(cache_dp) - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - """ - size: Optional[int] = None - idx: int - - def __init__(self, source_dp: IterDataPipe[T_co], size: Optional[int] = None) -> None: - self.source_dp: IterDataPipe[T_co] = source_dp - # cache size in MB - if size is not None: - self.size = size * 1024 * 1024 - self.cache: Optional[Deque] = None - self.idx: int = 0 - - def __iter__(self) -> Iterator[T_co]: - if self.cache: - if self.idx > 0: - for idx, data in enumerate(self.source_dp): - if idx < self.idx: - yield data - else: - break - yield from self.cache - else: - # Local cache - cache: Deque = deque() - idx = 0 - for data in self.source_dp: - cache.append(data) - # Cache reaches limit - if self.size is not None and sys.getsizeof(cache) > self.size: - cache.popleft() - idx += 1 - yield data - self.cache = cache - self.idx = idx - - def __len__(self) -> int: - try: - return len(self.source_dp) - except TypeError: - if self.cache: - return self.idx + len(self.cache) - else: - raise TypeError(f"{type(self).__name__} instance doesn't have valid length until the cache is loaded.") - - -def _generator_to_list(gen_fn): - def list_fn(*args, **kwargs): - gen = gen_fn(*args, **kwargs) - return list(gen) - - return list_fn - - -def _hash_check(filepath, hash_dict, hash_type): - - if filepath not in hash_dict: - return False - - if hash_type == "sha256": - hash_func = hashlib.sha256() - else: - hash_func = hashlib.md5() - - # with portalocker.Lock(filepath, "rb", flags=portalocker.LockFlags.SHARED) as f: - # TODO(634): Line above will require all readers (Win) to obtain proper locks, - # I'm putting it on hold as we need to modify PyTorch core codebase heavily. - with open(filepath, "rb") as f: - chunk = f.read(1024 ** 2) - while chunk: - hash_func.update(chunk) - chunk = f.read(1024 ** 2) - - return hash_func.hexdigest() == hash_dict[filepath] - - -def _promise_filename(filename, cache_uuid): - return filename + ".promise." + str(cache_uuid) - - -@functional_datapipe("on_disk_cache") -class OnDiskCacheHolderIterDataPipe(IterDataPipe): - """ - Caches the outputs of multiple DataPipe operations to local files, which are - typically performance bottleneck such download, decompress, and etc (functional name: ``on_disk_cache``). - - Must use ``.end_caching()`` to stop tracing the sequence of DataPipe operations and save the results to local files. - - Args: - source_datapipe: IterDataPipe - filepath_fn: Given data from ``source_datapipe``, returns file path(s) on local file system. - Single file path is only allowed as output of the function. - If resulted file name is different from the filename generated by the filename function of the end_cache - original file name used to store list of yield files (and as cached items availability check) - hash_dict: A Dictionary mapping file names to their corresponding hashes. If ``hash_dict`` is specified, - the extra hash check will be attached before saving data to local file system. If the data - doesn't meet the hash, the pipeline will raise an Error. - hash_type: The type of hash function to apply - extra_check_fn: Optional function to carry out extra validation on - the given file path from ``filepath_fn``. - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper, HttpReader - >>> url = IterableWrapper(["https://path/to/filename", ]) - >>> def _filepath_fn(url): - >>> temp_dir = tempfile.gettempdir() - >>> return os.path.join(temp_dir, os.path.basename(url)) - >>> hash_dict = {"expected_filepath": expected_MD5_hash} - >>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5") - >>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files. - >>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb", filepath_fn=_filepath_fn) - """ - - _temp_dict: Dict = {} - - def __init__( - self, - source_datapipe: IterDataPipe, - filepath_fn: Optional[Callable] = None, - hash_dict: Optional[Dict[str, str]] = None, - hash_type: str = "sha256", - extra_check_fn: Optional[Callable[[str], bool]] = None, - ): - _assert_portalocker() - - self.source_datapipe = source_datapipe - - if filepath_fn is not None: - _check_unpickable_fn(filepath_fn) - assert not inspect.isgeneratorfunction(filepath_fn) # BC breaking, now only str is accepted as return - - if hash_dict is not None and hash_type not in ("sha256", "md5"): - raise ValueError("Invalid hash_type requested, should be one of {}".format(("sha256", "md5"))) - - # TODO(VitalyFedyunin): We need some way to generate pipe uuids which will have similar result for - # same graph but different nodes of distributed system - self._uuid = uuid.uuid4() - - OnDiskCacheHolderIterDataPipe._temp_dict[self] = (filepath_fn, hash_dict, hash_type, extra_check_fn, self._uuid) - - self._end_caching_flag: bool = False - self._download_everything = False # This is internal field used for load testing only - - def __iter__(self): - if self._end_caching_flag: - yield from self.source_datapipe - else: - # In case of BC breaking, use RuntimeError for now. Warning is another option - raise RuntimeError("Please call `end_caching()` before iteration.") - - def __add__(self, other_datapipe): - raise RuntimeError("`OnDiskCacheHolder` doesn't support add operation") - - # Since Demux is using this function, we should not attach it to OnDiskCacheHolder instance. - # Otherwise, it would cause infinite recursion in graph traversal - @staticmethod - def _cache_check_fn(data, filepath_fn, hash_dict, hash_type, extra_check_fn, cache_uuid): - filepath = data if filepath_fn is None else filepath_fn(data) - assert not isinstance(filepath, (list, tuple)) # BC breaking, now only str is accepted as return - - result = CacheState.CACHED_SINGLE_ENTITY - cached_file_exists = True - if os.path.exists(_get_list_filename(filepath)): - return int(CacheState.CACHED_MULTIPLE_ENTITIES) - if not os.path.exists(filepath): - cached_file_exists = False - elif hash_dict is not None and not _hash_check(filepath, hash_dict, hash_type): - # TODO: It is safer to assume that entire cache is compromised and require user to wipe it - cached_file_exists = False - elif extra_check_fn is not None and not extra_check_fn(filepath): - # TODO: It is safer to assume that entire cache is compromised and require user to wipe it - cached_file_exists = False - if not cached_file_exists: - promise_filepath = _promise_filename(filepath, cache_uuid) - dirname = os.path.dirname(promise_filepath) - os.makedirs(dirname, exist_ok=True) - - with portalocker.Lock(promise_filepath, "a+", flags=portalocker.LockFlags.EXCLUSIVE) as promise_fh: - promise_fh.seek(0) - data = promise_fh.read() - # TODO(VitalyFedyunin): Potentially there is old .promise file from previous failed run, we - # need to somehow propagate uniq session id for dataloader, save and compare it here, - # raising error - file_exists = len(data) > 0 - if not file_exists: - result = CacheState.UNCACHED - promise_fh.seek(0) - data = promise_fh.read() - # TODO(635): Potentially there is old .promise file from previous failed run, we - # need to somehow propagate uniq session id for dataloader, save and compare it here, - # raising error - file_exists = len(data) > 0 - if not file_exists: - promise_fh.seek(0) - promise_fh.write("[dataloader session uid]") - promise_fh.truncate() - promise_fh.flush() - - return int(result) - - def _end_caching(self): - filepath_fn, hash_dict, hash_type, extra_check_fn, cache_uuid = OnDiskCacheHolderIterDataPipe._temp_dict.pop( - self - ) - - todo_dp: Any - cached_dp: Any - one_many_cached_dp: Any - - if self._download_everything: - - todo_dp = self.source_datapipe - cached_dp = IterableWrapper([]) - one_many_cached_dp = IterableWrapper([]) - - else: - - todo_dp, cached_dp, one_many_cached_dp = self.source_datapipe.demux( - 3, - partial( - OnDiskCacheHolderIterDataPipe._cache_check_fn, - filepath_fn=filepath_fn, - hash_dict=hash_dict, - hash_type=hash_type, - extra_check_fn=extra_check_fn, - cache_uuid=cache_uuid, - ), - ) - # Cached: keep filepath(s) - cached_dp = cached_dp.map(fn=filepath_fn) - - one_many_cached_dp = one_many_cached_dp.map(fn=filepath_fn) - one_many_cached_dp = _ExtractFilesFromList(one_many_cached_dp) - - self.source_datapipe = todo_dp.memory_cell() - self._end_caching_flag = True - return cached_dp, one_many_cached_dp - - -def _read_bytes(fd): - return b"".join(fd) - - -def _read_str(fd): - return "".join(fd) - - -def _is_promise_pending(promise_filename): - return os.path.exists(promise_filename) - - -class _WaitPendingCacheItemIterDataPipe(IterDataPipe): - def __init__(self, source_datapipe, timeout=300, input_col=None, cache_uuid=None): - self.source_datapipe = source_datapipe - self.timeout = timeout - self.input_col = input_col - self._cache_uuid = cache_uuid - - def set_timeout(self, timeout): - self.timeout = timeout - - def __iter__(self): - for data in self.source_datapipe: - if self.input_col is not None: - filename = data[self.input_col] - else: - filename = data - promise_filename = _promise_filename(filename, self._cache_uuid) - start = time.time() - while _is_promise_pending(promise_filename): - time.sleep(0.01) - if time.time() - start > self.timeout: - raise Exception( - f"OnDiskCache Exception: {filename} expected to be written by different process, " - + f"but file is not ready in {self.timeout} seconds." - ) - yield data - - -@functional_datapipe("memory_cell") -class _MemoryCellIterDataPipe(IterDataPipe): - def __init__(self, source_datapipe, remember_elements=1000): - self.source_datapipe = source_datapipe - self.buffer: List[Optional[Tuple[Any, Any]]] = [None for i in range(remember_elements)] - self.remember_elements = remember_elements - self.buffer_pos = -1 - # TODO(VitalyFedyunin): Make it friendly to save/restore state - - def __iter__(self): - for item in self.source_datapipe: - item_id = uuid.uuid4() - self.buffer_pos = (self.buffer_pos + 1) % self.remember_elements - self.buffer[self.buffer_pos] = (item_id, item) - yield item - - def get_last(self): - # Returns tuple of elements, autogenerated id of the last returned row and its value - return self.buffer[self.buffer_pos] - - def get_buffer(self): - # Returns last returned id+element and others in the order from latest to oldest. - result = [] - for i in range(self.remember_elements): - idx = (self.buffer_pos - i) % self.remember_elements - if self.buffer[idx] is not None: - result.append(self.buffer[idx]) - return result - - -def _get_list_filename(file_name): - return file_name + ".torchdata_list" - - -class _ExtractFilesFromList(IterDataPipe): - def __init__(self, source_datapipe): - self.source_datapipe = source_datapipe - - def __iter__(self): - for filename in self.source_datapipe: - with open(_get_list_filename(filename)) as fh: - for line in fh: - inner_file_name = line.rstrip() - yield filename, inner_file_name - - -class _FulfilledPromisesIterDataPipe(IterDataPipe): - def __init__(self, source_datapipe, memory_cell_dp, first_filepath_fn, cache_uuid): - self.source_datapipe = source_datapipe - self.memory_cell_dp = memory_cell_dp - self.first_filepath_fn = first_filepath_fn - self._cache_uuid = cache_uuid - - @staticmethod - def _del_promise_file(promise_filename, filename): - if os.path.exists(promise_filename): - retry = True - start = time.time() - while retry: - retry = False - try: - os.unlink(promise_filename) - except Exception as e: - # Workaround about Windows not letting to delete file, while it is open by another process - retry = True - if time.time() - start > PROMISE_FILE_DELETE_TIMEOUT: - raise Exception("Timeout while trying to recover from the ", type(e), e) - time.sleep(PROMISE_FILE_DELETE_RETRY_INTERVAL) - else: - warnings.warn( - f"Attempt to mark {promise_filename} promise (base of file {filename}) as fulfilled failed. Potentially missmatching filename functions of on_disk_cache and end_cache." - ) - - def __iter__(self): - last_record_uuid = None - one_to_many_detected = False - one_to_one_detected = False - - def fulfill_old_promises(buffer, last_record_uuid, first_filepath_fn, cache_uuid): - for old_rec_uuid, old_rec in buffer: - original_file_name = first_filepath_fn(old_rec) - old_promise_filename = _promise_filename(original_file_name, cache_uuid) - self._del_promise_file(old_promise_filename, original_file_name) - if old_rec_uuid == last_record_uuid: - break - # TODO(VitalyFedyunin): If no match found, that means we exceeded length of memory_cell - # and there is aggressive amount 1-to-zero cases, raise error and explain how to fix - - try: - - for filename in self.source_datapipe: - rec_uuid, record = self.memory_cell_dp.get_last() - original_file_name = self.first_filepath_fn(record) - # TODO(VitalyFedyunin): For debug mode we can detect duplicate keys situations here and warn user - if original_file_name != filename: - # Situations when every archive unpacks to single file only are also considered as 1-M - one_to_many_detected = True - if one_to_one_detected: - raise Exception("Disovered different keys when one-to-one mode previously assumed") - # We are dealing with one-to-many situation now - with open(_get_list_filename(original_file_name), "a") as fh: - fh.write(f"{filename}\n") - else: - one_to_one_detected = True - if one_to_many_detected: - # Keys should be always the same (1-1 situation) or always different (1-many) sutuation - raise Exception("first key somehow equal to secondary key") - if rec_uuid != last_record_uuid: - fulfill_old_promises( - self.memory_cell_dp.get_buffer()[1:], last_record_uuid, self.first_filepath_fn, self._cache_uuid - ) - last_record_uuid = rec_uuid - yield filename - finally: - if last_record_uuid is not None: - fulfill_old_promises( - self.memory_cell_dp.get_buffer(), last_record_uuid, self.first_filepath_fn, self._cache_uuid - ) - - -def _leave_second(x): - return x[1] - - -@functional_datapipe("end_caching") -class EndOnDiskCacheHolderIterDataPipe(IterDataPipe): - """ - Indicates when the result of prior DataPipe will be saved local files specified - by ``filepath_fn`` (functional name: ``end_caching``). Moreover, the result of source DataPipe - is required to be a tuple of metadata and data, or a tuple of metadata and file handle. - - Args: - datapipe: IterDataPipe with at least one ``OnDiskCacheHolder`` in the graph. - mode: Mode in which the cached files are opened to write the data on disk. This is needed - to be aligned with the type of data or file handle from ``datapipe``. ``"wb"`` is used by default. - filepath_fn: Optional function to extract filepath from the metadata from ``datapipe``. - By default, it would directly use the ?metadata? as file path. - same_filepath_fn: Set to ``True`` to use same ``filepath_fn`` from the ``OnDiskCacheHolder``. - skip_read: Boolean value to skip reading the file handle from ``datapipe``. - By default, reading is enabled and reading function is created based on the ``mode``. - timeout: Integer value of seconds to wait for uncached item to be written to disk - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper, HttpReader - >>> url = IterableWrapper(["https://path/to/filename", ]) - >>> def _filepath_fn(url): - >>> temp_dir = tempfile.gettempdir() - >>> return os.path.join(temp_dir, os.path.basename(url)) - >>> hash_dict = {"expected_filepath": expected_MD5_hash} - >>> # You must call ``.on_disk_cache`` at some point before ``.end_caching`` - >>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5") - >>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files. - >>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb", filepath_fn=_filepath_fn) - """ - - def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=False, skip_read=False, timeout=300): - if filepath_fn is not None and same_filepath_fn: - raise ValueError("`filepath_fn` is mutually exclusive with `same_filepath_fn`") - - graph = traverse_dps(datapipe) - # Get the last CacheHolder - cache_holder = EndOnDiskCacheHolderIterDataPipe._recursive_search(graph) - if cache_holder is None: - raise RuntimeError("Expected `OnDiskCacheHolder` existing in pipeline when `end_caching` is invoked") - if cache_holder._end_caching_flag: - raise RuntimeError("`end_caching` can only be invoked once per `OnDiskCacheHolder`") - - first_filepath_fn, _hash_dict, _hash_type, _, cache_uuid = OnDiskCacheHolderIterDataPipe._temp_dict[ - cache_holder - ] - cached_dp, one_many_cached_dp = cache_holder._end_caching() - cached_dp = _WaitPendingCacheItemIterDataPipe(cached_dp, timeout=timeout, cache_uuid=cache_uuid) - one_many_cached_dp = _WaitPendingCacheItemIterDataPipe( - one_many_cached_dp, timeout=timeout, cache_uuid=cache_uuid, input_col=0 - ) - one_many_cached_dp = one_many_cached_dp.map(_leave_second) - memory_cell_dp = cache_holder.source_datapipe - - if same_filepath_fn: - filepath_fn = first_filepath_fn - - todo_dp = datapipe - if not skip_read: - if "t" in mode: - todo_dp = todo_dp.map(fn=_read_str, input_col=1) - else: - todo_dp = todo_dp.map(fn=_read_bytes, input_col=1) - - if filepath_fn is not None: - todo_dp = todo_dp.map(fn=filepath_fn, input_col=0) - - # Extra hash check here when hash is provided. - # And, raise Error if data returned from prior operations doesn't meet hash - if _hash_dict is not None: - todo_dp = todo_dp.check_hash(_hash_dict, _hash_type) - - todo_dp = todo_dp.save_to_disk(mode=mode) - todo_dp = _FulfilledPromisesIterDataPipe(todo_dp, memory_cell_dp, first_filepath_fn, cache_uuid=cache_uuid) - - # TODO(VitalyFedyunin): This impacts determinism for partial cache situations - return todo_dp.concat(cached_dp).concat(one_many_cached_dp) - - @staticmethod - def _recursive_search(graph): - for dp, _ in graph.values(): - # Find the closest CacheHolder - if isinstance(dp, OnDiskCacheHolderIterDataPipe): - return dp - for _, sub_graph in graph.values(): - res = EndOnDiskCacheHolderIterDataPipe._recursive_search(sub_graph) - if res is not None: - return res - return None diff --git a/torchdata/datapipes/iter/util/combining.py b/torchdata/datapipes/iter/util/combining.py deleted file mode 100644 index ad98e4ff1..000000000 --- a/torchdata/datapipes/iter/util/combining.py +++ /dev/null @@ -1,384 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import warnings - -from collections import OrderedDict -from typing import Callable, final, Iterator, List, Optional, Sequence, TypeVar - -from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe -from torch.utils.data.datapipes.iter.combining import _ChildDataPipe, _DemultiplexerIterDataPipe, _ForkerIterDataPipe -from torch.utils.data.datapipes.utils.common import _check_unpickable_fn - -from torchdata.datapipes.utils.janitor import janitor - -T_co = TypeVar("T_co", covariant=True) -T = TypeVar("T") - - -@functional_datapipe("zip_with_iter") -class IterKeyZipperIterDataPipe(IterDataPipe[T_co]): - r""" - Zips two IterDataPipes together based on the matching key (functional name: ``zip_with_iter``). The keys - are computed by ``key_fn`` and ``ref_key_fn`` for the two IterDataPipes, respectively. When there isn't a match - between the elements of the two IterDataPipes, the element from ``ref_datapipe`` is stored in a buffer. Then, the - next element from ``ref_datapipe`` is tried. After a match is found, the ``merge_fn`` determines how they will - be combined and returned (a tuple is generated by default). - - Args: - source_datapipe: IterKeyZipper will yield data based on the order of this IterDataPipe - ref_datapipe: Reference IterDataPipe from which IterKeyZipper will find items - with matching key for ``source_datapipe`` - key_fn: Callable function that will compute keys using elements from ``source_datapipe`` - ref_key_fn: Callable function that will compute keys using elements from ``ref_datapipe`` - If it's not specified, the ``key_fn`` will also be applied to elements from ``ref_datapipe`` - keep_key: Option to yield the matching key along with the items in a tuple, - resulting in `(key, merge_fn(item1, item2))`. - buffer_size: The size of buffer used to hold key-data pairs from reference DataPipe until a match is found. - If it's specified as ``None``, the buffer size is set as infinite. - merge_fn: Function that combines the item from ``source_datapipe`` and the item from ``ref_datapipe``, - by default a tuple is created - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> from operator import itemgetter - >>> def merge_fn(t1, t2): - >>> return t1[1] + t2[1] - >>> dp1 = IterableWrapper([('a', 100), ('b', 200), ('c', 300)]) - >>> dp2 = IterableWrapper([('a', 1), ('b', 2), ('c', 3), ('d', 4)]) - >>> res_dp = dp1.zip_with_iter(dp2, key_fn=itemgetter(0), - >>> ref_key_fn=itemgetter(0), keep_key=True, merge_fn=merge_fn) - >>> list(res_dp) - [('a', 101), ('b', 202), ('c', 303)] - """ - - def __init__( - self, - source_datapipe: IterDataPipe, - ref_datapipe: IterDataPipe, - key_fn: Callable, - ref_key_fn: Optional[Callable] = None, - keep_key: bool = False, - buffer_size: int = 10000, - merge_fn: Optional[Callable] = None, - ) -> None: - if not isinstance(ref_datapipe, IterDataPipe): - raise TypeError(f"ref_datapipe must be a IterDataPipe, but its type is {type(ref_datapipe)} instead.") - self.source_datapipe = source_datapipe - self.ref_datapipe = ref_datapipe - _check_unpickable_fn(key_fn) - self.key_fn = key_fn - if ref_key_fn is not None: - _check_unpickable_fn(ref_key_fn) - self.ref_key_fn = key_fn if ref_key_fn is None else ref_key_fn - self.keep_key = keep_key - if merge_fn is not None: - _check_unpickable_fn(merge_fn) - self.merge_fn = merge_fn - if buffer_size is not None and buffer_size <= 0: - raise ValueError("'buffer_size' is required to be either None or a positive integer.") - self.buffer_size: int = buffer_size - self.buffer: OrderedDict = OrderedDict() - - def __iter__(self) -> Iterator: - ref_it = iter(self.ref_datapipe) - warn_once_flag = True - try: - for data in self.source_datapipe: - key = self.key_fn(data) - while key not in self.buffer: - try: - ref_data = next(ref_it) - except StopIteration: - raise BufferError( - f"No matching key can be found from reference DataPipe for the data {data}. " - "Please consider increasing the buffer size." - ) - ref_key = self.ref_key_fn(ref_data) - if ref_key in self.buffer: - raise ValueError("Duplicate key is found in reference DataPipe") - if self.buffer_size is not None and len(self.buffer) > self.buffer_size: - if warn_once_flag: - warn_once_flag = False - warnings.warn( - "Buffer reaches the upper limit, so reference key-data pair begins to " - "be removed from buffer in FIFO order. Please consider increase buffer size." - ) - self.buffer.popitem(last=False) - self.buffer[ref_key] = ref_data - res = self.merge_fn(data, self.buffer.pop(key)) if self.merge_fn else (data, self.buffer.pop(key)) - if self.keep_key: - yield key, res - else: - yield res - finally: - del ref_it - # TODO(633): This should be Exception or warn when debug mode is enabled - if self.buffer: - for _, v in self.buffer.items(): - janitor(v) - self.buffer.clear() - - def __len__(self) -> int: - return len(self.source_datapipe) - - @final - def reset(self) -> None: - self.buffer = OrderedDict() - - def __getstate__(self): - state = ( - self.source_datapipe, - self.ref_datapipe, - self.key_fn, - self.ref_key_fn, - self.keep_key, - self.merge_fn, - self.buffer_size, - ) - if IterDataPipe.getstate_hook is not None: - return IterDataPipe.getstate_hook(state) - return state - - def __setstate__(self, state): - ( - self.source_datapipe, - self.ref_datapipe, - self.key_fn, - self.ref_key_fn, - self.keep_key, - self.merge_fn, - self.buffer_size, - ) = state - self.buffer = OrderedDict() - - def __del__(self): - if self.buffer: - for _, v in self.buffer.items(): - janitor(v) - self.buffer.clear() - - -@functional_datapipe("zip_with_map") -class MapKeyZipperIterDataPipe(IterDataPipe[T_co]): - r""" - Joins the items from the source IterDataPipe with items from a MapDataPipe (functional name: ``zip_with_map``). - The matching is done by the provided ``key_fn``, which maps an item from ``source_iterdatapipe`` to - a key that should exist in the ``map_datapipe``. The return value is created by the ``merge_fn``, which returns - a tuple of the two items by default. - - Args: - source_iterdatapipe: IterDataPipe from which items are yield and will be combined with an item - from ``map_datapipe`` - map_datapipe: MapDataPipe that takes a key from ``key_fn``, and returns an item - key_fn: Function that maps each item from ``source_iterdatapipe`` to a key that exists in ``map_datapipe`` - keep_key: Option to yield the matching key along with the items in a tuple, - resulting in ``(key, merge_fn(item1, item2))``. - merge_fn: Function that combines the item from ``source_iterdatapipe`` and the matching item - from ``map_datapipe``, by default a tuple is created - - Example: - - .. testsetup:: - - from operator import itemgetter - - .. testcode:: - - from torchdata.datapipes.iter import IterableWrapper - from torchdata.datapipes.map import SequenceWrapper - - - def merge_fn(tuple_from_iter, value_from_map): - return tuple_from_iter[0], tuple_from_iter[1] + value_from_map - - - dp1 = IterableWrapper([('a', 1), ('b', 2), ('c', 3)]) - mapdp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400}) - res_dp = dp1.zip_with_map(map_datapipe=mapdp, key_fn=itemgetter(0), merge_fn=merge_fn) - - print(list(res_dp)) - - Output: - - .. testoutput:: - - [('a', 101), ('b', 202), ('c', 303)] - - """ - - def __init__( - self, - source_iterdatapipe: IterDataPipe, - map_datapipe: MapDataPipe, - key_fn: Callable, - merge_fn: Optional[Callable] = None, - keep_key: bool = False, - ): - if not isinstance(map_datapipe, MapDataPipe): - raise TypeError(f"map_datapipe must be a MapDataPipe, but its type is {type(map_datapipe)} instead.") - self.source_iterdatapipe: IterDataPipe = source_iterdatapipe - self.map_datapipe: MapDataPipe = map_datapipe - _check_unpickable_fn(key_fn) - self.key_fn: Callable = key_fn - if merge_fn is not None: - _check_unpickable_fn(merge_fn) - self.merge_fn: Optional[Callable] = merge_fn - self.keep_key = keep_key - - def __iter__(self) -> Iterator: - for item in self.source_iterdatapipe: - key = self.key_fn(item) - try: - map_item = self.map_datapipe[key] - except (KeyError, IndexError): - raise KeyError(f"key_fn maps {item} to {key}, which is not a valid key in the given MapDataPipe.") - res = self.merge_fn(item, map_item) if self.merge_fn else (item, map_item) - if self.keep_key: - yield key, res - else: - yield res - - def __len__(self) -> int: - return len(self.source_iterdatapipe) - - -def _drop_index(idx_data): - _, data = idx_data - return data - - -@functional_datapipe("round_robin_demux") -class RoundRobinDemultiplexerIterDataPipe(IterDataPipe): - r""" - Splits the input DataPipe into multiple child DataPipes in the round-robin order (functional name: ``round_robin_demux``). - A list of the child DataPipes is returned from this operation. - - Args: - datapipe: Iterable DataPipe being filtered - num_instances: number of instances of the DataPipe to create - buffer_size: this defines the maximum number of inputs that the buffer can hold across all child - DataPipes while waiting for their values to be yielded. - Defaults to ``1000``. Use ``-1`` for the unlimited buffer. - - Examples: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> source_dp = IterableWrapper(range(5)) - >>> dp1, dp2 = source_dp.round_robin_demux(2) - >>> list(dp1) - [0, 2, 4] - >>> len(dp1) - 3 - >>> list(dp2) - [1, 3] - >>> len(dp2) - 2 - """ - - def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = 1000): - if num_instances < 1: - raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found") - if num_instances == 1: - warnings.warn( - "The operation of `round_robin_demux` with `num_instances=1` is an no-op and returns the provided `datapipe` in a list directly" - ) - return [datapipe] - - datapipe = datapipe.enumerate() - container = _RoundRobinDemultiplexerIterDataPipe(datapipe, num_instances, buffer_size=buffer_size) # type: ignore - return [_ChildDataPipe(container, i).map(_drop_index) for i in range(num_instances)] - - -class _RoundRobinDemultiplexerIterDataPipe(_DemultiplexerIterDataPipe): - def __init__(self, datapipe: IterDataPipe[T_co], num_instances: int, buffer_size: int): - super().__init__(datapipe, num_instances, self._round_robin_fn, drop_none=False, buffer_size=buffer_size) - - def _round_robin_fn(self, idx_data) -> int: - idx, _ = idx_data - return idx % self.num_instances - - def get_length_by_instance(self, instance_id: int) -> int: - n = len(self.main_datapipe) - avg_length = n // self.num_instances - return avg_length + 1 if n - avg_length * self.num_instances > instance_id else avg_length - - -@functional_datapipe("unzip") -class UnZipperIterDataPipe(IterDataPipe[T]): - r""" - Takes in a DataPipe of Sequences, unpacks each Sequence, and return the elements in separate DataPipes - based on their position in the Sequence (functional name: ``unzip``). The number of instances produced equals to - the sequence length minus the number of columns to skip. - - Note: - Each sequence within the DataPipe should have the same length, specified by - the input argument `sequence_length`. - - Args: - source_datapipe: Iterable DataPipe with sequences of data - sequence_length: Length of the sequence within the source_datapipe. All elements should have the same length. - buffer_size: this restricts how far ahead the leading child DataPipe can read relative - to the slowest child DataPipe. Use -1 for the unlimited buffer. - columns_to_skip: optional indices of columns that the DataPipe should skip (each index should be - an integer from 0 to sequence_length - 1) - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> source_dp = IterableWrapper([(i, i + 10, i + 20) for i in range(3)]) - >>> dp1, dp2, dp3 = source_dp.unzip(sequence_length=3) - >>> list(dp1) - [0, 1, 2] - >>> list(dp2) - [10, 11, 12] - >>> list(dp3) - [20, 21, 22] - """ - - def __new__( - cls, - source_datapipe: IterDataPipe[Sequence[T]], - sequence_length: int, - buffer_size: int = 1000, - columns_to_skip: Optional[Sequence[int]] = None, - ): - if columns_to_skip is None: - instance_ids = list(range(sequence_length)) - else: - skips = set(columns_to_skip) - instance_ids = [i for i in range(sequence_length) if i not in skips] - - if len(instance_ids) == 0: - raise RuntimeError( - "All instances are being filtered out in UnZipperIterDataPipe. Please check" - "the input `sequence_length` and `columns_to_skip`." - ) - - # The implementation basically uses Forker but only yields a specific element within the sequence - container = _UnZipperIterDataPipe(source_datapipe, instance_ids, buffer_size) # type: ignore - return [_ChildDataPipe(container, i) for i in range(len(instance_ids))] - - -class _UnZipperIterDataPipe(_ForkerIterDataPipe): - def __init__(self, datapipe: IterDataPipe, instance_ids: List[int], buffer_size: int = 1000): - super().__init__(datapipe, len(instance_ids), buffer_size) # type: ignore[arg-type] - self.instance_ids = instance_ids - - def get_next_element_by_instance(self, instance_id: int): - r""" - Note: - Each element returned from the source datapipe is required to be a sequnce that can - be subscribed with a column index - """ - for return_val in super().get_next_element_by_instance(instance_id): - yield return_val[self.instance_ids[instance_id]] - - def __getstate__(self): - state = super().__getstate__() - return (*state, self.instance_ids) - - def __setstate__(self, state): - super().__setstate__(state[:-1]) - self.instance_ids = state[-1] diff --git a/torchdata/datapipes/iter/util/converter.py b/torchdata/datapipes/iter/util/converter.py deleted file mode 100644 index 1f5e7c9ef..000000000 --- a/torchdata/datapipes/iter/util/converter.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import warnings - -from typing import Callable, Dict, Optional - -from torch.utils._import_utils import dill_available - -from torch.utils.data import IterDataPipe, MapDataPipe -from torch.utils.data.datapipes.utils.common import _check_unpickable_fn - -if dill_available(): - import dill - - dill.extend(use_dill=False) - - -# @functional_datapipe("to_map_datapipe") # This line must be kept for .pyi signature parser -class IterToMapConverterMapDataPipe(MapDataPipe): - r""" - Lazily load data from ``IterDataPipe`` to construct a ``MapDataPipe`` with - the key-value pair generated by ``key_value_fn`` (functional name: ``to_map_datapipe``). - If ``key_value_fn`` is not given, each data from the source IterDataPipe must itself be an iterable - with exactly two objects. The first object of each item becomes a key in - the new dictionary, and the second object the corresponding value. - - For the opposite converter, use :class:`.MapToIterConverter`. - - Args: - datapipe: Source IterDataPipe - key_value_fn: Function being applied over each data to generate key-value pair - - Note: - If a key being added is already present, the corresponding value - will be replaced by the new value. - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> source_dp = IterableWrapper([(i, i) for i in range(10)]) - >>> map_dp = source_dp.to_map_datapipe() - >>> list(map_dp) - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - >>> source_dp2 = IterableWrapper([('a', 1), ('b', 2), ('c', 1)]) - >>> map_dp2 = source_dp2.to_map_datapipe() - >>> map_dp2['a'] - 1 - >>> def row_to_tuple(row): - >>> label = row[0] - >>> data = row[1:] - >>> return label, data - >>> source_dp3 = IterableWrapper([('a', 1, 1, 1, 1, 1, 1), ('b', 2, 2, 2, 2, 2, 2), ('c', 3, 3, 3, 3, 3, 3)]) - >>> map_dp3 = source_dp3.to_map_datapipe(key_value_fn=row_to_tuple) - >>> map_dp3['a'] - (1, 1, 1, 1, 1, 1) - """ - datapipe: IterDataPipe - key_value_fn: Optional[Callable] - _map: Optional[Dict] - _length: int - - def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = None): - if not isinstance(datapipe, IterDataPipe): - raise TypeError(f"IterToMapConverter can only apply on IterDataPipe, but found {type(datapipe)}") - self.datapipe = datapipe - if key_value_fn is not None: - _check_unpickable_fn(key_value_fn) - self.key_value_fn = key_value_fn # type: ignore[assignment] - self._map = None - - def _load_map(self): - self._map = {} - for d in self.datapipe: - inp = d if self.key_value_fn is None else self.key_value_fn(d) - try: - length = len(inp) - except TypeError: - raise TypeError(f"Cannot convert dictionary update element {type(inp)} ({inp}) to a sequence") - if length != 2: - raise ValueError(f"dictionary update sequence element has length {length}, 2 is required") - key, value = inp - if key in self._map: - warnings.warn(f"Found duplicate key {key}. Please check your `key_value_fn`") - self._map[key] = value - - def __getitem__(self, index): - try: - if self._map is None: - self._load_map() - return self._map[index] # type: ignore[index] - except KeyError: - raise IndexError(f"Index {index} is invalid for IterToMapConverter.") - - def __len__(self): - if self._map is not None: - return len(self._map) # type: ignore[arg-type] - try: - return len(self.datapipe) - except (TypeError, NotImplementedError): - pass - warnings.warn( - "Data from prior DataPipe are loaded to get length of" - "IterToMapConverter before execution of the pipeline." - "Please consider removing len()." - ) - self._load_map() - return len(self._map) # type: ignore[arg-type] - - def __getstate__(self): - if dill_available(): - dill_key_value_fn = dill.dumps(self.key_value_fn) - else: - dill_key_value_fn = self.key_value_fn - return ( - self.datapipe, - dill_key_value_fn, - self._map, - ) - - def __setstate__(self, state): - (self.datapipe, dill_key_value_fn, self._map) = state - if dill_available(): - self.key_value_fn = dill.loads(dill_key_value_fn) # type: ignore[assignment] - else: - self.key_value_fn = dill_key_value_fn # type: ignore[assignment] - - -# Register for functional API -# See https://github.com/pytorch/data/issues/200 -IterDataPipe.register_datapipe_as_function("to_map_datapipe", IterToMapConverterMapDataPipe) diff --git a/torchdata/datapipes/iter/util/cycler.py b/torchdata/datapipes/iter/util/cycler.py deleted file mode 100644 index 851872b3c..000000000 --- a/torchdata/datapipes/iter/util/cycler.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Iterator, Optional, TypeVar - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -T_co = TypeVar("T_co", covariant=True) - - -@functional_datapipe("cycle") -class CyclerIterDataPipe(IterDataPipe[T_co]): - """ - Cycles the specified input in perpetuity by default, or for the specified number - of times (functional name: ``cycle``). - - If the ordering does not matter (e.g. because you plan to ``shuffle`` later) or if you would like to - repeat an element multiple times before moving onto the next element, use :class:`.Repeater`. - - Args: - source_datapipe: source DataPipe that will be cycled through - count: the number of times to read through ``source_datapipe` (if ``None``, it will cycle in perpetuity) - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp = IterableWrapper(range(3)) - >>> dp = dp.cycle(2) - >>> list(dp) - [0, 1, 2, 0, 1, 2] - """ - - def __init__(self, source_datapipe: IterDataPipe[T_co], count: Optional[int] = None) -> None: - self.source_datapipe: IterDataPipe[T_co] = source_datapipe - self.count: Optional[int] = count - if count is not None and count < 0: - raise ValueError(f"Expected non-negative count, got {count}") - - def __iter__(self) -> Iterator[T_co]: - i = 0 - while self.count is None or i < self.count: - yield from self.source_datapipe - i += 1 - - def __len__(self) -> int: - if self.count is None: - raise TypeError( - f"This {type(self).__name__} instance cycles forever, and therefore doesn't have valid length" - ) - else: - return self.count * len(self.source_datapipe) - - -@functional_datapipe("repeat") -class RepeaterIterDataPipe(IterDataPipe[T_co]): - """ - Repeatedly yield each element of source DataPipe for the specified number of times before - moving onto the next element (functional name: ``repeat``). Note that no copy is made in this DataPipe, - the same element is yielded repeatedly. - - If you would like to yield the whole DataPipe in order multiple times, use :class:`.Cycler`. - - Args: - source_datapipe: source DataPipe that will be iterated through - times: the number of times an element of ``source_datapipe`` will be yielded before moving onto the next element - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp = IterableWrapper(range(3)) - >>> dp = dp.repeat(2) - >>> list(dp) - [0, 0, 1, 1, 2, 2] - """ - - def __init__(self, source_datapipe: IterDataPipe[T_co], times: int) -> None: - self.source_datapipe: IterDataPipe[T_co] = source_datapipe - self.times: int = times - if times <= 1: - raise ValueError(f"The number of repetition must be > 1, got {times}") - - def __iter__(self) -> Iterator[T_co]: - for element in self.source_datapipe: - for _ in range(self.times): - yield element - - def __len__(self) -> int: - return self.times * len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/util/dataframemaker.py b/torchdata/datapipes/iter/util/dataframemaker.py deleted file mode 100644 index a7e5c27b0..000000000 --- a/torchdata/datapipes/iter/util/dataframemaker.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from functools import partial -from typing import List, Optional, TypeVar - -from torch.utils._import_utils import dill_available - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -try: # TODO(637): Create dependency on TorchArrow? - import pyarrow.parquet as parquet - import torcharrow -except ImportError: - torcharrow = None - parquet = None - -if dill_available(): - import dill - - dill.extend(use_dill=False) - -T_co = TypeVar("T_co") - - -def _construct_dataframe(data, dtype=None, dtype_generator=None, columns=None, device=None): - if dtype is None: - dtype = dtype_generator() - return torcharrow.dataframe(data, dtype=dtype, columns=columns, device=device) - - -@functional_datapipe("dataframe") -class DataFrameMakerIterDataPipe(IterDataPipe): # IterDataPipe[torcharrow.IDataFrame[T_co]] - r""" - Takes rows of data, batches a number of them together and creates `TorchArrow` - DataFrames (functional name: ``dataframe``). - - Note: - There is a trade-off between having a large number of rows within a DataFrame and usage of memory. Please - choose a value carefully. - - Args: - source_dp: IterDataPipe containing rows of data - dataframe_size: number of rows of data within each DataFrame, page size can be option - dtype: specify the `TorchArrow` dtype for the DataFrame, use ``torcharrow.dtypes.DType`` - dtype_generator: function with no input argument that generates a torcharrow.dtypes.DType, - which overrides dtype if both are given. This is useful for when the desired dtype is - not serializable. - columns: List of str that specifies the column names of the DataFrame - device: specify the device on which the DataFrame will be stored - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> import torcharrow.dtypes as dt - >>> source_data = [(i,) for i in range(3)] - >>> source_dp = IterableWrapper(source_data) - >>> DTYPE = dt.Struct([dt.Field("Values", dt.int32)]) - >>> df_dp = source_dp.dataframe(dtype=DTYPE) - >>> list(df_dp)[0] - index Values - ------- -------- - 0 0 - 1 1 - 2 2 - dtype: Struct([Field('Values', int32)]), count: 3, null_count: 0 - """ - - def __new__( - cls, - source_dp: IterDataPipe[T_co], - dataframe_size: int = 1000, - dtype=None, - dtype_generator=None, - columns: Optional[List[str]] = None, - device: str = "", - ): - if torcharrow is None: - raise ImportError( - "The library 'torcharrow' is necessary for this DataPipe but it is not available." - "Please visit https://github.com/facebookresearch/torcharrow/ to install it." - ) - # In this version, DF tracing is not available, which would allow DataPipe to run DataFrame operations - batch_dp = source_dp.batch(dataframe_size) - df_dp = batch_dp.map( - partial(_construct_dataframe, dtype=dtype, dtype_generator=dtype_generator, columns=columns, device=device) - ) - return df_dp - - -@functional_datapipe("load_parquet_as_df") -class ParquetDFLoaderIterDataPipe(IterDataPipe): # IterDataPipe[torcharrow.IDataFrame[T_co]] - r""" - Takes in paths to Parquet files and return a `TorchArrow` DataFrame for each row group - within a Parquet file (functional name: ``load_parquet_as_df``). - - Args: - source_dp: source DataPipe containing paths to the Parquet files - columns: List of `str` that specifies the column names of the DataFrame - use_threads: if ``True``, Parquet reader will perform multi-threaded column reads - dtype: specify the `TorchArrow` dtype for the DataFrame, use ``torcharrow.dtypes.DType`` - device: specify the device on which the DataFrame will be stored - - Example: - >>> from torchdata.datapipes.iter import FileLister - >>> import torcharrow.dtypes as dt - >>> DTYPE = dt.Struct([dt.Field("Values", dt.int32)]) - >>> source_dp = FileLister(".", masks="df*.parquet") - >>> parquet_df_dp = source_dp.load_parquet_as_df(dtype=DTYPE) - >>> list(parquet_df_dp)[0] - index Values - ------- -------- - 0 0 - 1 1 - 2 2 - dtype: Struct([Field('Values', int32)]), count: 3, null_count: 0 - """ - - def __init__( - self, - source_dp: IterDataPipe[str], - dtype=None, - columns: Optional[List[str]] = None, - device: str = "", - use_threads: bool = False, - ): - if torcharrow is None: - raise ImportError( - "The library 'torcharrow' is necessary for this DataPipe but it is not available." - "Please visit https://github.com/facebookresearch/torcharrow/ to install it." - ) - if parquet is None: - raise ImportError("The library 'parquet' is necessary for this DataPipe but it is not available.") - self.source_dp = source_dp - self.columns = columns - self.use_threads = use_threads - self.dtype = dtype - self.device = device - - def __iter__(self): - for path in self.source_dp: - parquet_file = parquet.ParquetFile(path) - num_row_groups = parquet_file.num_row_groups - for i in range(num_row_groups): - # TODO(638): More fine-grain control over the number of rows or row group per DataFrame - row_group = parquet_file.read_row_group(i, columns=self.columns, use_threads=self.use_threads) - yield torcharrow.from_arrow(row_group, dtype=self.dtype) - - def __getstate__(self): - if dill_available(): - dill_dtype = dill.dumps(self.dtype) - else: - dill_dtype = self.dtype - state = (self.source_dp, dill_dtype, self.columns, self.device, self.use_threads) - if IterDataPipe.getstate_hook is not None: - return IterDataPipe.getstate_hook(state) - return state - - def __setstate__(self, state): - (self.source_dp, dill_dtype, self.columns, self.device, self.use_threads) = state - if dill_available(): - self.dtype = dill.loads(dill_dtype) # type: ignore[assignment] - else: - self.dtype = dill_dtype # type: ignore[assignment] diff --git a/torchdata/datapipes/iter/util/decompressor.py b/torchdata/datapipes/iter/util/decompressor.py deleted file mode 100644 index aafcb7144..000000000 --- a/torchdata/datapipes/iter/util/decompressor.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import bz2 -import gzip -import lzma -import os -import pathlib -import tarfile -import zipfile - -from enum import Enum -from io import IOBase -from typing import Iterator, Optional, Tuple, Union - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe -from torchdata.datapipes.utils import StreamWrapper - - -class CompressionType(Enum): - GZIP = "gzip" - LZMA = "lzma" - TAR = "tar" - ZIP = "zip" - BZIP2 = "bz2" - - -@functional_datapipe("decompress") -class DecompressorIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): - r""" - Takes tuples of path and compressed stream of data, and returns tuples of - path and decompressed stream of data (functional name: ``decompress``). The input compression format can be specified - or automatically detected based on the files' file extensions. - - Args: - source_datapipe: IterDataPipe containing tuples of path and compressed stream of data - file_type: Optional `string` or ``CompressionType`` that represents what compression format of the inputs - - Example: - >>> from torchdata.datapipes.iter import FileLister, FileOpener - >>> tar_file_dp = FileLister(self.temp_dir.name, "*.tar") - >>> tar_load_dp = FileOpener(tar_file_dp, mode="b") - >>> tar_decompress_dp = Decompressor(tar_load_dp, file_type="tar") - >>> for _, stream in tar_decompress_dp: - >>> print(stream.read()) - b'0123456789abcdef' - """ - - types = CompressionType - - _DECOMPRESSORS = { - types.GZIP: lambda file: gzip.GzipFile(fileobj=file), - types.LZMA: lambda file: lzma.LZMAFile(file), - types.TAR: lambda file: tarfile.open(fileobj=file, mode="r:*"), - types.ZIP: lambda file: zipfile.ZipFile(file=file), - types.BZIP2: lambda file: bz2.BZ2File(filename=file), - } - - def __init__( - self, source_datapipe: IterDataPipe[Tuple[str, IOBase]], file_type: Optional[Union[str, CompressionType]] = None - ) -> None: - self.source_datapipe: IterDataPipe[Tuple[str, IOBase]] = source_datapipe - if isinstance(file_type, str): - file_type = self.types(file_type.lower()) - self.file_type: Optional[CompressionType] = file_type - - def _detect_compression_type(self, path: str) -> CompressionType: - if self.file_type: - return self.file_type - - ext = "".join(pathlib.Path(path).suffixes) - if ext in {".tar.gz", ".tar.xz"}: - return self.types.TAR - else: - ext = os.path.splitext(path)[1] - if ext == ".tar": - return self.types.TAR - elif ext == ".xz": - return self.types.LZMA - elif ext == ".gz": - return self.types.GZIP - elif ext == ".zip": - return self.types.ZIP - elif ext == ".bz2": - return self.types.BZIP2 - else: - raise RuntimeError( - f"File at {path} has file extension {ext}, which does not match what are supported by" - f"ExtractorIterDataPipe." - ) - - def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: - for path, file in self.source_datapipe: - try: - file_type = self._detect_compression_type(path) - decompressor = self._DECOMPRESSORS[file_type] - yield path, StreamWrapper(decompressor(file), file, name=path) - finally: - if isinstance(file, StreamWrapper): - file.autoclose() - - -@functional_datapipe("extract") -class ExtractorIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): - r""" - Please use ``Decompressor`` or ``.decompress`` instead. - """ - - def __new__( - cls, source_datapipe: IterDataPipe[Tuple[str, IOBase]], file_type: Optional[Union[str, CompressionType]] = None - ): - return DecompressorIterDataPipe(source_datapipe, file_type) diff --git a/torchdata/datapipes/iter/util/distributed.py b/torchdata/datapipes/iter/util/distributed.py deleted file mode 100644 index 3fc310875..000000000 --- a/torchdata/datapipes/iter/util/distributed.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import threading - -import time - -from collections import deque -from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError -from dataclasses import dataclass -from functools import partial -from typing import Callable, Deque, final, Iterator, Optional, TypeVar - -import torch -import torch.distributed as dist - -from torchdata._constants import default_timeout_in_s -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe -from torchdata.datapipes.iter.util.prefetcher import PRODUCER_SLEEP_INTERVAL - - -T_co = TypeVar("T_co", covariant=True) - - -__all__ = ["Expected", "FullSyncIterDataPipe", "PrefetchTimeoutError"] - - -class PrefetchTimeoutError(RuntimeError): - def __init__(self, timeout: int) -> None: - super().__init__(f"Fail to fetch data within {timeout} seconds") - self.timeout = timeout - - -class _EndOfPrefetch: - ... - - -@dataclass -class Expected: - r""" - Expected data provided to callback function in ``_PrefetchExecutor``. - """ - index: int - error: Optional[BaseException] = None - - def has_error(self) -> bool: - return self.error is not None - - -class _PrefetchExecutor: - # TODO: Improvement - merge with the `_PrefetchData` class of prefetcher.py - # May not be possible right now due to circular import - def __init__( - self, - datapipe_iterator: Iterator, - prefetch_size: int = 1, - callback_fn: Optional[Callable[[Expected], None]] = None, - timeout: int = default_timeout_in_s, - ) -> None: - self.datapipe_iterator = datapipe_iterator - self.prefetch_size = prefetch_size - self.callback_fn = callback_fn - self.timeout = timeout - # Use max_workers as 1 to guarantee the order of data fetched from iterator - self._executor = ThreadPoolExecutor(max_workers=1) - self._futures: Deque[Future] = deque() - self._lock = threading.RLock() - # `_end_flag` indicates the end of epoch or an exception has been raised, - # with the exception being handled by `callback_fn` - self._end_flag: bool = False - self._paused: bool = False - self._is_shutdown: bool = False # indicates if `_executor` has been shutdown by `shutdown` method - self._idx = 0 - for _ in range(prefetch_size): - with self._lock: - if self._end_flag: - break - fetch_future: Future = self._executor.submit(self.fetch_next) - fetch_future.add_done_callback(partial(self._done_callback_fn, self._idx)) - self._futures.append(fetch_future) - with self._lock: - self._idx += 1 - - def fetch_next(self): - while self._paused: - time.sleep(PRODUCER_SLEEP_INTERVAL * 10) - return next(self.datapipe_iterator) - - def _done_callback_fn(self, index: int, f: Future): - if f.exception(): - with self._lock: - self._end_flag = True - if self.callback_fn is not None: - # Invoke `callback_fn` in order to set `FullSyncDP._done_callback` to `True` - self.callback_fn(Expected(index, f.exception())) - - def return_next(self): - if self._futures: - fetch_future = self._futures.popleft() - try: - data = fetch_future.result(timeout=self.timeout) - except TimeoutError: - raise PrefetchTimeoutError(self.timeout) - with self._lock: - if not self._end_flag and not self._is_shutdown: - next_future = self._executor.submit(self.fetch_next) - next_future.add_done_callback(partial(self._done_callback_fn, self._idx)) - self._futures.append(next_future) - self._idx += 1 - else: - data = _EndOfPrefetch() - return data - - def shutdown(self): - self._paused = False - self._is_shutdown = True - while self._futures: - self._futures.popleft().cancel() - self._executor.shutdown(wait=True) - - def pause(self): - self._paused = True - - def resume(self): - self._paused = False - - -@functional_datapipe("fullsync") -class FullSyncIterDataPipe(IterDataPipe[T_co]): - r""" - Synchronizes data across distributed processes to prevent hanging during training, - which is caused by uneven sharded data (functional name: ``fullsync``). It stops - when the shortest distributed shard is exhausted. It would be appended at the end - of the graph of ``DataPipe`` by ``DistributedReadingService`` automatically. - - Args: - datapipe: IterDataPipe that needs to be synchronized - timeout: Timeout for prefetching data in seconds. Default value equals to 30 minutes - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> # Distributed training with world size 2 - >>> world_size = 2 - >>> dp = IterableWrapper(list(range(23))).sharding_filter() - >>> torch.utils.data.graph_settings.apply_sharding(dp, world_size, rank) - >>> # Rank 0 has 12 elements; Rank 1 has 11 elements - >>> for d in dp: - ... model(d) # Hanging at the end of epoch due to uneven sharding - >>> dp = dp.fullsync() - >>> # Both ranks have 11 elements - >>> for d in dp: - ... model(d) # Not hanging anymore - """ - - def __init__(self, datapipe: IterDataPipe, timeout=default_timeout_in_s): - if not dist.is_available(): - raise RuntimeError("Torch Distributed is required to be available") - self.datapipe = datapipe - self.timeout: int = timeout - - self._process_group: Optional[dist.ProcessGroup] = None - self._world_size: int = 1 - - self._lock = threading.RLock() - self._cv = threading.Condition(lock=self._lock) - self._executor: Optional[_PrefetchExecutor] = None - # Use single values rather than deques for the following variables - # because fullsync only prefetches 1 element - self._error = None - self._sync_counter = torch.tensor([0], dtype=torch.int32) - self._done_callback = False - - def _callback_fn(self, exp: Expected) -> None: - with self._cv: - if exp.has_error(): - if not isinstance(exp.error, StopIteration): - self._error = exp.error # type: ignore[assignment] - self._sync_counter = torch.tensor([0], dtype=torch.int32) - else: - self._sync_counter = torch.tensor([1], dtype=torch.int32) - dist.all_reduce( - tensor=self._sync_counter, - op=dist.ReduceOp.SUM, - group=self._process_group, - ) - self._done_callback = True - self._cv.notify() - - def __iter__(self) -> Iterator[T_co]: - assert self._executor is None - if not (dist.is_available() and dist.is_initialized()): - raise RuntimeError("Torch Distributed is required to be initialized to use `FullSync`.") - - if self._process_group is None: - self._process_group = dist.new_group(backend="gloo") - self._world_size = dist.get_world_size() - - if self._world_size == 1: # The below functionalities are not needed if `_world_size == 1` - yield from self.datapipe - return - - self._executor = _PrefetchExecutor(iter(self.datapipe), 1, self._callback_fn, self.timeout) - while True: - with self._cv: - is_success = self._cv.wait_for( - lambda: self._done_callback is True, - self.timeout, - ) - if not is_success: - raise PrefetchTimeoutError(self.timeout) - if self._error is not None: - raise self._error - if bool(self._sync_counter < self._world_size): - break - self._done_callback = False - data = self._executor.return_next() # type: ignore[attr-defined] - if isinstance(data, _EndOfPrefetch): - break - yield data - - @final - def reset(self): - if self._executor is not None: - self._executor.shutdown() - self._executor = None - self._world_size = 1 - with self._cv: - self._error = None - self._sync_counter = torch.tensor([0], dtype=torch.int32) - self._done_callback = False - - def is_replicable(self): - return False - - def __getstate__(self): - state = ( - self.datapipe, - self.timeout, - ) - if IterDataPipe.getstate_hook is not None: - return IterDataPipe.getstate_hook(state) - return state - - def __setstate__(self, state): - self.datapipe, self.timeout = state - self._process_group = None - self._world_size = 1 - self._lock = threading.RLock() - self._cv = threading.Condition(lock=self._lock) - self._executor = None - self._error = None - self._sync_counter = torch.tensor([0], dtype=torch.int32) - self._done_callback = False - - @final - def pause(self): - if self._executor is not None: - self._executor.pause() - - @final - def resume(self): - if self._executor is not None: - self._executor.resume() - - @final - def shutdown(self): - if self._executor is not None: - self._executor.shutdown() - self._executor = None - - def __del__(self): - self.shutdown() diff --git a/torchdata/datapipes/iter/util/hashchecker.py b/torchdata/datapipes/iter/util/hashchecker.py deleted file mode 100644 index 9cb32ac6a..000000000 --- a/torchdata/datapipes/iter/util/hashchecker.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import hashlib - -from io import IOBase -from typing import Dict, Iterator, Tuple, Union - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe -from torchdata.datapipes.utils import StreamWrapper - - -D_type = Union[str, bytes, bytearray] -U = Union[D_type, StreamWrapper] - - -@functional_datapipe("check_hash") -class HashCheckerIterDataPipe(IterDataPipe[Tuple[str, U]]): - r""" - Computes and checks the hash of each file, from an input DataPipe of tuples of file name and - data/stream (functional name: ``check_hash``). If the hashes match the given hash - in the dictionary, it yields a tuple of file name and data/stream. Otherwise, it will raise an error. - - Args: - source_datapipe: IterDataPipe with tuples of file name and data/stream - hash_dict: Dictionary that maps file names to their corresponding hashes - hash_type: The type of hash function to apply - rewind: Rewind the stream after using the stream to compute the hash (this - does not work with non-seekable stream, e.g. HTTP) - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper, FileOpener - >>> expected_MD5_hash = "bb9675028dd39d2dd2bf71002b93e66c" - File is from "https://raw.githubusercontent.com/pytorch/data/main/LICENSE" - >>> file_dp = FileOpener(IterableWrapper(["LICENSE.txt"]), mode='rb') - >>> # An exception is only raised when the hash doesn't match, otherwise (path, stream) is returned - >>> check_hash_dp = file_dp.check_hash({"LICENSE.txt": expected_MD5_hash}, "md5", rewind=True) - >>> reader_dp = check_hash_dp.readlines() - >>> it = iter(reader_dp) - >>> path, line = next(it) - >>> path - LICENSE.txt - >>> line - b'BSD 3-Clause License' - """ - - def __init__( - self, - source_datapipe: IterDataPipe[Tuple[str, IOBase]], - hash_dict: Dict[str, str], - hash_type: str = "sha256", - rewind: bool = True, - ) -> None: - self.source_datapipe: IterDataPipe[Tuple[str, IOBase]] = source_datapipe - self.hash_dict: Dict[str, str] = hash_dict - self.hash_type: str = hash_type - self.rewind: bool = rewind - - if self.hash_type not in ["sha256", "md5"]: - raise ValueError("Invalid hash_type requested, should be one of {}".format(["sha256", "md5"])) - - def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: - for file_name, data in self.source_datapipe: - if self.hash_type == "sha256": - hash_func = hashlib.sha256() - else: - hash_func = hashlib.md5() - - if isinstance(data, (str, bytes, bytearray)): - if isinstance(data, str): - data = data.decode() - hash_func.update(data) - # File Stream - else: - # Not all streams have `read(bytes)` method. - # `__iter__` method is chosen because it is a common interface for IOBase. - for d in data: - hash_func.update(d) - - # TODO(133): this will not work (or work crappy for non-seekable steams like http) - if self.rewind: - data.seek(0) - - if file_name not in self.hash_dict: - raise RuntimeError(f"Unspecified hash for file {file_name}") - - if hash_func.hexdigest() != self.hash_dict[file_name]: - raise RuntimeError( - f"The computed hash {hash_func.hexdigest()} of {file_name} does not match the expected" - f"hash {self.hash_dict[file_name]}. Delete the file manually and retry." - ) - - if isinstance(data, (str, bytes, bytearray)): - yield file_name, data - else: - yield file_name, StreamWrapper(data) - - def __len__(self) -> int: - return len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/util/header.py b/torchdata/datapipes/iter/util/header.py deleted file mode 100644 index b16efe564..000000000 --- a/torchdata/datapipes/iter/util/header.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Iterator, Optional, TypeVar -from warnings import warn - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -T_co = TypeVar("T_co", covariant=True) - - -@functional_datapipe("header") -class HeaderIterDataPipe(IterDataPipe[T_co]): - r""" - Yields elements from the source DataPipe from the start, up to the specfied limit (functional name: ``header``). - - If you would like to manually set the length of a DataPipe to a certain value; we recommend you to - use :class:`.LengthSetter`. - - Args: - source_datapipe: the DataPipe from which elements will be yielded - limit: the number of elements to yield before stopping - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp = IterableWrapper(range(10)) - >>> header_dp = dp.header(3) - >>> list(header_dp) - [0, 1, 2] - """ - - def __init__(self, source_datapipe: IterDataPipe[T_co], limit: Optional[int] = 10) -> None: - self.source_datapipe: IterDataPipe[T_co] = source_datapipe - self.limit: Optional[int] = limit - - def __iter__(self) -> Iterator[T_co]: - i: int = 0 - for value in self.source_datapipe: - i += 1 - if self.limit is None or i <= self.limit: - yield value - else: - break - - def __len__(self) -> int: - try: - source_len = len(self.source_datapipe) - return source_len if self.limit is None else min(source_len, self.limit) - except TypeError as error: - if self.limit is None: - raise TypeError("The length of this HeaderIterDataPipe cannot be determined.") from error - - warn( - "The length of this HeaderIterDataPipe is inferred to be equal to its limit." - "The actual value may be smaller if the actual length of source_datapipe is smaller than the limit." - ) - return self.limit - - -@functional_datapipe("set_length") -class LengthSetterIterDataPipe(IterDataPipe[T_co]): - r""" - Set the length attribute of the DataPipe, which is returned by ``__len__`` (functional name: ``set_length``). - This can be used after DataPipes whose final length cannot be known in advance (e.g. ``filter``). If you - know the final length with certainty, you can manually set it, which can then be used by - DataLoader or other DataPipes. - - Note: - This DataPipe differs from :class:`.Header` in that this doesn't restrict the number of elements that - can be yielded from the DataPipe; this is strictly used for setting an attribute so that it can be used later. - - Args: - source_datapipe: a DataPipe - length: the integer value that will be set as the length - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp = IterableWrapper(range(10)).filter(lambda x: x < 5).set_length(3) - >>> list(dp) # Notice that the number of elements yielded is unchanged - [0, 1, 2, 3, 4] - >>> len(dp) - 3 - >>> header_dp = IterableWrapper(range(10)).filter(lambda x: x < 5).header(3) - >>> list(header_dp) # Use `.header()` if you want to limit the number of elements yielded - [0, 1, 2] - >>> len(header_dp) - 3 - """ - - def __init__(self, source_datapipe: IterDataPipe[T_co], length: int) -> None: - self.source_datapipe: IterDataPipe[T_co] = source_datapipe - assert length >= 0 - self.length: int = length - - def __iter__(self) -> Iterator[T_co]: - yield from self.source_datapipe - - def __len__(self) -> int: - return self.length diff --git a/torchdata/datapipes/iter/util/indexadder.py b/torchdata/datapipes/iter/util/indexadder.py deleted file mode 100644 index 6b4f70a85..000000000 --- a/torchdata/datapipes/iter/util/indexadder.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Dict, Iterator, Tuple, TypeVar - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -K = TypeVar("K") - - -@functional_datapipe("enumerate") -class EnumeratorIterDataPipe(IterDataPipe[Tuple[int, K]]): - r""" - Adds an index to an existing DataPipe through enumeration, with - the index starting from 0 by default (functional name: ``enumerate``). - - Args: - source_datapipe: Iterable DataPipe being indexed - starting_index: Index from which enumeration will start - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp = IterableWrapper(['a', 'b', 'c']) - >>> enum_dp = dp.enumerate() - >>> list(enum_dp) - [(0, 'a'), (1, 'b'), (2, 'c')] - """ - - def __init__(self, source_datapipe: IterDataPipe[K], starting_index: int = 0) -> None: - self.source_datapipe: IterDataPipe[K] = source_datapipe - self.starting_index = starting_index - - def __iter__(self): - yield from enumerate(self.source_datapipe, self.starting_index) - - def __len__(self): - return len(self.source_datapipe) - - -@functional_datapipe("add_index") -class IndexAdderIterDataPipe(IterDataPipe[Dict]): - r""" - Adds an index to an existing Iterable DataPipe with (functional name: ``add_index``). The row or batch - within the DataPipe must have the type `Dict`; otherwise, a `NotImplementedError` will be thrown. The index - of the data is set to the provided ``index_name``. - - Args: - source_datapipe: Iterable DataPipe being indexed, its row/batch must be of type `Dict` - index_name: Name of the key to store data index - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp = IterableWrapper([{'a': 1, 'b': 2}, {'c': 3, 'a': 1}]) - >>> index_dp = dp.add_index("order") - >>> list(index_dp) - [{'a': 1, 'b': 2, 'order': 0}, {'c': 3, 'a': 1, 'order': 1}] - """ - - def __init__(self, source_datapipe: IterDataPipe[Dict], index_name: str = "index") -> None: - self.source_datapipe = source_datapipe - self.index_name = index_name - - def __iter__(self) -> Iterator[Dict]: - for i, row_or_batch in enumerate(self.source_datapipe): - if isinstance(row_or_batch, dict): - row_or_batch[self.index_name] = i - yield row_or_batch - else: - raise NotImplementedError("We only support adding index to row or batch in dict type") - - def __len__(self) -> int: - return len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/util/jsonparser.py b/torchdata/datapipes/iter/util/jsonparser.py deleted file mode 100644 index 68a45259a..000000000 --- a/torchdata/datapipes/iter/util/jsonparser.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import json -from typing import Dict, IO, Iterator, Tuple - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - - -@functional_datapipe("parse_json_files") -class JsonParserIterDataPipe(IterDataPipe[Tuple[str, Dict]]): - r""" - Reads from JSON data streams and yields a tuple of file name and JSON data (functional name: ``parse_json_files``). - - Args: - source_datapipe: a DataPipe with tuples of file name and JSON data stream - kwargs: keyword arguments that will be passed through to ``json.loads`` - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper, FileOpener - >>> import os - >>> def get_name(path_and_stream): - >>> return os.path.basename(path_and_stream[0]), path_and_stream[1] - >>> datapipe1 = IterableWrapper(["empty.json", "1.json", "2.json"]) - >>> datapipe2 = FileOpener(datapipe1, mode="b") - >>> datapipe3 = datapipe2.map(get_name) - >>> json_dp = datapipe3.parse_json_files() - >>> list(json_dp) - [('1.json', ['foo', {'bar': ['baz', None, 1.0, 2]}]), ('2.json', {'__complex__': True, 'real': 1, 'imag': 2})] - """ - - def __init__(self, source_datapipe: IterDataPipe[Tuple[str, IO]], **kwargs) -> None: - self.source_datapipe: IterDataPipe[Tuple[str, IO]] = source_datapipe - self.kwargs = kwargs - - def __iter__(self) -> Iterator[Tuple[str, Dict]]: - for file_name, stream in self.source_datapipe: - data = stream.read() - stream.close() - yield file_name, json.loads(data, **self.kwargs) - - def __len__(self) -> int: - return len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/util/mux_longest.py b/torchdata/datapipes/iter/util/mux_longest.py deleted file mode 100644 index 2b6e328b9..000000000 --- a/torchdata/datapipes/iter/util/mux_longest.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Set, Sized - -from torch.utils.data.datapipes._decorator import functional_datapipe -from torch.utils.data.datapipes.datapipe import IterDataPipe - - -@functional_datapipe("mux_longest") -class MultiplexerLongestIterDataPipe(IterDataPipe): - r""" - Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux_longest``). As in, - one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration, - and so on. It skips over DataPipes that are exhausted, and ends when all input DataPipes are exhausted. - - Args: - datapipes: Iterable DataPipes that will take turn to yield their elements, until they are all exhausted - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp1, dp2, dp3 = IterableWrapper(range(5)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) - >>> list(dp1.mux_longest(dp2, dp3)) - [0, 10, 20, 1, 11, 21, 2, 12, 22, 3, 13, 23, 4, 14, 24] - """ - - def __init__(self, *datapipes): - self.datapipes = datapipes - - def __iter__(self): - iterators = [iter(x) for x in self.datapipes] - finished: Set[int] = set() - while len(finished) < len(iterators): - for i in range(len(iterators)): - if i not in finished: - try: - value = next(iterators[i]) - yield value - except StopIteration: - finished.add(i) - - def __len__(self): - if all(isinstance(dp, Sized) for dp in self.datapipes): - return sum(len(dp) for dp in self.datapipes) - else: - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/torchdata/datapipes/iter/util/paragraphaggregator.py b/torchdata/datapipes/iter/util/paragraphaggregator.py deleted file mode 100644 index f2ec7bacd..000000000 --- a/torchdata/datapipes/iter/util/paragraphaggregator.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Callable, final, Iterator, List, Tuple, TypeVar - -from torch.utils.data.datapipes.utils.common import _check_unpickable_fn - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - - -T_co = TypeVar("T_co", covariant=True) - - -def _default_line_join(lines: List[str]) -> str: - return "\n".join(lines) - - -@functional_datapipe("lines_to_paragraphs") -class ParagraphAggregatorIterDataPipe(IterDataPipe[Tuple[str, str]]): - r""" - Aggregates lines of text from the same file into a single paragraph (functional name: ``lines_to_paragraphs``). - Specifically, this accepts a DataPipe consisting of tuples of a file name and a line. For each tuple, - it checks if the file name matches the file name from the previous tuple. If yes, it joins the current line - with existing paragraph. If the file names do not match, the existing paragraph is yielded and a new - paragraph starts. - - Args: - source_datapipe: a DataPipe with tuples of a file name and a line - joiner: a function that joins a list of lines together - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> source_dp = IterableWrapper( - >>> [("file1", "Line1"), ("file1", "Line2"), ("file2", "Line2,1"), ("file2", "Line2,2"), ("file2", "Line2,3")] - >>> ) - >>> para_agg_dp = source_dp.lines_to_paragraphs(joiner=lambda ls: " ".join(ls)) - >>> list(para_agg_dp) - [('file1', 'Line1 Line2'), ('file2', 'Line2,1 Line2,2 Line2,3')] - """ - - def __init__(self, source_datapipe: IterDataPipe[Tuple[str, T_co]], joiner: Callable = _default_line_join) -> None: - self.source_datapipe: IterDataPipe[Tuple[str, T_co]] = source_datapipe - _check_unpickable_fn(joiner) - self.joiner: Callable = joiner - self.buffer: List = [] - - def __iter__(self) -> Iterator[Tuple[str, str]]: - prev_filename = None - for filename, line in self.source_datapipe: - if prev_filename is None: - prev_filename = filename - if line and prev_filename == filename: - self.buffer.append(line) - else: - if self.buffer: - yield prev_filename, self.joiner(self.buffer) # type: ignore[misc] - if line: - self.buffer = [line] - else: - self.buffer = [] - prev_filename = filename - if self.buffer: - yield prev_filename, self.joiner(self.buffer) # type: ignore[misc] - - @final - def reset(self) -> None: - self.buffer = [] - - def __getstate__(self): - state = (self.source_datapipe, self.joiner) - if IterDataPipe.getstate_hook is not None: - return IterDataPipe.getstate_hook(state) - return state - - def __setstate__(self, state): - (self.source_datapipe, self.joiner) = state - self.buffer = [] - - def __del__(self): - self.buffer.clear() diff --git a/torchdata/datapipes/iter/util/plain_text_reader.py b/torchdata/datapipes/iter/util/plain_text_reader.py deleted file mode 100644 index b5d876bd3..000000000 --- a/torchdata/datapipes/iter/util/plain_text_reader.py +++ /dev/null @@ -1,285 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import contextlib -import csv -from typing import IO, Iterator, Tuple, TypeVar, Union - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -D = TypeVar("D") -Str_Or_Bytes = Union[str, bytes] - - -class PlainTextReaderHelper: - def __init__( - self, - *, - skip_lines: int = 0, - strip_newline: bool = True, - decode: bool = True, - encoding="utf-8", - errors: str = "ignore", - return_path: bool = False, - as_tuple: bool = False, - ) -> None: - if skip_lines < 0: - raise ValueError("'skip_lines' is required to be a positive integer.") - self._skip_lines = skip_lines - self._strip_newline = strip_newline - self._decode = decode - self._encoding = encoding - self._errors = errors - self._return_path = return_path - self._as_tuple = as_tuple - - def skip_lines(self, file: IO) -> Union[Iterator[bytes], Iterator[str]]: - with contextlib.suppress(StopIteration): - for _ in range(self._skip_lines): - next(file) - try: - yield from file - finally: - file.close() - - def strip_newline(self, stream: Union[Iterator[bytes], Iterator[str]]) -> Union[Iterator[bytes], Iterator[str]]: - if not self._strip_newline: - yield from stream - return - - for line in stream: - if isinstance(line, str): - yield line.strip("\r\n") - else: - yield line.strip(b"\r\n") - - def decode(self, stream: Union[Iterator[bytes], Iterator[str]]) -> Union[Iterator[bytes], Iterator[str]]: - if not self._decode: - yield from stream - else: - for line in stream: - yield line.decode(self._encoding, self._errors) if isinstance(line, bytes) else line - - def return_path(self, stream: Iterator[D], *, path: str) -> Iterator[Union[D, Tuple[str, D]]]: - if not self._return_path: - yield from stream - return - for data in stream: - yield path, data - - def as_tuple(self, stream: Iterator[D]) -> Iterator[Union[D, Tuple]]: - if not self._as_tuple: - yield from stream - return - for data in stream: - if isinstance(data, list): - yield tuple(data) - else: - yield data - - -@functional_datapipe("readlines") -class LineReaderIterDataPipe(IterDataPipe[Union[Str_Or_Bytes, Tuple[str, Str_Or_Bytes]]]): - r""" - Accepts a DataPipe consisting of tuples of file name and string data stream, and for each line in the - stream, yields a tuple of file name and the line (functional name: ``readlines``). - - Args: - source_datapipe: a DataPipe with tuples of file name and string data stream - skip_lines: number of lines to skip at the beginning of each file - strip_newline: if ``True``, the new line character will be stripped - decode: if ``True``, this will decode the contents of the file based on the specified ``encoding`` - encoding: the character encoding of the files (`default='utf-8'`) - errors: the error handling scheme used while decoding - return_path: if ``True``, each line will return a tuple of path and contents, rather - than just the contents - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> import io - >>> text1 = "Line1\nLine2" - >>> text2 = "Line2,1\r\nLine2,2\r\nLine2,3" - >>> source_dp = IterableWrapper([("file1", io.StringIO(text1)), ("file2", io.StringIO(text2))]) - >>> line_reader_dp = source_dp.readlines() - >>> list(line_reader_dp) - [('file1', 'Line1'), ('file1', 'Line2'), ('file2', 'Line2,1'), ('file2', 'Line2,2'), ('file2', 'Line2,3')] - """ - - def __init__( - self, - source_datapipe: IterDataPipe[Tuple[str, IO]], - *, - skip_lines: int = 0, - strip_newline: bool = True, - decode: bool = False, - encoding="utf-8", - errors: str = "ignore", - return_path: bool = True, - ) -> None: - self.source_datapipe = source_datapipe - self._helper = PlainTextReaderHelper( - skip_lines=skip_lines, - strip_newline=strip_newline, - decode=decode, - encoding=encoding, - errors=errors, - return_path=return_path, - ) - - def __iter__(self) -> Iterator[Union[Str_Or_Bytes, Tuple[str, Str_Or_Bytes]]]: - for path, file in self.source_datapipe: - stream = self._helper.skip_lines(file) - stream = self._helper.strip_newline(stream) - stream = self._helper.decode(stream) - yield from self._helper.return_path(stream, path=path) # type: ignore[misc] - - -class _CSVBaseParserIterDataPipe(IterDataPipe): - def __init__( - self, - source_datapipe, - csv_reader, - *, - skip_lines: int = 0, - decode: bool = False, - encoding="utf-8", - errors: str = "ignore", - return_path: bool = True, - as_tuple: bool = False, - **fmtparams, - ) -> None: - self.source_datapipe = source_datapipe - self._csv_reader = csv_reader - self._helper = PlainTextReaderHelper( - skip_lines=skip_lines, - decode=decode, - encoding=encoding, - errors=errors, - return_path=return_path, - as_tuple=as_tuple, - ) - self.fmtparams = fmtparams - - def __iter__(self) -> Iterator[Union[D, Tuple[str, D]]]: - for path, file in self.source_datapipe: - stream = self._helper.skip_lines(file) - stream = self._helper.decode(stream) - stream = self._csv_reader(stream, **self.fmtparams) - stream = self._helper.as_tuple(stream) # type: ignore[assignment] - yield from self._helper.return_path(stream, path=path) # type: ignore[misc] - - -@functional_datapipe("parse_csv") -class CSVParserIterDataPipe(_CSVBaseParserIterDataPipe): - r""" - Accepts a DataPipe consists of tuples of file name and CSV data stream, - reads and returns the contents within the CSV files one row at a time (functional name: ``parse_csv``). - Each output is a `List` by default, but it depends on ``fmtparams``. - - Args: - source_datapipe: source DataPipe with tuples of file name and CSV data stream - skip_lines: number of lines to skip at the beginning of each file - strip_newline: if ``True``, the new line character will be stripped - decode: if ``True``, this will decode the contents of the file based on the specified ``encoding`` - encoding: the character encoding of the files (`default='utf-8'`) - errors: the error handling scheme used while decoding - return_path: if ``True``, each line will return a tuple of path and contents, rather - than just the contents - as_tuple: if ``True``, each line will return a tuple instead of a list - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper, FileOpener - >>> import os - >>> def get_name(path_and_stream): - >>> return os.path.basename(path_and_stream[0]), path_and_stream[1] - >>> datapipe1 = IterableWrapper(["1.csv", "empty.csv", "empty2.csv"]) - >>> datapipe2 = FileOpener(datapipe1, mode="b") - >>> datapipe3 = datapipe2.map(get_name) - >>> csv_parser_dp = datapipe3.parse_csv() - >>> list(csv_parser_dp) - [['key', 'item'], ['a', '1'], ['b', '2'], []] - """ - - def __init__( - self, - source_datapipe: IterDataPipe[Tuple[str, IO]], - *, - skip_lines: int = 0, - decode: bool = True, - encoding: str = "utf-8", - errors: str = "ignore", - return_path: bool = False, - as_tuple: bool = False, - **fmtparams, - ) -> None: - super().__init__( - source_datapipe, - csv.reader, - skip_lines=skip_lines, - decode=decode, - encoding=encoding, - errors=errors, - return_path=return_path, - as_tuple=as_tuple, - **fmtparams, - ) - - -@functional_datapipe("parse_csv_as_dict") -class CSVDictParserIterDataPipe(_CSVBaseParserIterDataPipe): - r""" - Accepts a DataPipe consists of tuples of file name and CSV data stream, reads and returns the contents - within the CSV files one row at a time (functional name: ``parse_csv_as_dict``). - - Each output is a `Dict` by default, but it depends on ``fmtparams``. The first row of each file, unless skipped, - will be used as the header; the contents of the header row will be used as keys for the `Dict`\s - generated from the remaining rows. - - Args: - source_datapipe: source DataPipe with tuples of file name and CSV data stream - skip_lines: number of lines to skip at the beginning of each file - strip_newline: if ``True``, the new line character will be stripped - decode: if ``True``, this will decode the contents of the file based on the specified ``encoding`` - encoding: the character encoding of the files (`default='utf-8'`) - errors: the error handling scheme used while decoding - return_path: if ``True``, each line will return a tuple of path and contents, rather - than just the contents - - Example: - >>> from torchdata.datapipes.iter import FileLister, FileOpener - >>> import os - >>> def get_name(path_and_stream): - >>> return os.path.basename(path_and_stream[0]), path_and_stream[1] - >>> datapipe1 = FileLister(".", "*.csv") - >>> datapipe2 = FileOpener(datapipe1, mode="b") - >>> datapipe3 = datapipe2.map(get_name) - >>> csv_dict_parser_dp = datapipe3.parse_csv_as_dict() - >>> list(csv_dict_parser_dp) - [{'key': 'a', 'item': '1'}, {'key': 'b', 'item': '2'}] - """ - - def __init__( - self, - source_datapipe: IterDataPipe[Tuple[str, IO]], - *, - skip_lines: int = 0, - decode: bool = True, - encoding: str = "utf-8", - errors: str = "ignore", - return_path: bool = False, - **fmtparams, - ) -> None: - super().__init__( - source_datapipe, - csv.DictReader, - skip_lines=skip_lines, - decode=decode, - encoding=encoding, - errors=errors, - return_path=return_path, - **fmtparams, - ) diff --git a/torchdata/datapipes/iter/util/prefetcher.py b/torchdata/datapipes/iter/util/prefetcher.py deleted file mode 100644 index c12e5444a..000000000 --- a/torchdata/datapipes/iter/util/prefetcher.py +++ /dev/null @@ -1,268 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import threading -import time - -from collections import deque -from typing import Deque, final, Optional, Sized - -import torch - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe -from torchdata.datapipes.utils import pin_memory_fn - -PRODUCER_SLEEP_INTERVAL = 0.0001 # Interval between buffer fulfillment checks -CONSUMER_SLEEP_INTERVAL = 0.0001 # Interval between checking items availability in buffer - - -class _PrefetchData: - def __init__(self, source_datapipe, buffer_size: int): - self.run_prefetcher: bool = True - self.prefetch_buffer: Deque = deque() - self.buffer_size: int = buffer_size - self.source_datapipe = source_datapipe - self.stop_iteration: bool = False - self.paused: bool = False - - -@functional_datapipe("prefetch") -class PrefetcherIterDataPipe(IterDataPipe): - r""" - Prefetches elements from the source DataPipe and puts them into a buffer (functional name: ``prefetch``). - Prefetching performs the operations (e.g. I/O, computations) of the DataPipes up to this one ahead of time - and stores the result in the buffer, ready to be consumed by the subsequent DataPipe. It has no effect aside - from getting the sample ready ahead of time. - - This is used by ``MultiProcessingReadingService`` when the arguments - ``worker_prefetch_cnt`` (for prefetching at each worker process) or - ``main_prefetch_cnt`` (for prefetching at the main loop) are greater than 0. - - Beyond the built-in use cases, this can be useful to put after I/O DataPipes that have - expensive I/O operations (e.g. takes a long time to request a file from a remote server). - - Args: - source_datapipe: IterDataPipe from which samples are prefetched - buffer_size: the size of the buffer which stores the prefetched samples - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp = IterableWrapper(file_paths).open_files().prefetch(5) - """ - - def __init__(self, source_datapipe, buffer_size: int = 10): - self.source_datapipe = source_datapipe - if buffer_size <= 0: - raise ValueError("'buffer_size' is required to be a positive integer.") - self.buffer_size = buffer_size - self.thread: Optional[threading.Thread] = None - self.prefetch_data: Optional[_PrefetchData] = None - - @staticmethod - def thread_worker(prefetch_data: _PrefetchData): - itr = iter(prefetch_data.source_datapipe) - while not prefetch_data.stop_iteration: - # Run if not paused - while prefetch_data.run_prefetcher: - if len(prefetch_data.prefetch_buffer) < prefetch_data.buffer_size: - try: - item = next(itr) - prefetch_data.prefetch_buffer.append(item) - except Exception as e: - prefetch_data.run_prefetcher = False - prefetch_data.stop_iteration = True - prefetch_data.prefetch_buffer.append(e) - else: # Buffer is full, waiting for main thread to consume items - # TODO: Calculate sleep interval based on previous consumption speed - time.sleep(PRODUCER_SLEEP_INTERVAL) - prefetch_data.paused = True - # Sleep longer when this prefetcher thread is paused - time.sleep(PRODUCER_SLEEP_INTERVAL * 10) - - def __iter__(self): - try: - prefetch_data = _PrefetchData(self.source_datapipe, self.buffer_size) - self.prefetch_data = prefetch_data - thread = threading.Thread(target=PrefetcherIterDataPipe.thread_worker, args=(prefetch_data,), daemon=True) - thread.start() - self.thread = thread - - # Lazily import to prevent circular import - from torchdata.dataloader2 import communication - - while not prefetch_data.stop_iteration or len(prefetch_data.prefetch_buffer) > 0: - if len(prefetch_data.prefetch_buffer) > 0: - data = prefetch_data.prefetch_buffer.popleft() - if isinstance(data, Exception): - if isinstance(data, (StopIteration, communication.iter.TerminateRequired)): - break - raise data - yield data - else: - time.sleep(CONSUMER_SLEEP_INTERVAL) - finally: - if "prefetch_data" in locals(): - prefetch_data.run_prefetcher = False - prefetch_data.stop_iteration = True - prefetch_data.paused = False - if "thread" in locals(): - thread.join() - - def __getstate__(self): - """ - Getting state in threading environment requires next operations: - 1) Stopping of the producer thread. - 2) Saving buffer. - 3) Adding lazy restart of producer thread when __next__ is called again - (this will guarantee that you only change state of the source_datapipe - after entire state of the graph is saved). - """ - # TODO: Update __getstate__ and __setstate__ to support snapshotting and restoration - return {"source_datapipe": self.source_datapipe, "buffer_size": self.buffer_size} - - def __setstate__(self, state): - self.source_datapipe = state["source_datapipe"] - self.buffer_size = state["buffer_size"] - self.thread = None - - @final - def reset(self): - self.shutdown() - - def pause(self): - if self.thread is not None: - assert self.prefetch_data is not None - self.prefetch_data.run_prefetcher = False - if self.thread.is_alive(): - # Blocking until the thread is paused - while not self.prefetch_data.paused: - time.sleep(PRODUCER_SLEEP_INTERVAL * 10) - - @final - def resume(self): - if ( - self.thread is not None - and self.prefetch_data is not None - and (not self.prefetch_data.stop_iteration or len(self.prefetch_data.prefetch_buffer) > 0) - ): - self.prefetch_data.run_prefetcher = True - self.prefetch_data.paused = False - - @final - def shutdown(self): - if hasattr(self, "prefetch_data") and self.prefetch_data is not None: - self.prefetch_data.run_prefetcher = False - self.prefetch_data.stop_iteration = True - self.prefetch_data.paused = False - self.prefetch_data = None - if hasattr(self, "thread") and self.thread is not None: - self.thread.join() - self.thread = None - - def __del__(self): - self.shutdown() - - def __len__(self) -> int: - if isinstance(self.source_datapipe, Sized): - return len(self.source_datapipe) - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") - - -@functional_datapipe("pin_memory") -class PinMemoryIterDataPipe(PrefetcherIterDataPipe): - r""" - Prefetches one element from the source DataPipe and moves it to pinned memory (functional name: ``pin_memory``). - When used with ``MultiProcessingReadingService``, this DataPipe would be kept in the main process to prevent - duplicated CUDA context creation. - - Args: - source_datapipe: IterDataPipe from which samples are moved to pinned memory. - device: The device to pin samples. - pin_memory_fn: Optional callable function to move data to pinned memory. - A ``pin_memory_fn`` to handle general objects is provided by default. - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp = IterableWrapper(file_paths).open_files().readlines().map(tokenize_fn).pin_memory() - """ - - def __init__(self, source_datapipe, device=None, pin_memory_fn=pin_memory_fn): - if not torch.cuda.is_available(): - raise RuntimeError("``pin_memory`` can only be used when CUDA is available.") - # TODO: Add support for dynamic buffer based on the available size of pinned memory - super().__init__(source_datapipe, buffer_size=2) - if device is None: - device = torch.cuda.current_device() - self.device = device - self.pin_memory_fn = pin_memory_fn - - def is_replicable(self) -> bool: - return False - - @staticmethod - def thread_worker(prefetch_data: _PrefetchData, pin_memory_fn, device): # type: ignore[override] - itr = iter(prefetch_data.source_datapipe) - while not prefetch_data.stop_iteration: - # Run if not paused - while prefetch_data.run_prefetcher: - if len(prefetch_data.prefetch_buffer) < prefetch_data.buffer_size: - try: - item = pin_memory_fn(next(itr), device) - prefetch_data.prefetch_buffer.append(item) - except Exception as e: - prefetch_data.run_prefetcher = False - prefetch_data.stop_iteration = True - prefetch_data.prefetch_buffer.append(e) - else: # Buffer is full, waiting for main thread to consume items - # TODO: Calculate sleep interval based on previous consumption speed - time.sleep(PRODUCER_SLEEP_INTERVAL) - # Sleep longer when this prefetcher thread is paused - time.sleep(PRODUCER_SLEEP_INTERVAL * 10) - - def __iter__(self): - try: - prefetch_data = _PrefetchData(self.source_datapipe, self.buffer_size) - self.prefetch_data = prefetch_data - thread = threading.Thread( - target=PinMemoryIterDataPipe.thread_worker, - args=(prefetch_data, self.pin_memory_fn, self.device), - daemon=True, - ) - thread.start() - self.thread = thread - - # Lazily import to prevent circular import - from torchdata.dataloader2 import communication - - while not prefetch_data.stop_iteration or len(prefetch_data.prefetch_buffer) > 0: - if len(prefetch_data.prefetch_buffer) > 0: - data = prefetch_data.prefetch_buffer.popleft() - if isinstance(data, Exception): - if isinstance(data, (StopIteration, communication.iter.TerminateRequired)): - break - raise data - yield data - else: - time.sleep(CONSUMER_SLEEP_INTERVAL) - finally: - if "prefetch_data" in locals(): - prefetch_data.run_prefetcher = False - prefetch_data.stop_iteration = True - prefetch_data.paused = False - if "thread" in locals(): - thread.join() - - def __getstate__(self): - state = super().__getstate__() - state["pin_memory_fn"] = self.pin_memory_fn - state["device"] = self.device - return state - - def __setstate__(self, state): - super().__setstate__(state) - self.pin_memory_fn = state["pin_memory_fn"] - self.device = state["device"] diff --git a/torchdata/datapipes/iter/util/protobuf_template/__init__.py b/torchdata/datapipes/iter/util/protobuf_template/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/torchdata/datapipes/iter/util/protobuf_template/_tfrecord_example_pb2.py b/torchdata/datapipes/iter/util/protobuf_template/_tfrecord_example_pb2.py deleted file mode 100644 index 972c61ff4..000000000 --- a/torchdata/datapipes/iter/util/protobuf_template/_tfrecord_example_pb2.py +++ /dev/null @@ -1,699 +0,0 @@ -# type: ignore -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: example.proto - -import sys - -_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode("latin1")) -from google.protobuf import ( - descriptor as _descriptor, - descriptor_pb2, - message as _message, - reflection as _reflection, - symbol_database as _symbol_database, -) - -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor.FileDescriptor( - name="example.proto", - package="tfrecord", - syntax="proto3", - serialized_pb=_b( - '\n\rexample.proto\x12\x08tfrecord"\x1a\n\tBytesList\x12\r\n\x05value\x18\x01 \x03(\x0c"\x1e\n\tFloatList\x12\x11\n\x05value\x18\x01 \x03(\x02\x42\x02\x10\x01"\x1e\n\tInt64List\x12\x11\n\x05value\x18\x01 \x03(\x03\x42\x02\x10\x01"\x92\x01\n\x07\x46\x65\x61ture\x12)\n\nbytes_list\x18\x01 \x01(\x0b\x32\x13.tfrecord.BytesListH\x00\x12)\n\nfloat_list\x18\x02 \x01(\x0b\x32\x13.tfrecord.FloatListH\x00\x12)\n\nint64_list\x18\x03 \x01(\x0b\x32\x13.tfrecord.Int64ListH\x00\x42\x06\n\x04kind"\x7f\n\x08\x46\x65\x61tures\x12\x30\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32\x1f.tfrecord.Features.FeatureEntry\x1a\x41\n\x0c\x46\x65\x61tureEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.tfrecord.Feature:\x02\x38\x01"1\n\x0b\x46\x65\x61tureList\x12"\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32\x11.tfrecord.Feature"\x98\x01\n\x0c\x46\x65\x61tureLists\x12=\n\x0c\x66\x65\x61ture_list\x18\x01 \x03(\x0b\x32\'.tfrecord.FeatureLists.FeatureListEntry\x1aI\n\x10\x46\x65\x61tureListEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12$\n\x05value\x18\x02 \x01(\x0b\x32\x15.tfrecord.FeatureList:\x02\x38\x01"/\n\x07\x45xample\x12$\n\x08\x66\x65\x61tures\x18\x01 \x01(\x0b\x32\x12.tfrecord.Features"e\n\x0fSequenceExample\x12#\n\x07\x63ontext\x18\x01 \x01(\x0b\x32\x12.tfrecord.Features\x12-\n\rfeature_lists\x18\x02 \x01(\x0b\x32\x16.tfrecord.FeatureListsB\x03\xf8\x01\x01\x62\x06proto3' - ), -) -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - - -_BYTESLIST = _descriptor.Descriptor( - name="BytesList", - full_name="tfrecord.BytesList", - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name="value", - full_name="tfrecord.BytesList.value", - index=0, - number=1, - type=12, - cpp_type=9, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None, - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=27, - serialized_end=53, -) - - -_FLOATLIST = _descriptor.Descriptor( - name="FloatList", - full_name="tfrecord.FloatList", - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name="value", - full_name="tfrecord.FloatList.value", - index=0, - number=1, - type=2, - cpp_type=6, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=55, - serialized_end=85, -) - - -_INT64LIST = _descriptor.Descriptor( - name="Int64List", - full_name="tfrecord.Int64List", - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name="value", - full_name="tfrecord.Int64List.value", - index=0, - number=1, - type=3, - cpp_type=2, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")), - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=87, - serialized_end=117, -) - - -_FEATURE = _descriptor.Descriptor( - name="Feature", - full_name="tfrecord.Feature", - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name="bytes_list", - full_name="tfrecord.Feature.bytes_list", - index=0, - number=1, - type=11, - cpp_type=10, - label=1, - has_default_value=False, - default_value=None, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None, - ), - _descriptor.FieldDescriptor( - name="float_list", - full_name="tfrecord.Feature.float_list", - index=1, - number=2, - type=11, - cpp_type=10, - label=1, - has_default_value=False, - default_value=None, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None, - ), - _descriptor.FieldDescriptor( - name="int64_list", - full_name="tfrecord.Feature.int64_list", - index=2, - number=3, - type=11, - cpp_type=10, - label=1, - has_default_value=False, - default_value=None, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None, - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name="kind", full_name="tfrecord.Feature.kind", index=0, containing_type=None, fields=[] - ), - ], - serialized_start=120, - serialized_end=266, -) - - -_FEATURES_FEATUREENTRY = _descriptor.Descriptor( - name="FeatureEntry", - full_name="tfrecord.Features.FeatureEntry", - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name="key", - full_name="tfrecord.Features.FeatureEntry.key", - index=0, - number=1, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=_b("").decode("utf-8"), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None, - ), - _descriptor.FieldDescriptor( - name="value", - full_name="tfrecord.Features.FeatureEntry.value", - index=1, - number=2, - type=11, - cpp_type=10, - label=1, - has_default_value=False, - default_value=None, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None, - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")), - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=330, - serialized_end=395, -) - -_FEATURES = _descriptor.Descriptor( - name="Features", - full_name="tfrecord.Features", - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name="feature", - full_name="tfrecord.Features.feature", - index=0, - number=1, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None, - ), - ], - extensions=[], - nested_types=[ - _FEATURES_FEATUREENTRY, - ], - enum_types=[], - options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=268, - serialized_end=395, -) - - -_FEATURELIST = _descriptor.Descriptor( - name="FeatureList", - full_name="tfrecord.FeatureList", - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name="feature", - full_name="tfrecord.FeatureList.feature", - index=0, - number=1, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None, - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=397, - serialized_end=446, -) - - -_FEATURELISTS_FEATURELISTENTRY = _descriptor.Descriptor( - name="FeatureListEntry", - full_name="tfrecord.FeatureLists.FeatureListEntry", - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name="key", - full_name="tfrecord.FeatureLists.FeatureListEntry.key", - index=0, - number=1, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=_b("").decode("utf-8"), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None, - ), - _descriptor.FieldDescriptor( - name="value", - full_name="tfrecord.FeatureLists.FeatureListEntry.value", - index=1, - number=2, - type=11, - cpp_type=10, - label=1, - has_default_value=False, - default_value=None, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None, - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=_descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")), - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=528, - serialized_end=601, -) - -_FEATURELISTS = _descriptor.Descriptor( - name="FeatureLists", - full_name="tfrecord.FeatureLists", - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name="feature_list", - full_name="tfrecord.FeatureLists.feature_list", - index=0, - number=1, - type=11, - cpp_type=10, - label=3, - has_default_value=False, - default_value=[], - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None, - ), - ], - extensions=[], - nested_types=[ - _FEATURELISTS_FEATURELISTENTRY, - ], - enum_types=[], - options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=449, - serialized_end=601, -) - - -_EXAMPLE = _descriptor.Descriptor( - name="Example", - full_name="tfrecord.Example", - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name="features", - full_name="tfrecord.Example.features", - index=0, - number=1, - type=11, - cpp_type=10, - label=1, - has_default_value=False, - default_value=None, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None, - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=603, - serialized_end=650, -) - - -_SEQUENCEEXAMPLE = _descriptor.Descriptor( - name="SequenceExample", - full_name="tfrecord.SequenceExample", - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name="context", - full_name="tfrecord.SequenceExample.context", - index=0, - number=1, - type=11, - cpp_type=10, - label=1, - has_default_value=False, - default_value=None, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None, - ), - _descriptor.FieldDescriptor( - name="feature_lists", - full_name="tfrecord.SequenceExample.feature_lists", - index=1, - number=2, - type=11, - cpp_type=10, - label=1, - has_default_value=False, - default_value=None, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None, - ), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - syntax="proto3", - extension_ranges=[], - oneofs=[], - serialized_start=652, - serialized_end=753, -) - -_FEATURE.fields_by_name["bytes_list"].message_type = _BYTESLIST -_FEATURE.fields_by_name["float_list"].message_type = _FLOATLIST -_FEATURE.fields_by_name["int64_list"].message_type = _INT64LIST -_FEATURE.oneofs_by_name["kind"].fields.append(_FEATURE.fields_by_name["bytes_list"]) -_FEATURE.fields_by_name["bytes_list"].containing_oneof = _FEATURE.oneofs_by_name["kind"] -_FEATURE.oneofs_by_name["kind"].fields.append(_FEATURE.fields_by_name["float_list"]) -_FEATURE.fields_by_name["float_list"].containing_oneof = _FEATURE.oneofs_by_name["kind"] -_FEATURE.oneofs_by_name["kind"].fields.append(_FEATURE.fields_by_name["int64_list"]) -_FEATURE.fields_by_name["int64_list"].containing_oneof = _FEATURE.oneofs_by_name["kind"] -_FEATURES_FEATUREENTRY.fields_by_name["value"].message_type = _FEATURE -_FEATURES_FEATUREENTRY.containing_type = _FEATURES -_FEATURES.fields_by_name["feature"].message_type = _FEATURES_FEATUREENTRY -_FEATURELIST.fields_by_name["feature"].message_type = _FEATURE -_FEATURELISTS_FEATURELISTENTRY.fields_by_name["value"].message_type = _FEATURELIST -_FEATURELISTS_FEATURELISTENTRY.containing_type = _FEATURELISTS -_FEATURELISTS.fields_by_name["feature_list"].message_type = _FEATURELISTS_FEATURELISTENTRY -_EXAMPLE.fields_by_name["features"].message_type = _FEATURES -_SEQUENCEEXAMPLE.fields_by_name["context"].message_type = _FEATURES -_SEQUENCEEXAMPLE.fields_by_name["feature_lists"].message_type = _FEATURELISTS -DESCRIPTOR.message_types_by_name["BytesList"] = _BYTESLIST -DESCRIPTOR.message_types_by_name["FloatList"] = _FLOATLIST -DESCRIPTOR.message_types_by_name["Int64List"] = _INT64LIST -DESCRIPTOR.message_types_by_name["Feature"] = _FEATURE -DESCRIPTOR.message_types_by_name["Features"] = _FEATURES -DESCRIPTOR.message_types_by_name["FeatureList"] = _FEATURELIST -DESCRIPTOR.message_types_by_name["FeatureLists"] = _FEATURELISTS -DESCRIPTOR.message_types_by_name["Example"] = _EXAMPLE -DESCRIPTOR.message_types_by_name["SequenceExample"] = _SEQUENCEEXAMPLE - -BytesList = _reflection.GeneratedProtocolMessageType( - "BytesList", - (_message.Message,), - dict( - DESCRIPTOR=_BYTESLIST, - __module__="example_pb2" - # @@protoc_insertion_point(class_scope:tfrecord.BytesList) - ), -) -_sym_db.RegisterMessage(BytesList) - -FloatList = _reflection.GeneratedProtocolMessageType( - "FloatList", - (_message.Message,), - dict( - DESCRIPTOR=_FLOATLIST, - __module__="example_pb2" - # @@protoc_insertion_point(class_scope:tfrecord.FloatList) - ), -) -_sym_db.RegisterMessage(FloatList) - -Int64List = _reflection.GeneratedProtocolMessageType( - "Int64List", - (_message.Message,), - dict( - DESCRIPTOR=_INT64LIST, - __module__="example_pb2" - # @@protoc_insertion_point(class_scope:tfrecord.Int64List) - ), -) -_sym_db.RegisterMessage(Int64List) - -Feature = _reflection.GeneratedProtocolMessageType( - "Feature", - (_message.Message,), - dict( - DESCRIPTOR=_FEATURE, - __module__="example_pb2" - # @@protoc_insertion_point(class_scope:tfrecord.Feature) - ), -) -_sym_db.RegisterMessage(Feature) - -Features = _reflection.GeneratedProtocolMessageType( - "Features", - (_message.Message,), - dict( - FeatureEntry=_reflection.GeneratedProtocolMessageType( - "FeatureEntry", - (_message.Message,), - dict( - DESCRIPTOR=_FEATURES_FEATUREENTRY, - __module__="example_pb2" - # @@protoc_insertion_point(class_scope:tfrecord.Features.FeatureEntry) - ), - ), - DESCRIPTOR=_FEATURES, - __module__="example_pb2" - # @@protoc_insertion_point(class_scope:tfrecord.Features) - ), -) -_sym_db.RegisterMessage(Features) -_sym_db.RegisterMessage(Features.FeatureEntry) - -FeatureList = _reflection.GeneratedProtocolMessageType( - "FeatureList", - (_message.Message,), - dict( - DESCRIPTOR=_FEATURELIST, - __module__="example_pb2" - # @@protoc_insertion_point(class_scope:tfrecord.FeatureList) - ), -) -_sym_db.RegisterMessage(FeatureList) - -FeatureLists = _reflection.GeneratedProtocolMessageType( - "FeatureLists", - (_message.Message,), - dict( - FeatureListEntry=_reflection.GeneratedProtocolMessageType( - "FeatureListEntry", - (_message.Message,), - dict( - DESCRIPTOR=_FEATURELISTS_FEATURELISTENTRY, - __module__="example_pb2" - # @@protoc_insertion_point(class_scope:tfrecord.FeatureLists.FeatureListEntry) - ), - ), - DESCRIPTOR=_FEATURELISTS, - __module__="example_pb2" - # @@protoc_insertion_point(class_scope:tfrecord.FeatureLists) - ), -) -_sym_db.RegisterMessage(FeatureLists) -_sym_db.RegisterMessage(FeatureLists.FeatureListEntry) - -Example = _reflection.GeneratedProtocolMessageType( - "Example", - (_message.Message,), - dict( - DESCRIPTOR=_EXAMPLE, - __module__="example_pb2" - # @@protoc_insertion_point(class_scope:tfrecord.Example) - ), -) -_sym_db.RegisterMessage(Example) - -SequenceExample = _reflection.GeneratedProtocolMessageType( - "SequenceExample", - (_message.Message,), - dict( - DESCRIPTOR=_SEQUENCEEXAMPLE, - __module__="example_pb2" - # @@protoc_insertion_point(class_scope:tfrecord.SequenceExample) - ), -) -_sym_db.RegisterMessage(SequenceExample) - - -DESCRIPTOR.has_options = True -DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b("\370\001\001")) -_FLOATLIST.fields_by_name["value"].has_options = True -_FLOATLIST.fields_by_name["value"]._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")) -_INT64LIST.fields_by_name["value"].has_options = True -_INT64LIST.fields_by_name["value"]._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b("\020\001")) -_FEATURES_FEATUREENTRY.has_options = True -_FEATURES_FEATUREENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")) -_FEATURELISTS_FEATURELISTENTRY.has_options = True -_FEATURELISTS_FEATURELISTENTRY._options = _descriptor._ParseOptions(descriptor_pb2.MessageOptions(), _b("8\001")) -# @@protoc_insertion_point(module_scope) diff --git a/torchdata/datapipes/iter/util/protobuf_template/example.proto b/torchdata/datapipes/iter/util/protobuf_template/example.proto deleted file mode 100644 index 9f762fb51..000000000 --- a/torchdata/datapipes/iter/util/protobuf_template/example.proto +++ /dev/null @@ -1,301 +0,0 @@ -// Protocol messages for describing input data Examples for machine learning -// model training or inference. -syntax = "proto3"; - -package tensorflow; - -import "tensorflow/core/example/feature.proto"; - -option cc_enable_arenas = true; -option java_outer_classname = "ExampleProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.example"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/example/example_protos_go_proto"; - -// An Example is a mostly-normalized data format for storing data for -// training and inference. It contains a key-value store (features); where -// each key (string) maps to a Feature message (which is oneof packed BytesList, -// FloatList, or Int64List). This flexible and compact format allows the -// storage of large amounts of typed data, but requires that the data shape -// and use be determined by the configuration files and parsers that are used to -// read and write this format. That is, the Example is mostly *not* a -// self-describing format. In TensorFlow, Examples are read in row-major -// format, so any configuration that describes data with rank-2 or above -// should keep this in mind. If you flatten a matrix into a FloatList it should -// be stored as [ row 0 ... row 1 ... row M-1 ] -// -// An Example for a movie recommendation application: -// features { -// feature { -// key: "age" -// value { float_list { -// value: 29.0 -// }} -// } -// feature { -// key: "movie" -// value { bytes_list { -// value: "The Shawshank Redemption" -// value: "Fight Club" -// }} -// } -// feature { -// key: "movie_ratings" -// value { float_list { -// value: 9.0 -// value: 9.7 -// }} -// } -// feature { -// key: "suggestion" -// value { bytes_list { -// value: "Inception" -// }} -// } -// # Note that this feature exists to be used as a label in training. -// # E.g., if training a logistic regression model to predict purchase -// # probability in our learning tool we would set the label feature to -// # "suggestion_purchased". -// feature { -// key: "suggestion_purchased" -// value { float_list { -// value: 1.0 -// }} -// } -// # Similar to "suggestion_purchased" above this feature exists to be used -// # as a label in training. -// # E.g., if training a linear regression model to predict purchase -// # price in our learning tool we would set the label feature to -// # "purchase_price". -// feature { -// key: "purchase_price" -// value { float_list { -// value: 9.99 -// }} -// } -// } -// -// A conformant Example data set obeys the following conventions: -// - If a Feature K exists in one example with data type T, it must be of -// type T in all other examples when present. It may be omitted. -// - The number of instances of Feature K list data may vary across examples, -// depending on the requirements of the model. -// - If a Feature K doesn't exist in an example, a K-specific default will be -// used, if configured. -// - If a Feature K exists in an example but contains no items, the intent -// is considered to be an empty tensor and no default will be used. - -message Example { - Features features = 1; -} - -// A SequenceExample is an Example representing one or more sequences, and -// some context. The context contains features which apply to the entire -// example. The feature_lists contain a key, value map where each key is -// associated with a repeated set of Features (a FeatureList). -// A FeatureList thus represents the values of a feature identified by its key -// over time / frames. -// -// Below is a SequenceExample for a movie recommendation application recording a -// sequence of ratings by a user. The time-independent features ("locale", -// "age", "favorites") describing the user are part of the context. The sequence -// of movies the user rated are part of the feature_lists. For each movie in the -// sequence we have information on its name and actors and the user's rating. -// This information is recorded in three separate feature_list(s). -// In the example below there are only two movies. All three feature_list(s), -// namely "movie_ratings", "movie_names", and "actors" have a feature value for -// both movies. Note, that "actors" is itself a bytes_list with multiple -// strings per movie. -// -// context: { -// feature: { -// key : "locale" -// value: { -// bytes_list: { -// value: [ "pt_BR" ] -// } -// } -// } -// feature: { -// key : "age" -// value: { -// float_list: { -// value: [ 19.0 ] -// } -// } -// } -// feature: { -// key : "favorites" -// value: { -// bytes_list: { -// value: [ "Majesty Rose", "Savannah Outen", "One Direction" ] -// } -// } -// } -// } -// feature_lists: { -// feature_list: { -// key : "movie_ratings" -// value: { -// feature: { -// float_list: { -// value: [ 4.5 ] -// } -// } -// feature: { -// float_list: { -// value: [ 5.0 ] -// } -// } -// } -// } -// feature_list: { -// key : "movie_names" -// value: { -// feature: { -// bytes_list: { -// value: [ "The Shawshank Redemption" ] -// } -// } -// feature: { -// bytes_list: { -// value: [ "Fight Club" ] -// } -// } -// } -// } -// feature_list: { -// key : "actors" -// value: { -// feature: { -// bytes_list: { -// value: [ "Tim Robbins", "Morgan Freeman" ] -// } -// } -// feature: { -// bytes_list: { -// value: [ "Brad Pitt", "Edward Norton", "Helena Bonham Carter" ] -// } -// } -// } -// } -// } -// -// A conformant SequenceExample data set obeys the following conventions: -// -// Context: -// - All conformant context features K must obey the same conventions as -// a conformant Example's features (see above). -// Feature lists: -// - A FeatureList L may be missing in an example; it is up to the -// parser configuration to determine if this is allowed or considered -// an empty list (zero length). -// - If a FeatureList L exists, it may be empty (zero length). -// - If a FeatureList L is non-empty, all features within the FeatureList -// must have the same data type T. Even across SequenceExamples, the type T -// of the FeatureList identified by the same key must be the same. An entry -// without any values may serve as an empty feature. -// - If a FeatureList L is non-empty, it is up to the parser configuration -// to determine if all features within the FeatureList must -// have the same size. The same holds for this FeatureList across multiple -// examples. -// - For sequence modeling, e.g.: -// http://colah.github.io/posts/2015-08-Understanding-LSTMs/ -// https://github.com/tensorflow/nmt -// the feature lists represent a sequence of frames. -// In this scenario, all FeatureLists in a SequenceExample have the same -// number of Feature messages, so that the ith element in each FeatureList -// is part of the ith frame (or time step). -// Examples of conformant and non-conformant examples' FeatureLists: -// -// Conformant FeatureLists: -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0 ] } } } -// } } -// -// Non-conformant FeatureLists (mismatched types): -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { int64_list: { value: [ 5 ] } } } -// } } -// -// Conditionally conformant FeatureLists, the parser configuration determines -// if the feature sizes must match: -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0, 6.0 ] } } } -// } } -// -// Conformant pair of SequenceExample -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0 ] } } } -// } } -// and: -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0 ] } } -// feature: { float_list: { value: [ 2.0 ] } } } -// } } -// -// Conformant pair of SequenceExample -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0 ] } } } -// } } -// and: -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { } -// } } -// -// Conditionally conformant pair of SequenceExample, the parser configuration -// determines if the second feature_lists is consistent (zero-length) or -// invalid (missing "movie_ratings"): -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0 ] } } } -// } } -// and: -// feature_lists: { } -// -// Non-conformant pair of SequenceExample (mismatched types) -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0 ] } } } -// } } -// and: -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { int64_list: { value: [ 4 ] } } -// feature: { int64_list: { value: [ 5 ] } } -// feature: { int64_list: { value: [ 2 ] } } } -// } } -// -// Conditionally conformant pair of SequenceExample; the parser configuration -// determines if the feature sizes must match: -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.5 ] } } -// feature: { float_list: { value: [ 5.0 ] } } } -// } } -// and: -// feature_lists: { feature_list: { -// key: "movie_ratings" -// value: { feature: { float_list: { value: [ 4.0 ] } } -// feature: { float_list: { value: [ 5.0, 3.0 ] } } -// } } - -message SequenceExample { - Features context = 1; - FeatureLists feature_lists = 2; -} diff --git a/torchdata/datapipes/iter/util/protobuf_template/feature.proto b/torchdata/datapipes/iter/util/protobuf_template/feature.proto deleted file mode 100644 index defc73e8a..000000000 --- a/torchdata/datapipes/iter/util/protobuf_template/feature.proto +++ /dev/null @@ -1,110 +0,0 @@ -// Protocol messages for describing features for machine learning model -// training or inference. -// -// There are three base Feature types: -// - bytes -// - float -// - int64 -// -// A Feature contains Lists which may hold zero or more values. These -// lists are the base values BytesList, FloatList, Int64List. -// -// Features are organized into categories by name. The Features message -// contains the mapping from name to Feature. -// -// Example Features for a movie recommendation application: -// feature { -// key: "age" -// value { float_list { -// value: 29.0 -// }} -// } -// feature { -// key: "movie" -// value { bytes_list { -// value: "The Shawshank Redemption" -// value: "Fight Club" -// }} -// } -// feature { -// key: "movie_ratings" -// value { float_list { -// value: 9.0 -// value: 9.7 -// }} -// } -// feature { -// key: "suggestion" -// value { bytes_list { -// value: "Inception" -// }} -// } -// feature { -// key: "suggestion_purchased" -// value { int64_list { -// value: 1 -// }} -// } -// feature { -// key: "purchase_price" -// value { float_list { -// value: 9.99 -// }} -// } -// - -syntax = "proto3"; - -package tensorflow; - -option cc_enable_arenas = true; -option java_outer_classname = "FeatureProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.example"; -option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/example/example_protos_go_proto"; - -// LINT.IfChange -// Containers to hold repeated fundamental values. -message BytesList { - repeated bytes value = 1; -} -message FloatList { - repeated float value = 1 [ packed = true ]; -} -message Int64List { - repeated int64 value = 1 [ packed = true ]; -} - -// Containers for non-sequential data. -message Feature { - // Each feature can be exactly one kind. - oneof kind { - BytesList bytes_list = 1; - FloatList float_list = 2; - Int64List int64_list = 3; - } -} - -message Features { - // Map from feature name to feature. - map feature = 1; -} - -// Containers for sequential data. -// -// A FeatureList contains lists of Features. These may hold zero or more -// Feature values. -// -// FeatureLists are organized into categories by name. The FeatureLists message -// contains the mapping from name to FeatureList. -// -message FeatureList { - repeated Feature feature = 1; -} - -message FeatureLists { - // Map from feature name to feature list. - map feature_list = 1; -} -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/python/training/training.py) diff --git a/torchdata/datapipes/iter/util/randomsplitter.py b/torchdata/datapipes/iter/util/randomsplitter.py deleted file mode 100644 index 2972122f9..000000000 --- a/torchdata/datapipes/iter/util/randomsplitter.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import random -from typing import Dict, final, List, Optional, TypeVar, Union - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -T = TypeVar("T") - - -@functional_datapipe("random_split") -class RandomSplitterIterDataPipe(IterDataPipe): - r""" - Randomly split samples from a source DataPipe into groups (functional name: ``random_split``). - Since there is no buffer, only ONE group of samples (i.e. one child DataPipe) can be iterated through - at any time. Attempts to iterate through multiple of them simultaneously will fail. - - Note that by default, multiple iterations of this DataPipe will yield the same split for consistency across epochs. - You can invoke ``override_seed`` on the output(s) to update the seed whenever needed (such as per epoch to - get a different split per epoch). - - Args: - source_datapipe: Iterable DataPipe being split - weights: Dict of weights; the length of this list determines how many output DataPipes there will be. - It is recommended to provide integer weights that sum up to ``total_length``, which allows - resulting DataPipes' length values to be known in advance. - seed: random _seed used to determine the randomness of the split - total_length: Length of the ``source_datapipe``, optional but providing an integer is highly encouraged, - because not all ``IterDataPipe`` has ``len``, espeically ones that can be easily known in advance. - target: Optional key (that must exist in ``weights``) to indicate the specific group to return. - If set to the default ``None``, returns ``List[IterDataPipe]``. - If target is specified, returns ``IterDataPipe``. - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp = IterableWrapper(range(10)) - >>> train, valid = dp.random_split(total_length=10, weights={"train": 0.5, "valid": 0.5}, seed=0) - >>> list(train) - [2, 3, 5, 7, 8] - >>> list(valid) - [0, 1, 4, 6, 9] - >>> # You can also specify a target key if you only need a specific group of samples - >>> train = dp.random_split(total_length=10, weights={"train": 0.5, "valid": 0.5}, seed=0, target='train') - >>> list(train) - [2, 3, 5, 7, 8] - >>> # Be careful to use the same seed as before when specifying `target` to get the correct split. - >>> valid = dp.random_split(total_length=10, weights={"train": 0.5, "valid": 0.5}, seed=0, target='valid') - >>> list(valid) - [0, 1, 4, 6, 9] - """ - - def __new__( - cls, - source_datapipe: IterDataPipe, - weights: Dict[T, Union[int, float]], - seed, - total_length: Optional[int] = None, - target: Optional[T] = None, - ): - if total_length is None: - try: - # TODO: This is an issue for DataPipes which only have runtime lengths. Revisit to see if this - # is problematic. - total_length = len(source_datapipe) - except TypeError: - raise TypeError( - "RandomSplitter needs `total_length`, but it is unable to infer it from " - f"the `source_datapipe`: {source_datapipe}." - ) - container = _RandomSplitterIterDataPipe(source_datapipe, total_length, weights, seed) # type: ignore - if target is None: - return [SplitterIterator(container, k) for k in list(weights.keys())] - else: - if target in weights.keys(): - return SplitterIterator(container, target) - else: - raise KeyError(f"`target={target}` does not match any key in `weights`.") - - -class _RandomSplitterIterDataPipe(IterDataPipe): - def __init__( - self, - source_datapipe: IterDataPipe, - total_length: int, - weights: Dict[T, Union[int, float]], - seed, - ): - self.source_datapipe: IterDataPipe = source_datapipe - self.total_length: int = total_length - self.remaining_length: int = total_length - self._seed = seed - self.keys: List[T] = list(weights.keys()) - self.key_to_index: Dict[T, int] = {k: i for i, k in enumerate(self.keys)} - self.norm_weights: List[float] = self.normalize_weights([weights[k] for k in self.keys], total_length) - self.weights: List[float] = self.norm_weights.copy() - self._rng = random.Random(self._seed) - self._lengths: List[int] = [] - - def draw(self) -> T: # type: ignore - selected_key = self._rng.choices(self.keys, self.weights)[0] - index = self.key_to_index[selected_key] - self.weights[index] -= 1 - self.remaining_length -= 1 - if self.weights[index] < 0: - self.weights[index] = 0 - self.weights = self.normalize_weights(self.weights, self.remaining_length) - return selected_key - - @staticmethod - def normalize_weights(weights: List[float], total_length: int) -> List[float]: - """ - Given a ``List`` of weights, normalize them according to ``total_length``. - """ - total_weight = sum(weights) - return [float(w) * total_length / total_weight for w in weights] - - @final - def reset(self) -> None: - self._rng = random.Random(self._seed) - self.weights = self.norm_weights.copy() - self.remaining_length = self.total_length - - def override_seed(self, seed): - """ - Update the `seed`. The new `seed` will be used in the next iteration. - """ - self._seed = seed - return self - - def __getstate__(self): - state = ( - self.source_datapipe, - self.total_length, - self._seed, - self.norm_weights, - self.keys, - self.key_to_index, - self.weights, - self._rng.getstate(), - ) - if IterDataPipe.getstate_hook is not None: - return IterDataPipe.getstate_hook(state) - return state - - def __setstate__(self, state): - ( - self.source_datapipe, - self.total_length, - self._seed, - self.norm_weights, - self.keys, - self.key_to_index, - self.weights, - rng_state, - ) = state - self._rng = random.Random() - self._rng.setstate(rng_state) - - def get_length(self, target: T) -> int: - if not self._lengths: - if all(w.is_integer() for w in self.norm_weights) and sum(self.norm_weights) == self.total_length: - self._lengths = [int(w) for w in self.norm_weights] - else: - raise TypeError( - "Lengths of the split cannot be known in advance. Please supply " - "integer `weights` that sum up to `total_length`.\nAlternatively, " - "use `datapipe.set_length(LENGTH)` to manually set the desired length." - ) - index = self.key_to_index[target] - return self._lengths[index] - - -class SplitterIterator(IterDataPipe): - def __init__(self, main_datapipe: _RandomSplitterIterDataPipe, target: T): - self.main_datapipe = main_datapipe - self.target = target - - def __iter__(self): - self.main_datapipe.reset() - for sample in self.main_datapipe.source_datapipe: - if self.main_datapipe.draw() == self.target: - yield sample - - def override_seed(self, seed): - """ - Update the `seed`. The new `seed` will be used in the next iteration. For use cases that require a different - split for each epoch, you call this method before or after the epoch as necessary. - """ - self.main_datapipe.override_seed(seed) - return self - - def __len__(self): - return self.main_datapipe.get_length(self.target) diff --git a/torchdata/datapipes/iter/util/rararchiveloader.py b/torchdata/datapipes/iter/util/rararchiveloader.py deleted file mode 100644 index da9610a54..000000000 --- a/torchdata/datapipes/iter/util/rararchiveloader.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import io -import os.path -from typing import Iterator, Tuple -from unittest.mock import patch - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -from torchdata.datapipes.utils import StreamWrapper -from torchdata.datapipes.utils.common import validate_pathname_binary_tuple - - -class RarfilePatcher: - def __init__(self): - from rarfile import DirectReader - - unpatched_read = DirectReader._read - - def patched_read(self, cnt=-1): - self._fd.seek(self._inf.header_offset, 0) - self._cur = self._parser._parse_header(self._fd) - self._cur_avail = self._cur.add_size - return unpatched_read(self, cnt) - - self._patch = patch("rarfile.DirectReader._read", new=patched_read) - - def start(self): - self._patch.start() - - def stop(self): - self._patch.stop() - - -_PATCHED = False - - -@functional_datapipe("load_from_rar") -class RarArchiveLoaderIterDataPipe(IterDataPipe[Tuple[str, io.BufferedIOBase]]): - r""" - Decompresses rar binary streams from input Iterable Datapipes which contains tuples of path name and rar - binary stream, and yields a tuple of path name and extracted binary stream (functional name: ``load_from_rar``). - - Note: - The nested RAR archive is not supported by this DataPipe - due to the limitation of the archive type. Please extract - outer RAR archive before reading the inner archive. - - Args: - datapipe: Iterable DataPipe that provides tuples of path name and rar binary stream - length: Nominal length of the DataPipe - - Example: - >>> from torchdata.datapipes.iter import FileLister, FileOpener - >>> datapipe1 = FileLister(".", "*.rar") - >>> datapipe2 = FileOpener(datapipe1, mode="b") - >>> rar_loader_dp = datapipe2.load_from_rar() - >>> for _, stream in rar_loader_dp: - >>> print(stream.read()) - b'0123456789abcdef' - """ - - def __init__(self, datapipe: IterDataPipe[Tuple[str, io.BufferedIOBase]], *, length: int = -1): - try: - import rarfile - except ImportError as error: - raise ModuleNotFoundError( - "Package `rarfile` is required to be installed to use this datapipe. " - "Please use `pip install rarfile` or `conda -c conda-forge install rarfile` to install it." - ) from error - - # check if at least one system library for reading rar archives is available to be used by rarfile - rarfile.tool_setup() - - self.datapipe = datapipe - self.length = length - - def __iter__(self) -> Iterator[Tuple[str, io.BufferedIOBase]]: - import rarfile - - global _PATCHED - if not _PATCHED: - patcher = RarfilePatcher() - patcher.start() - _PATCHED = True - - for data in self.datapipe: - try: - validate_pathname_binary_tuple(data) - path, stream = data - if isinstance(stream, rarfile.RarExtFile) or ( - isinstance(stream, StreamWrapper) and isinstance(stream.file_obj, rarfile.RarExtFile) - ): - raise ValueError( - f"Nested RAR archive is not supported by {type(self).__name__}. Please extract outer archive first." - ) - - rar = rarfile.RarFile(stream) - for info in rar.infolist(): - if info.is_dir(): - continue - - inner_path = os.path.join(path, info.filename) - file_obj = rar.open(info) - yield inner_path, StreamWrapper(file_obj, stream, name=path) # type: ignore[misc] - finally: - if isinstance(stream, StreamWrapper): - stream.autoclose() - - def __len__(self) -> int: - if self.length == -1: - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") - return self.length diff --git a/torchdata/datapipes/iter/util/rows2columnar.py b/torchdata/datapipes/iter/util/rows2columnar.py deleted file mode 100644 index 5764a3d60..000000000 --- a/torchdata/datapipes/iter/util/rows2columnar.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from collections import defaultdict -from typing import Dict, Iterator, List, Optional, Union - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - - -@functional_datapipe("rows2columnar") -class Rows2ColumnarIterDataPipe(IterDataPipe[Dict]): - r""" - Accepts an input DataPipe with batches of data, and processes one batch - at a time and yields a Dict for each batch, with ``column_names`` as keys and lists of - corresponding values from each row as values (functional name: ``rows2columnar``). - - Within the input DataPipe, each row within a batch must either be a `Dict` or a `List` - - Note: - If ``column_names`` are not given and each row is a `Dict`, the keys of that Dict will be used as column names. - - Args: - source_datapipe: a DataPipe where each item is a batch. Within each batch, - there are rows and each row is a `List` or `Dict` - column_names: if each element in a batch contains `Dict`, ``column_names`` act as a filter for matching keys; - otherwise, these are used as keys to for the generated `Dict` of each batch - - Example: - >>> # Each element in a batch is a `Dict` - >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp = IterableWrapper([[{'a': 1}, {'b': 2, 'a': 1}], [{'a': 1, 'b': 200}, {'b': 2, 'c': 3, 'a': 100}]]) - >>> row2col_dp = dp.rows2columnar() - >>> list(row2col_dp) - [defaultdict(, {'a': [1, 1], 'b': [2]}), - defaultdict(, {'a': [1, 100], 'b': [200, 2], 'c': [3]})] - >>> row2col_dp = dp.rows2columnar(column_names=['a']) - >>> list(row2col_dp) - [defaultdict(, {'a': [1, 1]}), - defaultdict(, {'a': [1, 100]})] - >>> # Each element in a batch is a `List` - >>> dp = IterableWrapper([[[0, 1, 2, 3], [4, 5, 6, 7]]]) - >>> row2col_dp = dp.rows2columnar(column_names=["1st_in_batch", "2nd_in_batch", "3rd_in_batch", "4th_in_batch"]) - >>> list(row2col_dp) - [defaultdict(, {'1st_in_batch': [0, 4], '2nd_in_batch': [1, 5], - '3rd_in_batch': [2, 6], '4th_in_batch': [3, 7]})] - """ - column_names: List[str] - - def __init__( - self, source_datapipe: IterDataPipe[List[Union[Dict, List]]], column_names: Optional[List[str]] = None - ) -> None: - self.source_datapipe: IterDataPipe[List[Union[Dict, List]]] = source_datapipe - self.column_names: List[str] = [] if column_names is None else column_names - - def __iter__(self) -> Iterator[Dict]: - for batch in self.source_datapipe: - columnar = defaultdict(list) - for list_or_dict_row in batch: - if isinstance(list_or_dict_row, dict): - # if column_names provided, we use it as a filter - if len(self.column_names) > 0: - for column_name in self.column_names: - # this line will raise a KeyError if column_name - # is not within list_or_dict_row which is the - # expected behavior - columnar[column_name].append(list_or_dict_row[column_name]) - else: - for k, v in list_or_dict_row.items(): - columnar[k].append(v) - else: - for i, v in enumerate(list_or_dict_row): - columnar[self.column_names[i]].append(v) - yield columnar - - def __len__(self) -> int: - return len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/util/samplemultiplexer.py b/torchdata/datapipes/iter/util/samplemultiplexer.py deleted file mode 100644 index 3f55c1acc..000000000 --- a/torchdata/datapipes/iter/util/samplemultiplexer.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import random -from typing import Dict, Iterator, Optional, Sized, TypeVar - -from torchdata.datapipes.iter import IterDataPipe - - -T_co = TypeVar("T_co", covariant=True) - - -class SampleMultiplexerDataPipe(IterDataPipe[T_co]): - """ - Takes a `Dict` of (IterDataPipe, Weight), and yields items by sampling from these - DataPipes with respect to their weights. When individual DataPipes are exhausted, continues to sample from - the remaining DataPipes according to their relative weights. - If you wish to maintain the same ratio of weights indefinitely, you need to ensure that the - inputs are never exhausted, by, for instance, applying ``cycle`` to them. - - Sampling is controlled by the provided random ``seed``. If you don't provide it, the sampling - will not be deterministic. - - Args: - pipes_to_weights_dict: a `Dict` of IterDataPipes and Weights. The total weight of - unexhausted DataPipes will be normalized to 1 for the purpose of sampling. - seed: random seed to initialize the random number generator - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper, SampleMultiplexer - >>> source_dp1 = IterableWrapper([0] * 5) - >>> source_dp2 = IterableWrapper([1] * 5) - >>> d = {source_dp1: 99999999, source_dp2: 0.0000001} - >>> sample_mul_dp = SampleMultiplexer(pipes_to_weights_dict=d, seed=0) - >>> list(sample_mul_dp) - [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] - """ - - def __init__( - self, - pipes_to_weights_dict: Dict[IterDataPipe[T_co], float], - seed: Optional[int] = None, - ): - if not pipes_to_weights_dict: - raise ValueError("Empty dictionary passed to SampleMultiplexerDataPipe") - total_weight: float = 0 - for v in pipes_to_weights_dict.values(): - if v <= 0: - raise ValueError(f"Expecting a positive and non-zero weight, got {v}") - total_weight += v - - self.pipes_and_weights = [(k, v / total_weight) for k, v in pipes_to_weights_dict.items()] - if seed is None: - self.random = random.Random() - else: - self.random = random.Random(seed) - - def __iter__(self) -> Iterator[T_co]: - pipes_and_weights = [(iter(k), v) for k, v in self.pipes_and_weights] - while len(pipes_and_weights) > 1: - r = self.random.random() - s: float = 0 - for it, weight in pipes_and_weights: - s += weight - if r < s: - try: - item = next(it) - yield item - except StopIteration: - # remove the current stream - new_total = 1 - weight - assert new_total > 0 - pipes_and_weights = [(k, v / new_total) for k, v in pipes_and_weights if k != it] - break - - # only one stream left - for item in pipes_and_weights[0][0]: - yield item - - def __len__(self) -> int: - if all(isinstance(dp, Sized) for dp, _ in self.pipes_and_weights): - return sum(len(dp) for dp, _ in self.pipes_and_weights) - else: - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/torchdata/datapipes/iter/util/saver.py b/torchdata/datapipes/iter/util/saver.py deleted file mode 100644 index 0c8332c15..000000000 --- a/torchdata/datapipes/iter/util/saver.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os - -from typing import Any, Callable, Iterator, Optional, Tuple, Union - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -U = Union[bytes, bytearray, str] - - -@functional_datapipe("save_to_disk") -class SaverIterDataPipe(IterDataPipe[str]): - r""" - Takes in a DataPipe of tuples of metadata and data, saves the data - to the target path generated by the ``filepath_fn`` and metadata, and yields file path on local file - system (functional name: ``save_to_disk``). - - Args: - source_datapipe: Iterable DataPipe with tuples of metadata and data - mode: Node in which the file will be opened for write the data (``"w"`` by default) - filepath_fn: Function that takes in metadata and returns the target path of the new file - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> import os - >>> def filepath_fn(name: str) -> str: - >>> return os.path.join(".", os.path.basename(name)) - >>> name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"} - >>> source_dp = IterableWrapper(sorted(name_to_data.items())) - >>> saver_dp = source_dp.save_to_disk(filepath_fn=filepath_fn, mode="wb") - >>> res_file_paths = list(saver_dp) - >>> res_file_paths - ['./1.txt', './2.txt', './3.txt'] - """ - - def __init__( - self, - source_datapipe: IterDataPipe[Tuple[Any, U]], - mode: str = "w", - filepath_fn: Optional[Callable] = None, - ): - self.source_datapipe: IterDataPipe[Tuple[Any, U]] = source_datapipe - self.mode: str = mode if "w" in mode else "w" + mode - self.fn: Optional[Callable] = filepath_fn - - def __iter__(self) -> Iterator[str]: - for filepath, data in self.source_datapipe: - if self.fn is not None: - filepath = self.fn(filepath) - dirname = os.path.dirname(filepath) - if not os.path.exists(dirname): - os.makedirs(dirname) - # with portalocker.Lock(filepath, self.mode, flags=portalocker.LockFlags.EXCLUSIVE) as f: - # TODO(639): Enabling line above will require all read sites to be updated (Win). - with open(filepath, self.mode) as f: - f.write(data) - yield filepath - - def __len__(self) -> int: - return len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/util/shardexpander.py b/torchdata/datapipes/iter/util/shardexpander.py deleted file mode 100644 index 8a3c1fc82..000000000 --- a/torchdata/datapipes/iter/util/shardexpander.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import re -from typing import Iterator, List - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - - -def _shard_expand(s: str) -> List[str]: - expansion = r"[{][0-9]+[.][.][0-9]+[}]" - m = re.search(expansion, s) - if not m: - return [s] - prefix = s[: m.start()] - rest = _shard_expand(s[m.end() :]) - rng = s[m.start() + 1 : m.end() - 1] - lohi = rng.split("..") - if len(lohi[0]) == len(lohi[1]) and lohi[0].startswith("0"): - fmt = "{prefix}{i:0>{l}d}{r}" - elif len(lohi[0]) <= len(lohi[1]): - if lohi[0].startswith("0") and lohi[0] != "0": - raise ValueError("shard_expand: low bound must not start with 0 if low bound is shorter") - fmt = "{prefix}{i}{r}" - else: - raise ValueError("shard_expand: low bound must be shorter than high bound") - lo, hi = (int(x) for x in lohi) - if lo >= hi: - raise ValueError(f"shard_expand: bad range in in shard spec {s}.") - result = [] - for i in range(lo, hi + 1): - for r in rest: - expanded: str = fmt.format(prefix=prefix, i=i, r=r, l=len(lohi[1])) - result.append(expanded) - return result - - -@functional_datapipe("shard_expand") -class ShardExpanderIterDataPipe(IterDataPipe[str]): - r""" - Expands incoming shard strings into shards. - - Sharded data files are named using shell-like brace notation. For example, - an ImageNet dataset sharded into 1200 shards and stored on a web server - might be named `imagenet-{000000..001199}.tar`. - - Note that shard names can be expanded without any server transactions; - this makes `shard_expand` reproducible and storage system independent - (unlike :class `.FileLister` etc.). - - Args: - source_datapipe: a DataPipe yielding a stream of pairs - - Returns: - a DataPipe yielding a stream of expanded pathnames. - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> source_dp = IterableWrapper(["ds-{00..05}.tar"]) - >>> expand_dp = source_dp.shard_expand() - >>> list(expand_dp) - ['ds-00.tar', 'ds-01.tar', 'ds-02.tar', 'ds-03.tar', 'ds-04.tar', 'ds-05.tar'] - >>> source_dp = IterableWrapper(["imgs_{00..05}.tar", "labels_{00..05}.tar"]) - >>> expand_dp = source_dp.shard_expand() - >>> list(expand_dp) - ['imgs_00.tar', 'imgs_01.tar', 'imgs_02.tar', 'labels_00.tar', 'labels_01.tar', 'labels_02.tar'] - """ - - def __init__(self, source_datapipe: IterDataPipe[str]) -> None: - super().__init__() - self.source_datapipe: IterDataPipe[str] = source_datapipe - - def __iter__(self) -> Iterator[str]: - for path in self.source_datapipe: - yield from _shard_expand(path) diff --git a/torchdata/datapipes/iter/util/sharding.py b/torchdata/datapipes/iter/util/sharding.py deleted file mode 100644 index 9d6f51927..000000000 --- a/torchdata/datapipes/iter/util/sharding.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Iterator, Optional, TypeVar - -from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -T_co = TypeVar("T_co", covariant=True) - - -@functional_datapipe("sharding_round_robin_dispatch") -class ShardingRoundRobinDispatcherIterDataPipe(IterDataPipe): - r""" - Wrapper that indicates the prior section of ``DataPipe`` graph is non-replicable and will be - iterated in a separate, single dispatching process to distribute data to worker processes - in a round-robin manner when multiprocessing is being used. - (functional name: ``sharding_round_robin_dispatch``). - - Args: - source_datapipe: Iterable DataPipe that will be sharded - sharding_group_filter: Optional ``SHARDING_PRIORITIES`` value - - Note: - - ``sharding_group_filter`` only accepts ``SHARDING_PRIORITIES.MULTIPROCESSING`` for now - - When using distributed training, you can add a ``sharding_filter()`` prior to this DataPipe - to distribute samples among worker nodes. - - Examples: - >>> # xdoctest: +SKIP - >>> from torchdata.datapipes.iter import IterableWrapper - >>> from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES - >>> dp = IterableWrapper(range(10)) - >>> # `.shuffle()` will be executed in a single dispatching processing, then the samples are distributed - >>> # to worker processes - >>> dp = dp.shuffle().sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING) - >>> # `.map()` will be executed within each worker process - >>> dp = dp.map(lambda x: x + 1) - >>> # Distributed case: the 10 samples will be distributed among the nodes - >>> dp = IterableWrapper(range(10)).sharding_filter() - >>> # `.map()` will be executed in a single dispatching processing in each node - >>> # You may apply further transformation after within each worker process - >>> dp = dp.map(lambda x: x + 1).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING) - """ - - def __init__(self, source_datapipe: IterDataPipe, sharding_group_filter: Optional[SHARDING_PRIORITIES] = None): - self.source_datapipe = source_datapipe - if sharding_group_filter != SHARDING_PRIORITIES.MULTIPROCESSING: - raise NotImplementedError( - "`sharding_round_robin_dispatch` currently only supports `SHARDING_PRIORITIES.MULTIPROCESSING`." - "Please open issue on github for your feature request." - ) - self.sharding_group_filter = sharding_group_filter - - def __iter__(self) -> Iterator[T_co]: - yield from self.source_datapipe - - def __len__(self) -> int: - return len(self.source_datapipe) diff --git a/torchdata/datapipes/iter/util/tararchiveloader.py b/torchdata/datapipes/iter/util/tararchiveloader.py deleted file mode 100644 index e7cff2b9d..000000000 --- a/torchdata/datapipes/iter/util/tararchiveloader.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import tarfile -import warnings -from io import BufferedIOBase -from typing import cast, IO, Iterable, Iterator, Optional, Tuple - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -from torchdata.datapipes.utils import StreamWrapper -from torchdata.datapipes.utils.common import validate_pathname_binary_tuple - - -@functional_datapipe("load_from_tar") -class TarArchiveLoaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]): - r""" - Opens/decompresses tar binary streams from an Iterable DataPipe which contains tuples of path name and - tar binary stream, and yields a tuple of path name and extracted binary stream (functional name: ``load_from_tar``). - - Args: - datapipe: Iterable DataPipe that provides tuples of path name and tar binary stream - mode: File mode used by `tarfile.open` to read file object. - Mode has to be a string of the form `'filemode[:compression]'` - length: a nominal length of the DataPipe - - Note: - The opened file handles will be closed automatically if the default ``DecoderDataPipe`` - is attached. Otherwise, user should be responsible to close file handles explicitly - or let Python's GC close them periodically. - - Example: - >>> from torchdata.datapipes.iter import FileLister, FileOpener - >>> datapipe1 = FileLister(".", "*.tar") - >>> datapipe2 = FileOpener(datapipe1, mode="b") - >>> tar_loader_dp = datapipe2.load_from_tar() - >>> for _, stream in tar_loader_dp: - >>> print(stream.read()) - b'0123456789abcdef' - """ - - def __init__(self, datapipe: Iterable[Tuple[str, BufferedIOBase]], mode: str = "r:*", length: int = -1) -> None: - super().__init__() - self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe - self.mode: str = mode - self.length: int = length - - def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: - for data in self.datapipe: - validate_pathname_binary_tuple(data) - pathname, data_stream = data - try: - if isinstance(data_stream, StreamWrapper) and isinstance(data_stream.file_obj, tarfile.TarFile): - tar = data_stream.file_obj - else: - reading_mode = ( - self.mode - if hasattr(data_stream, "seekable") and data_stream.seekable() - else self.mode.replace(":", "|") - ) - # typing.cast is used here to silence mypy's type checker - tar = tarfile.open(fileobj=cast(Optional[IO[bytes]], data_stream), mode=reading_mode) - for tarinfo in tar: - if not tarinfo.isfile(): - continue - extracted_fobj = tar.extractfile(tarinfo) - if extracted_fobj is None: - warnings.warn(f"failed to extract file {tarinfo.name} from source tarfile {pathname}") - raise tarfile.ExtractError - inner_pathname = os.path.normpath(os.path.join(pathname, tarinfo.name)) - yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc] - except Exception as e: - warnings.warn(f"Unable to extract files from corrupted tarfile stream {pathname} due to: {e}, abort!") - raise e - finally: - if isinstance(data_stream, StreamWrapper): - data_stream.autoclose() - - def __len__(self) -> int: - if self.length == -1: - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") - return self.length diff --git a/torchdata/datapipes/iter/util/tfrecordloader.py b/torchdata/datapipes/iter/util/tfrecordloader.py deleted file mode 100644 index 8ad5ea101..000000000 --- a/torchdata/datapipes/iter/util/tfrecordloader.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import struct -import warnings -from functools import partial -from io import BufferedIOBase -from typing import Any, cast, Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union - -import torch - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -from torchdata.datapipes.utils.common import validate_pathname_binary_tuple - -try: - from math import prod # type: ignore -except ImportError: - # Implementation for older Python - # NOTE: this is not supported by mypy yet - # https://github.com/python/mypy/issues/1393 - import operator - from functools import reduce - - def prod(xs): # type: ignore[no-redef] - return reduce(operator.mul, xs, 1) - - -try: - import google.protobuf as _protobuf - - del _protobuf - HAS_PROTOBUF = True -except ImportError: - HAS_PROTOBUF = False - -U = Union[bytes, bytearray, str] -TFRecordFeatureSpec = Tuple[Tuple[int, ...], torch.dtype] -TFRecordExampleSpec = Dict[str, TFRecordFeatureSpec] - -# Note, reccursive types not supported by mypy at the moment -# TODO(640): uncomment as soon as it becomes supported -# https://github.com/python/mypy/issues/731 -# BinaryData = Union[str, List['BinaryData']] -TFRecordBinaryData = Union[str, List[str], List[List[str]], List[List[List[Any]]]] -TFRecordExampleFeature = Union[torch.Tensor, List[torch.Tensor], TFRecordBinaryData] -TFRecordExample = Dict[str, TFRecordExampleFeature] - - -class SequenceExampleSpec(NamedTuple): - context: TFRecordExampleSpec - feature_lists: TFRecordExampleSpec - - -def _assert_protobuf() -> None: - if not HAS_PROTOBUF: - raise ModuleNotFoundError( - "Package `protobuf` is required to be installed to use this datapipe." - "Please use `pip install protobuf` or `conda install -c conda-forge protobuf`" - "to install the package" - ) - - -def iterate_tfrecord_file(data: BufferedIOBase) -> Iterator[memoryview]: - length_bytes = bytearray(8) - crc_bytes = bytearray(4) - data_bytes = bytearray(1024) - - while True: - bytes_read = data.readinto(length_bytes) - if bytes_read == 0: - break - elif bytes_read != 8: - raise RuntimeError("Invalid tfrecord file: failed to read the record size.") - if data.readinto(crc_bytes) != 4: - raise RuntimeError("Invalid tfrecord file: failed to read the start token.") - (length,) = struct.unpack(" len(data_bytes): - data_bytes = data_bytes.zfill(int(length * 1.5)) - data_bytes_view = memoryview(data_bytes)[:length] - if data.readinto(data_bytes_view) != length: - raise RuntimeError("Invalid tfrecord file: failed to read the record.") - if data.readinto(crc_bytes) != 4: - raise RuntimeError("Invalid tfrecord file: failed to read the end token.") - - # TODO(641): check CRC - yield data_bytes_view - - -def process_feature(feature) -> torch.Tensor: - # NOTE: We assume that each key in the example has only one field - # (either "bytes_list", "float_list", or "int64_list")! - field = feature.ListFields()[0] - inferred_typename, value = field[0].name, field[1].value - if inferred_typename == "bytes_list": - pass - elif inferred_typename == "float_list": - value = torch.tensor(value, dtype=torch.float32) - elif inferred_typename == "int64_list": - value = torch.tensor(value, dtype=torch.int64) - return value - - -def _reshape_list(value, shape): - # Flatten list - flat_list = [] - - def flatten(value): - if isinstance(value, (str, bytes)): - flat_list.append(value) - else: - for x in value: - flatten(x) - - flatten(value) - - # Compute correct shape - common_divisor = prod(x for x in shape if x != -1) - if sum(1 for x in shape if x == -1) > 1: - raise RuntimeError("Shape can contain at most one dynamic dimension (-1).") - if len(flat_list) % max(common_divisor, 1) != 0: - raise RuntimeError(f"Cannot reshape {len(flat_list)} values into shape {shape}") - shape = [x if x != -1 else (len(flat_list) // common_divisor) for x in shape] - - # Reshape list into the correct shape - def _reshape(value, shape): - if len(shape) == 0: - assert len(value) == 1 - return value[0] - elif len(shape) == 1: # To make the reccursion faster - assert len(value) == shape[0] - return value - dim_size = len(value) // shape[0] - return [_reshape(value[i * dim_size : (i + 1) * dim_size], shape[1:]) for i in range(dim_size)] - - return _reshape(flat_list, shape) - - -def _apply_feature_spec(value, feature_spec): - if feature_spec is not None: - shape, dtype = feature_spec - if isinstance(dtype, torch.dtype): - if shape is not None: - value = value.reshape(shape) - value = value.to(dtype) - elif shape is not None: - # Manual list reshape - value = _reshape_list(value, shape) - return value - - -def _parse_tfrecord_features(features, spec: Optional[TFRecordExampleSpec]) -> Dict[str, torch.Tensor]: - result = dict() - features = features.feature - for key in features.keys(): - if spec is not None and key not in spec: - continue - feature_spec = None if spec is None else spec[key] - feature = features[key] - result[key] = _apply_feature_spec(process_feature(feature), feature_spec) - return result - - -def parse_tfrecord_sequence_example(example, spec: Optional[TFRecordExampleSpec]) -> TFRecordExample: - # Parse context features - result = cast(TFRecordExample, _parse_tfrecord_features(example.context, spec)) - - # Parse feature lists - feature_lists_keys = None if spec is None else set(spec.keys()) - set(result.keys()) - features = example.feature_lists.feature_list - for key in features.keys(): - if feature_lists_keys is not None and key not in feature_lists_keys: - continue - feature_spec = None if spec is None else spec[key] - feature = features[key].feature - if key in result: - raise RuntimeError( - "TFRecord example's key {key} is contained in both the context and feature lists. This is not supported." - ) - - value: Union[torch.Tensor, List[Any]] = list(map(partial(process_feature), feature)) - - # For known torch dtypes, we stack the list features - if feature_spec is not None and isinstance(feature_spec[1], torch.dtype): - value = torch.stack(cast(List[torch.Tensor], value), 0) - value = _apply_feature_spec(value, feature_spec) - result[key] = value - if spec is not None and len(result.keys()) != len(spec.keys()): - raise RuntimeError(f"Example is missing some required keys: {sorted(result.keys())} != {sorted(spec.keys())}") - return result - - -@functional_datapipe("load_from_tfrecord") -class TFRecordLoaderIterDataPipe(IterDataPipe[TFRecordExample]): - r""" - Opens/decompresses tfrecord binary streams from an Iterable DataPipe which contains tuples of path name and - tfrecord binary stream, and yields the stored records (functional name: ``load_from_tfrecord``). - - Args: - datapipe: Iterable DataPipe that provides tuples of path name and tfrecord binary stream - length: a nominal length of the DataPipe - - Note: - The opened file handles will be closed automatically if the default ``DecoderDataPipe`` - is attached. Otherwise, user should be responsible to close file handles explicitly - or let Python's GC close them periodically. - - Example: - >>> from torchdata.datapipes.iter import FileLister, FileOpener - >>> datapipe1 = FileLister(".", "*.tfrecord") - >>> datapipe2 = FileOpener(datapipe1, mode="b") - >>> tfrecord_loader_dp = datapipe2.load_from_tfrecord() - >>> for example in tfrecord_loader_dp: - >>> print(example) - """ - - def __init__( - self, - datapipe: Iterable[Tuple[str, BufferedIOBase]], - spec: Optional[TFRecordExampleSpec] = None, - length: int = -1, - ) -> None: - super().__init__() - _assert_protobuf() - - self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe - self.length: int = length - self.spec = spec - - def __iter__(self) -> Iterator[TFRecordExample]: - # We assume that the "example.proto" and "feature.proto" - # stays the same for future TensorFlow versions. - # If it changed, newer TensorFlow versions would - # not be able to load older tfrecord datasets. - from .protobuf_template import _tfrecord_example_pb2 as example_pb2 - - for data in self.datapipe: - validate_pathname_binary_tuple(data) - pathname, data_stream = data - try: - for example_bytes in iterate_tfrecord_file(data_stream): - example = example_pb2.SequenceExample() # type: ignore - example.ParseFromString(example_bytes) # type: ignore - yield parse_tfrecord_sequence_example(example, self.spec) - except RuntimeError as e: - warnings.warn(f"Unable to read from corrupted tfrecord stream {pathname} due to: {e}, abort!") - raise e - - def __len__(self) -> int: - if self.length == -1: - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") - return self.length diff --git a/torchdata/datapipes/iter/util/webdataset.py b/torchdata/datapipes/iter/util/webdataset.py deleted file mode 100644 index 309f2312c..000000000 --- a/torchdata/datapipes/iter/util/webdataset.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import re -from typing import Any, Dict, Iterator, List, Union - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - - -def pathsplit(p): - """Split a path into a WebDataset prefix and suffix. - - The prefix is used for grouping files into samples, - the suffix is used as key in the output dictionary. - The suffix consists of all components after the first - "." in the filename. - - In torchdata, the prefix consists of the .tar file - path followed by the file name inside the archive. - - Any backslash in the prefix is replaced by a forward - slash to make Windows prefixes consistent with POSIX - paths. - """ - - # convert Windows pathnames to UNIX pathnames, otherwise - # we get an inconsistent mix of the Windows path to the tar - # file followed by the POSIX path inside that tar file - p = p.replace("\\", "/") - if "." not in p: - return p, "" - # we need to use a regular expression because os.path is - # platform specific, but tar files always contain POSIX paths - match = re.search(r"^(.*?)(\.[^/]*)$", p) - if not match: - return p, "" - prefix, suffix = match.groups() - return prefix, suffix - - -@functional_datapipe("webdataset") -class WebDatasetIterDataPipe(IterDataPipe[Dict]): - r""" - Iterable DataPipe that accepts stream of (path, data) tuples, usually, - representing the pathnames and files of a tar archive (functional name: - ``webdataset``). This aggregates consecutive items with the same basename - into a single dictionary, using the extensions as keys (WebDataset file - convention). Any text after the first "." in the filename is used as - a key/extension. - - File names that do not have an extension are ignored. - - Args: - source_datapipe: a DataPipe yielding a stream of (path, data) pairs - - Returns: - a DataPipe yielding a stream of dictionaries - - Examples: - >>> from torchdata.datapipes.iter import FileLister, FileOpener - >>> - >>> def decode(item): - >>> key, value = item - >>> if key.endswith(".txt"): - >>> return key, value.read().decode("utf-8") - >>> if key.endswith(".bin"): - >>> return key, value.read().decode("utf-8") - >>> - >>> datapipe1 = FileLister("test/_fakedata", "wds*.tar") - >>> datapipe2 = FileOpener(datapipe1, mode="b") - >>> dataset = datapipe2.load_from_tar().map(decode).webdataset() - >>> for obj in dataset: - >>> print(obj) - """ - - def __init__(self, source_datapipe: IterDataPipe[List[Union[Dict, List]]]) -> None: - self.source_datapipe: IterDataPipe[List[Union[Dict, List]]] = source_datapipe - - def __iter__(self) -> Iterator[Dict]: - sample: Dict[str, Any] = {} - current = "" - for path, data in self.source_datapipe: - assert isinstance(path, str), path - prefix, suffix = pathsplit(path) - if suffix == "": - # files with empty suffixes can be used for metadata - # they cannot be used for data since they wouldn't have a key - continue - if prefix != current: - if current != "": - yield sample - sample = {} - current = prefix - sample["__key__"] = current - sample[suffix] = data - if sample != {}: - yield sample diff --git a/torchdata/datapipes/iter/util/xzfileloader.py b/torchdata/datapipes/iter/util/xzfileloader.py deleted file mode 100644 index f44dfaf3d..000000000 --- a/torchdata/datapipes/iter/util/xzfileloader.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import lzma -import warnings -from io import BufferedIOBase -from typing import Iterable, Iterator, Tuple - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -from torchdata.datapipes.utils import StreamWrapper -from torchdata.datapipes.utils.common import validate_pathname_binary_tuple - - -@functional_datapipe("load_from_xz") -class XzFileLoaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]): - r""" - Decompresses xz (lzma) binary streams from an Iterable DataPipe which contains tuples of - path name and xy binary streams, and yields a tuple of path name and extracted binary - stream (functional name: ``load_from_xz``). - - Args: - datapipe: Iterable DataPipe that provides tuples of path name and xy binary stream - length: Nominal length of the DataPipe - - Note: - The opened file handles will be closed automatically if the default ``DecoderDataPipe`` - is attached. Otherwise, user should be responsible to close file handles explicitly - or let Python's GC close them periodically. - - Example: - >>> from torchdata.datapipes.iter import FileLister, FileOpener - >>> datapipe1 = FileLister(".", "*.xz") - >>> datapipe2 = FileOpener(datapipe1, mode="b") - >>> xz_loader_dp = datapipe2.load_from_xz() - >>> for _, stream in xz_loader_dp: - >>> print(stream.read()) - b'0123456789abcdef' - """ - - def __init__(self, datapipe: Iterable[Tuple[str, BufferedIOBase]], length: int = -1) -> None: - super().__init__() - self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe - self.length: int = length - - def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: - for data in self.datapipe: - validate_pathname_binary_tuple(data) - pathname, data_stream = data - try: - extracted_fobj = lzma.open(data_stream, mode="rb") # type: ignore[call-overload] - new_pathname = pathname.rstrip(".xz") # https://github.com/pytorch/data/issues/1240 - yield new_pathname, StreamWrapper(extracted_fobj, data_stream, name=pathname) # type: ignore[misc] - except Exception as e: - warnings.warn(f"Unable to extract files from corrupted xz/lzma stream {pathname} due to: {e}, abort!") - raise e - finally: - if isinstance(data_stream, StreamWrapper): - data_stream.autoclose() - - def __len__(self) -> int: - if self.length == -1: - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") - return self.length diff --git a/torchdata/datapipes/iter/util/zip_longest.py b/torchdata/datapipes/iter/util/zip_longest.py deleted file mode 100644 index 3a8f4b982..000000000 --- a/torchdata/datapipes/iter/util/zip_longest.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any, Iterator, List, Optional, Set, Sized, Tuple - -from torch.utils.data.datapipes._decorator import functional_datapipe -from torch.utils.data.datapipes.datapipe import IterDataPipe - - -@functional_datapipe("zip_longest") -class ZipperLongestIterDataPipe(IterDataPipe): - r""" - Aggregates elements into a tuple from each of the input DataPipes (functional name: ``zip_longest``). - The output is stopped until all input DataPipes are exhausted. If any input DataPipe is exhausted, - missing values are filled-in with `fill_value` (default value is None). - - Args: - *datapipes: Iterable DataPipes being aggregated - *fill_value: Value that user input to fill in the missing values from DataPipe. Default value is None. - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> dp1, dp2, dp3 = IterableWrapper(range(3)), IterableWrapper(range(10, 15)), IterableWrapper(range(20, 25)) - >>> list(dp1.zip_longest(dp2, dp3)) - [(0, 10, 20), (1, 11, 21), (2, 12, 22), (None, 13, 23), (None, 14, 24)] - >>> list(dp1.zip_longest(dp2, dp3, -1)) - [(0, 10, 20), (1, 11, 21), (2, 12, 22), (-1, 13, 23), (-1, 14, 24)] - """ - datapipes: Tuple[IterDataPipe] - length: Optional[int] - fill_value: Any - - def __init__( - self, - *datapipes: IterDataPipe, - fill_value: Any = None, - ): - if not all(isinstance(dp, IterDataPipe) for dp in datapipes): - raise TypeError("All inputs are required to be `IterDataPipe` " "for `ZipperLongestIterDataPipe`.") - super().__init__() - self.datapipes = datapipes # type: ignore[assignment] - self.fill_value = fill_value - - def __iter__(self) -> Iterator[Tuple]: - iterators = [iter(x) for x in self.datapipes] - finished: Set[int] = set() - while len(finished) < len(iterators): - values: List[Any] = [] - for i in range(len(iterators)): - value = self.fill_value - if i not in finished: - try: - value = next(iterators[i]) - except StopIteration: - finished.add(i) - if len(finished) == len(iterators): - return - values.append(value) - yield tuple(values) - - def __len__(self) -> int: - if all(isinstance(dp, Sized) for dp in self.datapipes): - return max(len(dp) for dp in self.datapipes) - else: - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/torchdata/datapipes/iter/util/ziparchiveloader.py b/torchdata/datapipes/iter/util/ziparchiveloader.py deleted file mode 100644 index d70a902d3..000000000 --- a/torchdata/datapipes/iter/util/ziparchiveloader.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import sys -import warnings -import zipfile -from io import BufferedIOBase -from typing import cast, IO, Iterable, Iterator, Tuple - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.iter import IterDataPipe - -from torchdata.datapipes.utils import StreamWrapper -from torchdata.datapipes.utils.common import validate_pathname_binary_tuple - - -@functional_datapipe("load_from_zip") -class ZipArchiveLoaderIterDataPipe(IterDataPipe[Tuple[str, BufferedIOBase]]): - r""" - Opens/decompresses zip binary streams from an Iterable DataPipe which contains a tuple of path name and - zip binary stream, and yields a tuple of path name and extracted binary stream (functional name: ``load_from_zip``). - - Args: - datapipe: Iterable DataPipe that provides tuples of path name and zip binary stream - length: Nominal length of the DataPipe - - Note: - The opened file handles will be closed automatically if the default ``DecoderDataPipe`` - is attached. Otherwise, user should be responsible to close file handles explicitly - or let Python's GC close them periodically. Due to how `zipfiles` implements its ``open()`` method, - the data_stream variable below cannot be closed within the scope of this function. - - Example: - >>> from torchdata.datapipes.iter import FileLister, FileOpener - >>> datapipe1 = FileLister(".", "*.zip") - >>> datapipe2 = FileOpener(datapipe1, mode="b") - >>> zip_loader_dp = datapipe2.load_from_zip() - >>> for _, stream in zip_loader_dp: - >>> print(stream.read()) - b'0123456789abcdef' - """ - - def __init__(self, datapipe: Iterable[Tuple[str, BufferedIOBase]], length: int = -1) -> None: - super().__init__() - self.datapipe: Iterable[Tuple[str, BufferedIOBase]] = datapipe - self.length: int = length - - def __iter__(self) -> Iterator[Tuple[str, BufferedIOBase]]: - for data in self.datapipe: - validate_pathname_binary_tuple(data) - pathname, data_stream = data - try: - # typing.cast is used here to silence mypy's type checker - zips = zipfile.ZipFile(cast(IO[bytes], data_stream)) - for zipinfo in zips.infolist(): - # major version should always be 3 here. - if sys.version_info[1] >= 6: - if zipinfo.is_dir(): - continue - elif zipinfo.filename.endswith("/"): - continue - extracted_fobj = zips.open(zipinfo) - inner_pathname = os.path.normpath(os.path.join(pathname, zipinfo.filename)) - yield inner_pathname, StreamWrapper(extracted_fobj, data_stream, name=inner_pathname) # type: ignore[misc] - except Exception as e: - warnings.warn(f"Unable to extract files from corrupted zipfile stream {pathname} due to: {e}, abort!") - raise e - finally: - if isinstance(data_stream, StreamWrapper): - data_stream.autoclose() - # We are unable to close 'data_stream' here, because it needs to be available to use later - - def __len__(self) -> int: - if self.length == -1: - raise TypeError(f"{type(self).__name__} instance doesn't have valid length") - return self.length diff --git a/torchdata/datapipes/map/__init__.py b/torchdata/datapipes/map/__init__.py deleted file mode 100644 index a9e9428bd..000000000 --- a/torchdata/datapipes/map/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torch.utils.data import MapDataPipe - -from torch.utils.data.datapipes.map import Batcher, Concater, Mapper, SequenceWrapper, Shuffler, Zipper - -from torchdata.datapipes.iter.util.converter import IterToMapConverterMapDataPipe as IterToMapConverter -from torchdata.datapipes.map.util.cacheholder import InMemoryCacheHolderMapDataPipe as InMemoryCacheHolder -from torchdata.datapipes.map.util.unzipper import UnZipperMapDataPipe as UnZipper - -__all__ = [ - "Batcher", - "Concater", - "InMemoryCacheHolder", - "IterToMapConverter", - "MapDataPipe", - "Mapper", - "SequenceWrapper", - "Shuffler", - "UnZipper", - "Zipper", -] - -# Please keep this list sorted -assert __all__ == sorted(__all__) diff --git a/torchdata/datapipes/map/__init__.pyi.in b/torchdata/datapipes/map/__init__.pyi.in deleted file mode 100644 index e93ffa79a..000000000 --- a/torchdata/datapipes/map/__init__.pyi.in +++ /dev/null @@ -1,27 +0,0 @@ -${init_base} - -######################################################################################################################## -# The part below is generated by parsing through the Python files where MapDataPipes are defined. -# This base template ("__init__.pyi.in") is generated from mypy stubgen with minimal editing for code injection -# The output file will be "__init__.pyi". The generation function is called by "setup.py". -# Note that, for mypy, .pyi file takes precedent over .py file, such that we must define the interface for other -# classes/objects here, even though we are not injecting extra code into them at the moment. - -from torchdata.datapipes.iter import IterDataPipe -from torch.utils.data import DataChunk, Dataset -from torch.utils.data.datapipes._typing import _DataPipeMeta - -from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union - -T_co = TypeVar('T_co', covariant=True) -T = TypeVar('T') -UNTRACABLE_DATAFRAME_PIPES: Any - -class MapDataPipe(Dataset[T_co], metaclass=_DataPipeMeta): - functions: Dict[str, Callable] = ... - def __getattr__(self, attribute_name: Any): ... - @classmethod - def register_function(cls, function_name: Any, function: Any) -> None: ... - @classmethod - def register_datapipe_as_function(cls, function_name: Any, cls_to_register: Any): ... - ${MapDataPipeMethods} diff --git a/torchdata/datapipes/map/load/__init__.py b/torchdata/datapipes/map/load/__init__.py deleted file mode 100644 index 2e41cd717..000000000 --- a/torchdata/datapipes/map/load/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torchdata/datapipes/map/load/transform.py b/torchdata/datapipes/map/load/transform.py deleted file mode 100644 index 2e41cd717..000000000 --- a/torchdata/datapipes/map/load/transform.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torchdata/datapipes/map/transform/__init__.py b/torchdata/datapipes/map/transform/__init__.py deleted file mode 100644 index 2e41cd717..000000000 --- a/torchdata/datapipes/map/transform/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torchdata/datapipes/map/util/__init__.py b/torchdata/datapipes/map/util/__init__.py deleted file mode 100644 index 2e41cd717..000000000 --- a/torchdata/datapipes/map/util/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/torchdata/datapipes/map/util/cacheholder.py b/torchdata/datapipes/map/util/cacheholder.py deleted file mode 100644 index 8d53ffec0..000000000 --- a/torchdata/datapipes/map/util/cacheholder.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any, Dict, TypeVar - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.map import MapDataPipe - - -T_co = TypeVar("T_co", covariant=True) - - -@functional_datapipe("in_memory_cache") -class InMemoryCacheHolderMapDataPipe(MapDataPipe[T_co]): - r""" - Stores elements from the source DataPipe in memory (functional name: ``in_memory_cache``). Once an item is - stored, it will remain unchanged and subsequent retrivals will return the same element. Since items from - ``MapDataPipe`` are lazily computed, this can be used to store the results from previous ``MapDataPipe`` and - reduce the number of duplicate computations. - - Note: - The default ``cache`` is a ``Dict``. If another data structure is more suitable as cache for your use - - Args: - source_dp: source DataPipe from which elements are read and stored in memory - - Example: - >>> from torchdata.datapipes.map import SequenceWrapper - >>> source_dp = SequenceWrapper(range(10)) - >>> cache_dp = source_dp.in_memory_cache() - >>> list(cache_dp) - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - """ - - def __init__(self, source_dp: MapDataPipe[T_co]) -> None: - self.source_dp: MapDataPipe[T_co] = source_dp - self.cache: Dict[Any, T_co] = {} - - def __getitem__(self, index) -> T_co: - if index not in self.cache: - self.cache[index] = self.source_dp[index] # type: ignore[index] - return self.cache[index] # type: ignore[index] - # We can potentially remove `self.source_dp` to save memory once `len(self.cache) == len(self.source_dp)` - # Be careful about how that may interact with and graph traversal and other features - - def __len__(self) -> int: - return len(self.source_dp) diff --git a/torchdata/datapipes/map/util/converter.py b/torchdata/datapipes/map/util/converter.py deleted file mode 100644 index ad319b9d1..000000000 --- a/torchdata/datapipes/map/util/converter.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import List, Optional - -from torch.utils.data import IterDataPipe, MapDataPipe - - -# @functional_datapipe("to_iter_datapipe") # This line must be kept for .pyi signature parser -class MapToIterConverterIterDataPipe(IterDataPipe): - """ - Convert a ``MapDataPipe`` to an ``IterDataPipe`` (functional name: ``to_iter_datapipe``). It uses ``indices`` to - iterate through the ``MapDataPipe``, defaults to ``range(len(mapdatapipe))`` if not given. - - For the opposite converter, use :class:`.IterToMapConverter`. - - Args: - datapipe: source MapDataPipe with data - indices: optional list of indices that will dictate how the datapipe will be iterated over - - Example: - >>> from torchdata.datapipes.map import SequenceWrapper - >>> source_dp = SequenceWrapper(range(10)) - >>> iter_dp = source_dp.to_iter_datapipe() - >>> list(iter_dp) - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] - >>> source_dp2 = SequenceWrapper({'a': 1, 'b': 2, 'c': 3}) - >>> iter_dp2 = source_dp2.to_iter_datapipe(indices=['a', 'b', 'c']) - >>> list(iter_dp2) - [1, 2, 3] - """ - - # Note that ``indices`` has ``Optional[List]`` instead of ``Optional[Iterable]`` as type because a generator - # can be passed in as an iterable, which will complicate the serialization process as we will have - # to materialize ``indices`` and store it. - def __init__(self, datapipe: MapDataPipe, indices: Optional[List] = None): - if not isinstance(datapipe, MapDataPipe): - raise TypeError(f"MapToIterConverter can only apply on MapDataPipe, but found {type(datapipe)}") - self.datapipe: MapDataPipe = datapipe - self.indices = indices if indices else range(len(datapipe)) - - def __iter__(self): - for idx in self.indices: - yield self.datapipe[idx] - - def __len__(self): - return len(self.indices) - - -MapDataPipe.register_datapipe_as_function("to_iter_datapipe", MapToIterConverterIterDataPipe) diff --git a/torchdata/datapipes/map/util/unzipper.py b/torchdata/datapipes/map/util/unzipper.py deleted file mode 100644 index 99c011ae1..000000000 --- a/torchdata/datapipes/map/util/unzipper.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Optional, Sequence, TypeVar - -from torchdata.datapipes import functional_datapipe -from torchdata.datapipes.map import MapDataPipe - - -T = TypeVar("T") - - -@functional_datapipe("unzip") -class UnZipperMapDataPipe(MapDataPipe): - """ - Takes in a DataPipe of Sequences, unpacks each Sequence, and return the elements in separate DataPipes - based on their position in the Sequence (functional name: ``unzip``). The number of instances produced - equals to the ``sequence_legnth`` minus the number of columns to skip. - - Note: - Each sequence within the DataPipe should have the same length, specified by - the input argument `sequence_length`. - - Args: - source_datapipe: Iterable DataPipe with sequences of data - sequence_length: Length of the sequence within the source_datapipe. All elements should have the same length. - columns_to_skip: optional indices of columns that the DataPipe should skip (each index should be - an integer from 0 to sequence_length - 1) - - Example: - >>> from torchdata.datapipes.map import SequenceWrapper - >>> source_dp = SequenceWrapper([(i, i + 10, i + 20) for i in range(3)]) - >>> dp1, dp2, dp3 = source_dp.unzip(sequence_length=3) - >>> list(dp1) - [0, 1, 2] - >>> list(dp2) - [10, 11, 12] - >>> list(dp3) - [20, 21, 22] - """ - - def __new__( - cls, - source_datapipe: MapDataPipe[Sequence[T]], - sequence_length: int, - columns_to_skip: Optional[Sequence[int]] = None, - ): - if sequence_length < 1: - raise ValueError(f"Expected `sequence_length` larger than 0, but {sequence_length} is found") - if columns_to_skip is None: - instance_ids = list(range(sequence_length)) - else: - skips = set(columns_to_skip) - instance_ids = [i for i in range(sequence_length) if i not in skips] - - if len(instance_ids) == 0: - raise RuntimeError( - f"All instances are being filtered out in {cls.__name__}. Please check" - "the input `sequence_length` and `columns_to_skip`." - ) - return [_UnZipperMapDataPipe(source_datapipe, i) for i in instance_ids] - - -class _UnZipperMapDataPipe(MapDataPipe[T]): - def __init__(self, main_datapipe: MapDataPipe[Sequence[T]], instance_id: int): - self.main_datapipe = main_datapipe - self.instance_id = instance_id - - def __getitem__(self, index) -> T: - return self.main_datapipe[index][self.instance_id] - - def __len__(self) -> int: - return len(self.main_datapipe) diff --git a/torchdata/datapipes/utils/__init__.py b/torchdata/datapipes/utils/__init__.py deleted file mode 100644 index 3ce401c27..000000000 --- a/torchdata/datapipes/utils/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torch.utils.data.datapipes.utils.common import StreamWrapper - -from torchdata.datapipes.utils._visualization import to_graph -from torchdata.datapipes.utils.janitor import janitor -from torchdata.datapipes.utils.pin_memory import pin_memory_fn - -__all__ = [ - "StreamWrapper", - "janitor", - "pin_memory_fn", - "to_graph", -] - -# Please keep this list sorted -assert __all__ == sorted(__all__) diff --git a/torchdata/datapipes/utils/_visualization.py b/torchdata/datapipes/utils/_visualization.py deleted file mode 100644 index b08781507..000000000 --- a/torchdata/datapipes/utils/_visualization.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import itertools -from collections import defaultdict - -from typing import Optional, Set, TYPE_CHECKING - -from torch.utils.data.datapipes.iter.combining import _ChildDataPipe, IterDataPipe -from torch.utils.data.graph import traverse_dps - -if TYPE_CHECKING: - import graphviz - - -class Node: - def __init__(self, dp, *, name=None): - self.dp = dp - self.name = name or type(dp).__name__.replace("IterDataPipe", "") - self.childs = set() - self.parents = set() - - def add_child(self, child): - self.childs.add(child) - child.parents.add(self) - - def remove_child(self, child): - self.childs.remove(child) - child.parents.remove(self) - - def add_parent(self, parent): - self.parents.add(parent) - parent.childs.add(self) - - def remove_parent(self, parent): - self.parents.remove(parent) - parent.childs.remove(self) - - def __eq__(self, other): - if not isinstance(other, Node): - return NotImplemented - - return hash(self) == hash(other) - - def __hash__(self): - return hash(self.dp) - - def __str__(self): - return self.name - - def __repr__(self): - return f"{self}-{hash(self)}" - - -def to_nodes(dp, *, debug: bool) -> Set[Node]: - def recurse(dp_graph, child=None): - for _dp_id, (dp_node, dp_parents) in dp_graph.items(): - node = Node(dp_node) - if child is not None: - node.add_child(child) - yield node - yield from recurse(dp_parents, child=node) - - def aggregate(nodes): - groups = defaultdict(list) - for node in nodes: - groups[node].append(node) - - nodes = set() - for node, group in groups.items(): - if len(group) == 1: - nodes.add(node) - continue - - aggregated_node = Node(node.dp) - - for duplicate_node in group: - for child in duplicate_node.childs.copy(): - duplicate_node.remove_child(child) - aggregated_node.add_child(child) - - for parent in duplicate_node.parents.copy(): - duplicate_node.remove_parent(parent) - aggregated_node.add_parent(parent) - - nodes.add(aggregated_node) - - if debug: - return nodes - - child_dp_nodes = set( - itertools.chain.from_iterable(node.parents for node in nodes if isinstance(node.dp, _ChildDataPipe)) - ) - - if not child_dp_nodes: - return nodes - - for node in child_dp_nodes: - fixed_parent_node = Node( - type(str(node).lstrip("_"), (IterDataPipe,), dict(dp=node.dp, childs=node.childs))() - ) - nodes.remove(node) - nodes.add(fixed_parent_node) - - for parent in node.parents.copy(): - node.remove_parent(parent) - fixed_parent_node.add_parent(parent) - - for child in node.childs: - nodes.remove(child) - for actual_child in child.childs.copy(): - actual_child.remove_parent(child) - actual_child.add_parent(fixed_parent_node) - - return nodes - - return aggregate(recurse(traverse_dps(dp))) - - -def to_graph(dp, *, debug: bool = False) -> "graphviz.Digraph": - """Visualizes a DataPipe by returning a :class:`graphviz.Digraph`, which is a graph of the data pipeline. - This allows you to visually inspect all the transformation that takes place in your DataPipes. - - .. note:: - - The package :mod:`graphviz` is required to use this function. - - .. note:: - - The most common interfaces for the returned graph object are: - - - :meth:`~graphviz.Digraph.render`: Save the graph to a file. - - :meth:`~graphviz.Digraph.view`: Open the graph in a viewer. - - Args: - dp: DataPipe that you would like to visualize (generally the last one in a chain of DataPipes). - debug (bool): If ``True``, renders internal datapipes that are usually hidden from the user - (such as ``ChildDataPipe`` of `demux` and `fork`). Defaults to ``False``. - - Example: - >>> from torchdata.datapipes.iter import IterableWrapper - >>> from torchdata.datapipes.utils import to_graph - >>> dp = IterableWrapper(range(10)) - >>> dp1, dp2 = dp.demux(num_instances=2, classifier_fn=lambda x: x % 2) - >>> dp1 = dp1.map(lambda x: x + 1) - >>> dp2 = dp2.filter(lambda _: True) - >>> dp3 = dp1.zip(dp2).map(lambda t: t[0] + t[1]) - >>> g = to_graph(dp3) - >>> g.view() # This will open the graph in a viewer - """ - try: - import graphviz - except ModuleNotFoundError: - raise ModuleNotFoundError( - "The package `graphviz` is required to be installed to use this function. " - "Please `pip install graphviz` or `conda install -c conda-forge graphviz`." - ) from None - - # The graph style as well as the color scheme below was copied from https://github.com/szagoruyko/pytorchviz/ - # https://github.com/szagoruyko/pytorchviz/blob/0adcd83af8aa7ab36d6afd139cabbd9df598edb7/torchviz/dot.py#L78-L85 - node_attr = dict( - style="filled", - shape="box", - align="left", - fontsize="10", - ranksep="0.1", - height="0.2", - fontname="monospace", - ) - graph = graphviz.Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) - - for node in to_nodes(dp, debug=debug): - fillcolor: Optional[str] - if not node.parents: - fillcolor = "lightblue" - elif not node.childs: - fillcolor = "darkolivegreen1" - else: - fillcolor = None - - graph.node(name=repr(node), label=str(node), fillcolor=fillcolor) - - for child in node.childs: - graph.edge(repr(node), repr(child)) - - return graph diff --git a/torchdata/datapipes/utils/common.py b/torchdata/datapipes/utils/common.py deleted file mode 100644 index eea1f785b..000000000 --- a/torchdata/datapipes/utils/common.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from io import IOBase -from typing import Tuple - -from torchdata.datapipes.utils import StreamWrapper - - -def validate_pathname_binary_tuple(data: Tuple[str, IOBase]): - if not isinstance(data, tuple): - raise TypeError(f"pathname binary data should be tuple type, but it is type {type(data)}") - if len(data) != 2: - raise TypeError(f"pathname binary stream tuple length should be 2, but got {len(data)}") - if not isinstance(data[0], str): - raise TypeError(f"pathname within the tuple should have string type pathname, but it is type {type(data[0])}") - if not isinstance(data[1], IOBase) and not isinstance(data[1], StreamWrapper): - raise TypeError( - f"binary stream within the tuple should have IOBase or" - f"its subclasses as type, but it is type {type(data[1])}" - ) diff --git a/torchdata/datapipes/utils/janitor.py b/torchdata/datapipes/utils/janitor.py deleted file mode 100644 index 3dd248408..000000000 --- a/torchdata/datapipes/utils/janitor.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from torchdata.datapipes.utils import StreamWrapper - - -def janitor(obj): - """ - Invokes various `obj` cleanup procedures such as: - - Closing streams - """ - # TODO(632): We can also release caching locks here to allow filtering - StreamWrapper.close_streams(obj) diff --git a/torchdata/datapipes/utils/pin_memory.py b/torchdata/datapipes/utils/pin_memory.py deleted file mode 100644 index 997c24042..000000000 --- a/torchdata/datapipes/utils/pin_memory.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import collections - - -def pin_memory_fn(data, device=None): - r""" - Utility function to move data to pinned memory. If special treatment is needed to move - the input data to pinned memory, please attach a ``pin_memory`` method to the expected - data class. - """ - if hasattr(data, "pin_memory"): # Including torch.Tensor - return data.pin_memory(device) - elif isinstance(data, (str, bytes)): - return data - elif isinstance(data, collections.abc.Mapping): - pinned_data = {k: pin_memory_fn(sample, device) for k, sample in data.items()} - try: - return type(data)(pinned_data) # type: ignore[call-arg] - except TypeError: - # The mapping type may not support `__init__(iterable)`. - return pinned_data - elif isinstance(data, collections.abc.Sequence): - pinned_data = [pin_memory_fn(sample, device) for sample in data] # type: ignore[assignment] - try: - return type(data)(pinned_data) # type: ignore[call-arg] - except TypeError: - # The sequence type may not support `__init__(iterable)` (e.g., `range`). - return pinned_data - else: - return data