Skip to content

Commit

Permalink
Allow to parallelize operations in mlcroissant with Apache Beam. (#730)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcenacp authored Sep 4, 2024
1 parent a23b44c commit e3f524b
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 16 deletions.
25 changes: 15 additions & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -35,6 +35,11 @@ jobs:
- name: PyTest
run: make pytest

# Pylint is not compatible with Apache Beam in Python 3.11:
# https://github.com/pylint-dev/pylint/blob/02616372282fd84862636d58071e6f3c62b53559/pyproject.toml#L38
- name: Install Pylint separately.
run: pip install pylint

- name: PyLint
run: make pylint

Expand Down Expand Up @@ -62,7 +67,7 @@ jobs:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand Down Expand Up @@ -117,14 +122,14 @@ jobs:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Install dependencies except mlcroissant itself
run: >
pip install absl-py \
datasets \
apache-beam \
etils[epath] \
GitPython \
jsonpath_rw \
Expand Down Expand Up @@ -152,14 +157,14 @@ jobs:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Install dependencies except mlcroissant itself
run: >
pip install absl-py \
datasets \
apache-beam \
etils[epath] \
GitPython \
jsonpath_rw \
Expand Down Expand Up @@ -188,7 +193,7 @@ jobs:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.11'

Expand Down Expand Up @@ -217,7 +222,7 @@ jobs:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.11'

Expand All @@ -240,7 +245,7 @@ jobs:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.11'

Expand All @@ -264,7 +269,7 @@ jobs:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.11'

Expand Down
63 changes: 61 additions & 2 deletions python/mlcroissant/mlcroissant/_src/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Mapping
import dataclasses
import re
import typing
from typing import Any

from absl import logging
Expand All @@ -18,6 +19,7 @@
from mlcroissant._src.operation_graph.base_operation import Operation
from mlcroissant._src.operation_graph.base_operation import Operations
from mlcroissant._src.operation_graph.execute import execute_downloads
from mlcroissant._src.operation_graph.execute import execute_operations_in_beam
from mlcroissant._src.operation_graph.execute import execute_operations_in_streaming
from mlcroissant._src.operation_graph.execute import execute_operations_sequentially
from mlcroissant._src.operation_graph.operations import FilterFiles
Expand All @@ -27,6 +29,9 @@
from mlcroissant._src.structure_graph.nodes.metadata import Metadata
from mlcroissant._src.structure_graph.nodes.source import FileProperty

if typing.TYPE_CHECKING:
import apache_beam as beam

Filters = Mapping[str, Any]


Expand Down Expand Up @@ -152,8 +157,7 @@ def __iter__(self):
# that all operations lie on a single straight line, i.e. have an
# in-degree of 0 or 1. That means that the operation graph is a single line
# (without external joins for example).
can_stream_dataset = all(d == 1 or d == 2 for _, d in operations.degree())
if can_stream_dataset:
if _is_streamable_dataset(operations):
yield from execute_operations_in_streaming(
record_set=self.record_set,
operations=operations,
Expand All @@ -163,6 +167,52 @@ def __iter__(self):
record_set=self.record_set, operations=operations
)

def beam_reader(self, pipeline: beam.Pipeline):
"""Returns an Apache Beam reader to generate the dataset using e.g. Spark.
Example of usage:
```python
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
import mlcroissant as mlc
dataset = mlc.Dataset(
jsonld="https://huggingface.co/api/datasets/ylecun/mnist/croissant",
)
pipeline_options = PipelineOptions()
with beam.Pipeline(options=pipeline_options) as pipeline:
_ = dataset.records("mnist").beam_reader(pipeline)
```
Only streamable datasets can be used with Beam. A streamable dataset is a
dataset that can be generated by a linear sequence of operations - without joins
for example. This is the case for Hugging Face datasets. If there are branches,
we'd need a more complex Beam pipeline.
The sharding is done on the filtered files. This is currently optimized for
Hugging Face datasets, so it raises an error if the dataset is not a Hugging
Face dataset.
Args:
A Beam pipeline.
Returns:
A Beam PCollection with all the records.
Raises:
A ValueError if the dataset is not streamable.
"""
operations = self._filter_interesting_operations(self.filters)
execute_downloads(operations)
if not _is_streamable_dataset(operations):
raise ValueError("only streamable datasets can be used with Beam.")
return execute_operations_in_beam(
pipeline=pipeline,
record_set=self.record_set,
operations=operations,
)

def _filter_interesting_operations(self, filters: Filters | None) -> Operations:
"""Filters connected operations to `ReadFields(self.record_set)`.
Expand Down Expand Up @@ -307,3 +357,12 @@ def _validate_filters(filters: Filters):
" to keep all records whose field `data/split` takes the value `train`."
f" Instead, we got: {filters=}"
)


def _is_streamable_dataset(operations: Operations):
"""Whether the operations define a streamable datasets.
A streamable dataset is a dataset that results from executing a linear sequence of
operations without branching (for example, no join).
"""
return all(d == 1 or d == 2 for _, d in operations.degree())
54 changes: 54 additions & 0 deletions python/mlcroissant/mlcroissant/_src/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import json
from typing import Any

from apache_beam.testing import test_pipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import BeamAssertException
from etils import epath
import pytest

Expand Down Expand Up @@ -114,6 +117,46 @@ def load_records_and_test_equality(
assert len(expected_records) == length


def _equal_to_set(expected):
"""Checks whether 2 beam.PCollections are equal as sets."""

def matcher_fn(actual):
expected_set = set([
json.dumps(record_to_python(element)) for element in list(expected)
])
actual_set = set([
json.dumps(record_to_python(element)) for element in list(actual)
])
if expected_set != actual_set:
raise BeamAssertException(
f"sets are not equal: {expected_set - actual_set}"
)

return matcher_fn


def load_records_with_beam_and_test_equality(
version: str,
dataset_name: str,
record_set_name: str,
):
config = (
epath.Path(__file__).parent.parent.parent.parent.parent
/ "datasets"
/ version
/ dataset_name
)
output_file = config.parent / "output" / f"{record_set_name}.jsonl"
with output_file.open("rb") as f:
lines = f.readlines()
expected_records = [json.loads(line) for line in lines]
dataset = datasets.Dataset(config)

with test_pipeline.TestPipeline() as pipeline:
result = dataset.records(record_set_name).beam_reader(pipeline=pipeline)
assert_that(result, _equal_to_set(expected_records))


# IF (NON)-HERMETIC TESTS FAIL, OR A NEW DATASET IS ADDED:
# You can regenerate .pkl files by launching
# ```bash
Expand Down Expand Up @@ -150,6 +193,17 @@ def test_hermetic_loading(version, dataset_name, record_set_name, num_records):
load_records_and_test_equality(version, dataset_name, record_set_name, num_records)


@parametrize_version()
@pytest.mark.parametrize(
["dataset_name", "record_set_name"],
[
["simple-parquet/metadata.json", "persons"],
],
)
def test_beam_hermetic_loading(version, dataset_name, record_set_name):
load_records_with_beam_and_test_equality(version, dataset_name, record_set_name)


# Hermetic test cases for croissant >=1.0 only.
@pytest.mark.parametrize(
["dataset_name", "record_set_name", "num_records", "filters"],
Expand Down
34 changes: 34 additions & 0 deletions python/mlcroissant/mlcroissant/_src/operation_graph/execute.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""Module to execute operations."""

from __future__ import annotations

import collections
import concurrent.futures
import typing
from typing import Any

from absl import logging
Expand All @@ -10,10 +14,14 @@
from mlcroissant._src.core.issues import GenerationError
from mlcroissant._src.operation_graph.base_operation import Operation
from mlcroissant._src.operation_graph.base_operation import Operations
from mlcroissant._src.operation_graph.operations import FilterFiles
from mlcroissant._src.operation_graph.operations import ReadFields
from mlcroissant._src.operation_graph.operations.download import Download
from mlcroissant._src.operation_graph.operations.read import Read

if typing.TYPE_CHECKING:
import apache_beam as beam


def execute_downloads(operations: Operations):
"""Executes all the downloads in the graph of operations."""
Expand Down Expand Up @@ -124,3 +132,29 @@ def read_all_files():
"An error occured during the streaming generation of the dataset, more"
f" specifically during the operation {operation}"
) from e


def execute_operations_in_beam(
pipeline: beam.Pipeline, record_set: str, operations: Operations
):
"""See beam_reader docstring."""
import apache_beam as beam

list_of_operations = _order_relevant_operations(operations, record_set)
queue_of_operations = collections.deque(list_of_operations)
files = None
operation = None
while queue_of_operations:
operation = queue_of_operations.popleft()
files = operation(files)
if isinstance(operation, FilterFiles):
break
pipeline = pipeline | "Shard by files" >> beam.Create(files)
while queue_of_operations:
operation = queue_of_operations.popleft()
if isinstance(operation, ReadFields):
beam_operation = beam.ParDo(operation)
else:
beam_operation = beam.Map(operation)
pipeline |= beam_operation
return pipeline
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,9 @@ def __getstate__(self):
state = {}
for field in dataclasses.fields(self):
if field.name == "ctx":
state[field.name] = Context()
ctx = Context()
ctx.graph = self.ctx.graph
state[field.name] = ctx
else:
state[field.name] = getattr(self, field.name)
return state
Expand Down
8 changes: 5 additions & 3 deletions python/mlcroissant/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,15 @@ documentation = "https://mlcommons.org/working-groups/data/croissant/"
# Installed through `pip install -e .[dev]`
dev = [
"black==23.11.0",
"datasets",
"flake8-docstrings",
"mlcroissant[audio]",
"mlcroissant[beam]",
"mlcroissant[git]",
"mlcroissant[image]",
"mlcroissant[parquet]",
"mypy",
"pyflakes",
"pygraphviz",
"pylint",
"pytest",
"pytype",
"torchdata",
Expand All @@ -61,6 +60,9 @@ audio = [
"librosa",
"soxr==0.4.0b1",
]
beam = [
"apache-beam",
]
git = ["GitPython"]
image = ["Pillow"]
parquet = ["pyarrow"]
Expand Down Expand Up @@ -89,7 +91,7 @@ disable_error_code = "attr-defined"
[[tool.mypy.overrides]]
module = [
"absl",
"datasets",
"apache-beam",
"etils.*",
"jsonpath_rw",
"librosa",
Expand Down

0 comments on commit e3f524b

Please sign in to comment.