diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 50549920..59ada084 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -35,6 +35,11 @@ jobs: - name: PyTest run: make pytest + # Pylint is not compatible with Apache Beam in Python 3.11: + # https://github.com/pylint-dev/pylint/blob/02616372282fd84862636d58071e6f3c62b53559/pyproject.toml#L38 + - name: Install Pylint separately. + run: pip install pylint + - name: PyLint run: make pylint @@ -62,7 +67,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -117,14 +122,14 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.10' - name: Install dependencies except mlcroissant itself run: > pip install absl-py \ - datasets \ + apache-beam \ etils[epath] \ GitPython \ jsonpath_rw \ @@ -152,14 +157,14 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.10' - name: Install dependencies except mlcroissant itself run: > pip install absl-py \ - datasets \ + apache-beam \ etils[epath] \ GitPython \ jsonpath_rw \ @@ -188,7 +193,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' @@ -217,7 +222,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' @@ -240,7 +245,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' @@ -264,7 +269,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.11' diff --git a/python/mlcroissant/mlcroissant/_src/datasets.py b/python/mlcroissant/mlcroissant/_src/datasets.py index 2fffcbfd..67516e27 100644 --- a/python/mlcroissant/mlcroissant/_src/datasets.py +++ b/python/mlcroissant/mlcroissant/_src/datasets.py @@ -5,6 +5,7 @@ from collections.abc import Mapping import dataclasses import re +import typing from typing import Any from absl import logging @@ -18,6 +19,7 @@ from mlcroissant._src.operation_graph.base_operation import Operation from mlcroissant._src.operation_graph.base_operation import Operations from mlcroissant._src.operation_graph.execute import execute_downloads +from mlcroissant._src.operation_graph.execute import execute_operations_in_beam from mlcroissant._src.operation_graph.execute import execute_operations_in_streaming from mlcroissant._src.operation_graph.execute import execute_operations_sequentially from mlcroissant._src.operation_graph.operations import FilterFiles @@ -27,6 +29,9 @@ from mlcroissant._src.structure_graph.nodes.metadata import Metadata from mlcroissant._src.structure_graph.nodes.source import FileProperty +if typing.TYPE_CHECKING: + import apache_beam as beam + Filters = Mapping[str, Any] @@ -152,8 +157,7 @@ def __iter__(self): # that all operations lie on a single straight line, i.e. have an # in-degree of 0 or 1. That means that the operation graph is a single line # (without external joins for example). - can_stream_dataset = all(d == 1 or d == 2 for _, d in operations.degree()) - if can_stream_dataset: + if _is_streamable_dataset(operations): yield from execute_operations_in_streaming( record_set=self.record_set, operations=operations, @@ -163,6 +167,52 @@ def __iter__(self): record_set=self.record_set, operations=operations ) + def beam_reader(self, pipeline: beam.Pipeline): + """Returns an Apache Beam reader to generate the dataset using e.g. Spark. + + Example of usage: + + ```python + import apache_beam as beam + from apache_beam.options.pipeline_options import PipelineOptions + import mlcroissant as mlc + + dataset = mlc.Dataset( + jsonld="https://huggingface.co/api/datasets/ylecun/mnist/croissant", + ) + pipeline_options = PipelineOptions() + with beam.Pipeline(options=pipeline_options) as pipeline: + _ = dataset.records("mnist").beam_reader(pipeline) + ``` + + Only streamable datasets can be used with Beam. A streamable dataset is a + dataset that can be generated by a linear sequence of operations - without joins + for example. This is the case for Hugging Face datasets. If there are branches, + we'd need a more complex Beam pipeline. + + The sharding is done on the filtered files. This is currently optimized for + Hugging Face datasets, so it raises an error if the dataset is not a Hugging + Face dataset. + + Args: + A Beam pipeline. + + Returns: + A Beam PCollection with all the records. + + Raises: + A ValueError if the dataset is not streamable. + """ + operations = self._filter_interesting_operations(self.filters) + execute_downloads(operations) + if not _is_streamable_dataset(operations): + raise ValueError("only streamable datasets can be used with Beam.") + return execute_operations_in_beam( + pipeline=pipeline, + record_set=self.record_set, + operations=operations, + ) + def _filter_interesting_operations(self, filters: Filters | None) -> Operations: """Filters connected operations to `ReadFields(self.record_set)`. @@ -307,3 +357,12 @@ def _validate_filters(filters: Filters): " to keep all records whose field `data/split` takes the value `train`." f" Instead, we got: {filters=}" ) + + +def _is_streamable_dataset(operations: Operations): + """Whether the operations define a streamable datasets. + + A streamable dataset is a dataset that results from executing a linear sequence of + operations without branching (for example, no join). + """ + return all(d == 1 or d == 2 for _, d in operations.degree()) diff --git a/python/mlcroissant/mlcroissant/_src/datasets_test.py b/python/mlcroissant/mlcroissant/_src/datasets_test.py index 71349d87..d76e58fb 100644 --- a/python/mlcroissant/mlcroissant/_src/datasets_test.py +++ b/python/mlcroissant/mlcroissant/_src/datasets_test.py @@ -3,6 +3,9 @@ import json from typing import Any +from apache_beam.testing import test_pipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import BeamAssertException from etils import epath import pytest @@ -114,6 +117,46 @@ def load_records_and_test_equality( assert len(expected_records) == length +def _equal_to_set(expected): + """Checks whether 2 beam.PCollections are equal as sets.""" + + def matcher_fn(actual): + expected_set = set([ + json.dumps(record_to_python(element)) for element in list(expected) + ]) + actual_set = set([ + json.dumps(record_to_python(element)) for element in list(actual) + ]) + if expected_set != actual_set: + raise BeamAssertException( + f"sets are not equal: {expected_set - actual_set}" + ) + + return matcher_fn + + +def load_records_with_beam_and_test_equality( + version: str, + dataset_name: str, + record_set_name: str, +): + config = ( + epath.Path(__file__).parent.parent.parent.parent.parent + / "datasets" + / version + / dataset_name + ) + output_file = config.parent / "output" / f"{record_set_name}.jsonl" + with output_file.open("rb") as f: + lines = f.readlines() + expected_records = [json.loads(line) for line in lines] + dataset = datasets.Dataset(config) + + with test_pipeline.TestPipeline() as pipeline: + result = dataset.records(record_set_name).beam_reader(pipeline=pipeline) + assert_that(result, _equal_to_set(expected_records)) + + # IF (NON)-HERMETIC TESTS FAIL, OR A NEW DATASET IS ADDED: # You can regenerate .pkl files by launching # ```bash @@ -150,6 +193,17 @@ def test_hermetic_loading(version, dataset_name, record_set_name, num_records): load_records_and_test_equality(version, dataset_name, record_set_name, num_records) +@parametrize_version() +@pytest.mark.parametrize( + ["dataset_name", "record_set_name"], + [ + ["simple-parquet/metadata.json", "persons"], + ], +) +def test_beam_hermetic_loading(version, dataset_name, record_set_name): + load_records_with_beam_and_test_equality(version, dataset_name, record_set_name) + + # Hermetic test cases for croissant >=1.0 only. @pytest.mark.parametrize( ["dataset_name", "record_set_name", "num_records", "filters"], diff --git a/python/mlcroissant/mlcroissant/_src/operation_graph/execute.py b/python/mlcroissant/mlcroissant/_src/operation_graph/execute.py index 9d180f93..25bed5e1 100644 --- a/python/mlcroissant/mlcroissant/_src/operation_graph/execute.py +++ b/python/mlcroissant/mlcroissant/_src/operation_graph/execute.py @@ -1,6 +1,10 @@ """Module to execute operations.""" +from __future__ import annotations + +import collections import concurrent.futures +import typing from typing import Any from absl import logging @@ -10,10 +14,14 @@ from mlcroissant._src.core.issues import GenerationError from mlcroissant._src.operation_graph.base_operation import Operation from mlcroissant._src.operation_graph.base_operation import Operations +from mlcroissant._src.operation_graph.operations import FilterFiles from mlcroissant._src.operation_graph.operations import ReadFields from mlcroissant._src.operation_graph.operations.download import Download from mlcroissant._src.operation_graph.operations.read import Read +if typing.TYPE_CHECKING: + import apache_beam as beam + def execute_downloads(operations: Operations): """Executes all the downloads in the graph of operations.""" @@ -124,3 +132,29 @@ def read_all_files(): "An error occured during the streaming generation of the dataset, more" f" specifically during the operation {operation}" ) from e + + +def execute_operations_in_beam( + pipeline: beam.Pipeline, record_set: str, operations: Operations +): + """See beam_reader docstring.""" + import apache_beam as beam + + list_of_operations = _order_relevant_operations(operations, record_set) + queue_of_operations = collections.deque(list_of_operations) + files = None + operation = None + while queue_of_operations: + operation = queue_of_operations.popleft() + files = operation(files) + if isinstance(operation, FilterFiles): + break + pipeline = pipeline | "Shard by files" >> beam.Create(files) + while queue_of_operations: + operation = queue_of_operations.popleft() + if isinstance(operation, ReadFields): + beam_operation = beam.ParDo(operation) + else: + beam_operation = beam.Map(operation) + pipeline |= beam_operation + return pipeline diff --git a/python/mlcroissant/mlcroissant/_src/structure_graph/base_node.py b/python/mlcroissant/mlcroissant/_src/structure_graph/base_node.py index 0bb5a23a..e57f37d7 100644 --- a/python/mlcroissant/mlcroissant/_src/structure_graph/base_node.py +++ b/python/mlcroissant/mlcroissant/_src/structure_graph/base_node.py @@ -344,7 +344,9 @@ def __getstate__(self): state = {} for field in dataclasses.fields(self): if field.name == "ctx": - state[field.name] = Context() + ctx = Context() + ctx.graph = self.ctx.graph + state[field.name] = ctx else: state[field.name] = getattr(self, field.name) return state diff --git a/python/mlcroissant/pyproject.toml b/python/mlcroissant/pyproject.toml index ef919b38..fd5f8b72 100644 --- a/python/mlcroissant/pyproject.toml +++ b/python/mlcroissant/pyproject.toml @@ -42,16 +42,15 @@ documentation = "https://mlcommons.org/working-groups/data/croissant/" # Installed through `pip install -e .[dev]` dev = [ "black==23.11.0", - "datasets", "flake8-docstrings", "mlcroissant[audio]", + "mlcroissant[beam]", "mlcroissant[git]", "mlcroissant[image]", "mlcroissant[parquet]", "mypy", "pyflakes", "pygraphviz", - "pylint", "pytest", "pytype", "torchdata", @@ -61,6 +60,9 @@ audio = [ "librosa", "soxr==0.4.0b1", ] +beam = [ + "apache-beam", +] git = ["GitPython"] image = ["Pillow"] parquet = ["pyarrow"] @@ -89,7 +91,7 @@ disable_error_code = "attr-defined" [[tool.mypy.overrides]] module = [ "absl", - "datasets", + "apache-beam", "etils.*", "jsonpath_rw", "librosa",