Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve speed of AddId module #36

Merged
merged 4 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions nemo_curator/modules/add_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import dask.dataframe as dd
import numpy as np
from dask import delayed

from nemo_curator.datasets import DocumentDataset
from nemo_curator.utils.module_utils import count_digits


class AddId:
def __init__(self, id_field, id_prefix="doc_id", start_index=0) -> None:
def __init__(
self, id_field, id_prefix: str = "doc_id", start_index: Optional[int] = None
) -> None:
self.id_field = id_field
self.id_prefix = id_prefix
self.start_index = start_index

def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
if self.start_index is None:
return self._add_id_fast(dataset)
else:
return self._add_id_ordered(dataset)

def _add_id_fast(self, dataset: DocumentDataset) -> DocumentDataset:
meta = dataset.df.dtypes.to_dict()
meta[self.id_field] = "string"

partition_zero_padding = count_digits(dataset.df.npartitions)
id_df = dataset.df.map_partitions(
self._add_id_fast_partition,
partition_zero_padding,
meta=meta,
)

return DocumentDataset(id_df)

def _add_id_fast_partition(self, partition, global_padding, partition_info=None):
local_padding = count_digits(len(partition))
global_id = partition_info["number"]

id_column = [
f"{self.id_prefix}-{local_id:0{local_padding}d}{global_id:0{global_padding}d}"
for local_id in range(len(partition))
]
partition[self.id_field] = id_column

return partition

def _add_id_ordered(self, dataset: DocumentDataset) -> DocumentDataset:
original_meta = dataset.df.dtypes.to_dict()
original_meta[self.id_field] = "object"
original_meta[self.id_field] = "string"
delayed_dataset = dataset.df.to_delayed()

parition_lengths = [0]
Expand All @@ -38,7 +74,7 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:
delayed_id_dataset = []
for i, partition in enumerate(delayed_dataset):
delayed_id_dataset.append(
delayed(self._add_id_to_partition)(partition, lower_id_bounds[i])
delayed(self._add_id_ordered_partition)(partition, lower_id_bounds[i])
)

id_dataset = DocumentDataset(
Expand All @@ -47,11 +83,12 @@ def __call__(self, dataset: DocumentDataset) -> DocumentDataset:

return id_dataset

def _add_id_to_partition(self, partition, partition_start_id):
def _add_id_ordered_partition(self, partition, partition_start_id):
id_column = [
f"{self.id_prefix}-{int(i + self.start_index):010d}"
for i in range(partition_start_id, len(partition) + partition_start_id)
]
partition[self.id_field] = id_column
partition[self.id_field] = partition[self.id_field].astype("string")

return partition
6 changes: 4 additions & 2 deletions nemo_curator/scripts/add_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def attach_args(
parser.add_argument(
"--starting-index",
type=int,
default=0,
help="Starting index from which to start indexing the documents",
default=None,
help="If supplied, determines the starting index from which to start "
"indexing the documents. By default, it is unspecified, and uses an id"
" scheme that is fast to calculate and is not guaranteed to be ordered.",
)
parser.add_argument(
"--output-data-dir",
Expand Down
5 changes: 5 additions & 0 deletions nemo_curator/utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math


def is_batched(function):
return hasattr(function, "batched") and function.batched


def count_digits(num):
return math.floor(math.log10(num)) + 1
50 changes: 44 additions & 6 deletions tests/test_add_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pandas as pd
import pytest

import nemo_curator
import nemo_curator as nc
from nemo_curator.datasets import DocumentDataset


Expand All @@ -41,10 +41,10 @@ def two_partition_dataset():
)


class TestPrepareTaskData:
class TestAddId:
def test_basic_id(self, single_partition_dataset):
id_field = "id"
add_id = nemo_curator.AddId(id_field)
add_id = nc.AddId(id_field, start_index=0)
id_dataset = add_id(single_partition_dataset)
actual_ids = id_dataset.df[id_field].compute()
expected_ids = pd.Series(
Expand All @@ -63,7 +63,7 @@ def test_basic_id(self, single_partition_dataset):

def test_two_partitions(self, two_partition_dataset):
id_field = "id"
add_id = nemo_curator.AddId(id_field)
add_id = nc.AddId(id_field, start_index=0)
id_dataset = add_id(two_partition_dataset)
actual_ids = id_dataset.df[id_field].compute()
expected_ids = pd.Series(
Expand All @@ -83,7 +83,7 @@ def test_two_partitions(self, two_partition_dataset):
def test_id_prefix(self, two_partition_dataset):
id_field = "id"
id_prefix = "my_id"
add_id = nemo_curator.AddId(id_field, id_prefix=id_prefix)
add_id = nc.AddId(id_field, id_prefix=id_prefix, start_index=0)
id_dataset = add_id(two_partition_dataset)
actual_ids = id_dataset.df[id_field].compute()
expected_ids = pd.Series(
Expand All @@ -103,7 +103,7 @@ def test_id_prefix(self, two_partition_dataset):
def test_start_index(self, two_partition_dataset):
id_field = "id"
start_index = 13
add_id = nemo_curator.AddId(id_field, start_index=start_index)
add_id = nc.AddId(id_field, start_index=start_index)
id_dataset = add_id(two_partition_dataset)
actual_ids = id_dataset.df[id_field].compute()
expected_ids = pd.Series(
Expand All @@ -119,3 +119,41 @@ def test_start_index(self, two_partition_dataset):
assert all(
expected_ids == actual_ids
), f"Expected: {expected_ids}, got: {actual_ids}"

def test_fast_id_single_partition(self, single_partition_dataset):
id_field = "id"
add_id = nc.AddId(id_field)
id_dataset = add_id(single_partition_dataset)
actual_ids = id_dataset.df[id_field].compute()
expected_ids = pd.Series(
[
"doc_id-00",
"doc_id-10",
"doc_id-20",
"doc_id-30",
"doc_id-40",
]
)

assert all(
expected_ids == actual_ids
), f"Expected: {expected_ids}, got: {actual_ids}"

def test_fast_id_two_partitions(self, two_partition_dataset):
id_field = "id"
add_id = nc.AddId(id_field)
id_dataset = add_id(two_partition_dataset)
actual_ids = id_dataset.df[id_field].compute()
expected_ids = pd.Series(
[
"doc_id-00",
"doc_id-10",
"doc_id-20",
"doc_id-01",
"doc_id-11",
]
)

assert all(
expected_ids == actual_ids
), f"Expected: {expected_ids}, got: {actual_ids}"