Skip to content

Commit

Permalink
Preserve features in iterable dataset.filter (#7209)
Browse files Browse the repository at this point in the history
* add is_typed property to example iterables to prevent applying decode_examples multiple times

* Update src/datasets/iterable_dataset.py

Co-authored-by: Quentin Lhoest <[email protected]>

---------

Co-authored-by: Quentin Lhoest <[email protected]>
  • Loading branch information
alex-hh and lhoestq authored Oct 9, 2024
1 parent afe875a commit 16a121d
Showing 1 changed file with 60 additions and 8 deletions.
68 changes: 60 additions & 8 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def __iter__(self) -> Iterator[Tuple[Key, dict]]:
def iter_arrow(self) -> Optional[Callable[[], Iterator[Tuple[Key, pa.Table]]]]:
return None

@property
def is_typed(self) -> bool:
return False

def shuffle_data_sources(self, generator: np.random.Generator) -> "_BaseExamplesIterable":
"""
Either shuffle the shards/sources of the dataset, or propagate the shuffling to the underlying iterable.
Expand Down Expand Up @@ -393,6 +397,10 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, batch_size: Optional[int]
def iter_arrow(self):
return self._iter_arrow

@property
def is_typed(self):
return self.ex_iterable.is_typed

def _init_state_dict(self) -> dict:
self._state_dict = {
"ex_iterable": self.ex_iterable._init_state_dict(),
Expand Down Expand Up @@ -518,6 +526,10 @@ def iter_arrow(self):
if self.ex_iterable.iter_arrow:
return self._iter_arrow

@property
def is_typed(self):
return self.ex_iterable.is_typed

def _init_state_dict(self) -> dict:
self._state_dict = self.ex_iterable._init_state_dict()
return self._state_dict
Expand Down Expand Up @@ -550,6 +562,10 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, step: int, offset: int):
self.offset = offset
# TODO(QL): implement iter_arrow

@property
def is_typed(self):
return self.ex_iterable.is_typed

def _init_state_dict(self) -> dict:
self._state_dict = self.ex_iterable._init_state_dict()
return self._state_dict
Expand Down Expand Up @@ -593,6 +609,10 @@ def __init__(
self.bool_strategy_func = np.all if (stopping_strategy == "all_exhausted") else np.any
# TODO(QL): implement iter_arrow

@property
def is_typed(self):
return self.ex_iterables[0].is_typed

def _get_indices_iterator(self):
# this is an infinite iterator to keep track of which iterator we want to pick examples from
ex_iterable_idx = self._state_dict["ex_iterable_idx"] if self._state_dict else 0
Expand Down Expand Up @@ -687,6 +707,10 @@ def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
super().__init__()
self.ex_iterables = ex_iterables

@property
def is_typed(self):
return self.ex_iterables[0].is_typed

@property
def iter_arrow(self):
if all(ex_iterable.iter_arrow is not None for ex_iterable in self.ex_iterables):
Expand Down Expand Up @@ -767,6 +791,10 @@ def __init__(self, ex_iterables: List[_BaseExamplesIterable]):
self.ex_iterables = ex_iterables
# TODO(QL): implement iter_arrow

@property
def is_typed(self):
return self.ex_iterables[0].is_typed

def _init_state_dict(self) -> dict:
self._state_dict = {"ex_iterables": [ex_iterable._init_state_dict() for ex_iterable in self.ex_iterables]}
return self._state_dict
Expand Down Expand Up @@ -826,6 +854,10 @@ def __init__(
self.probabilities = probabilities
# TODO(QL): implement iter_arrow

@property
def is_typed(self):
return self.ex_iterables[0].is_typed

def _get_indices_iterator(self):
rng = deepcopy(self.generator)
num_sources = len(self.ex_iterables)
Expand Down Expand Up @@ -929,6 +961,10 @@ def iter_arrow(self):
if self.formatting and self.formatting.format_type == "arrow":
return self._iter_arrow

@property
def is_typed(self):
return False

def _init_state_dict(self) -> dict:
self._state_dict = {
"ex_iterable": self.ex_iterable._init_state_dict(),
Expand Down Expand Up @@ -1185,6 +1221,10 @@ def iter_arrow(self):
if self.formatting and self.formatting.format_type == "arrow":
return self._iter_arrow

@property
def is_typed(self):
return self.ex_iterable.is_typed

def _init_state_dict(self) -> dict:
self._state_dict = {
"ex_iterable": self.ex_iterable._init_state_dict(),
Expand Down Expand Up @@ -1365,6 +1405,10 @@ def __init__(self, ex_iterable: _BaseExamplesIterable, buffer_size: int, generat
self.generator = generator
# TODO(QL): implement iter_arrow

@property
def is_typed(self):
return self.ex_iterable.is_typed

def _init_state_dict(self) -> dict:
self._state_dict = self.ex_iterable._init_state_dict()
self._original_state_dict = self.state_dict()
Expand Down Expand Up @@ -1435,6 +1479,10 @@ def __init__(
self.split_when_sharding = split_when_sharding
# TODO(QL): implement iter_arrow

@property
def is_typed(self):
return self.ex_iterable.is_typed

def _init_state_dict(self) -> dict:
self._state_dict = {"skipped": False, "ex_iterable": self.ex_iterable._init_state_dict()}
return self._state_dict
Expand Down Expand Up @@ -1498,6 +1546,10 @@ def __init__(
self.split_when_sharding = split_when_sharding
# TODO(QL): implement iter_arrow

@property
def is_typed(self):
return self.ex_iterable.is_typed

def _init_state_dict(self) -> dict:
self._state_dict = {"num_taken": 0, "ex_iterable": self.ex_iterable._init_state_dict()}
return self._state_dict
Expand Down Expand Up @@ -1600,6 +1652,10 @@ def iter_arrow(self):
if self.ex_iterable.iter_arrow is not None:
return self._iter_arrow

@property
def is_typed(self):
return True

def _init_state_dict(self) -> dict:
self._state_dict = self.ex_iterable._init_state_dict()
return self._state_dict
Expand Down Expand Up @@ -1914,7 +1970,7 @@ def _iter_pytorch(self):
return
else:
for key, example in ex_iterable:
if self.features:
if self.features and not ex_iterable.is_typed:
# `IterableDataset` automatically fills missing columns with None.
# This is done with `_apply_feature_types_on_example`.
example = _apply_feature_types_on_example(
Expand Down Expand Up @@ -2010,7 +2066,7 @@ def __iter__(self):
return

for key, example in ex_iterable:
if self.features:
if self.features and not ex_iterable.is_typed:
# `IterableDataset` automatically fills missing columns with None.
# This is done with `_apply_feature_types_on_example`.
example = _apply_feature_types_on_example(
Expand Down Expand Up @@ -2052,7 +2108,7 @@ def iter(self, batch_size: int, drop_last_batch: bool = False):
if drop_last_batch and len(examples) < batch_size: # ignore last batch
return
batch = _examples_to_batch(examples)
if self.features:
if self.features and not ex_iterable.is_typed:
# `IterableDataset` automatically fills missing columns with None.
# This is done with `_apply_feature_types_on_batch`.
batch = _apply_feature_types_on_batch(batch, self.features, token_per_repo_id=self._token_per_repo_id)
Expand Down Expand Up @@ -2405,10 +2461,6 @@ def filter(
if isinstance(input_columns, str):
input_columns = [input_columns]

# TODO(QL): keep the features (right now if we keep it it would call decode_example again on an already decoded example)
info = copy.deepcopy(self._info)
info.features = None

# We need the examples to be decoded for certain feature types like Image or Audio, so we use TypedExamplesIterable here
ex_iterable = FilteredExamplesIterable(
TypedExamplesIterable(self._ex_iterable, self._info.features, token_per_repo_id=self._token_per_repo_id)
Expand All @@ -2424,7 +2476,7 @@ def filter(
)
return IterableDataset(
ex_iterable=ex_iterable,
info=info,
info=self._info,
split=self._split,
formatting=self._formatting,
shuffling=copy.deepcopy(self._shuffling),
Expand Down

0 comments on commit 16a121d

Please sign in to comment.