diff --git a/tests/core/test_persistence.py b/tests/core/test_persistence.py index 990717d..40a16c6 100644 --- a/tests/core/test_persistence.py +++ b/tests/core/test_persistence.py @@ -2,14 +2,20 @@ import os import random import uuid -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, Iterable, List, Tuple +from unittest.mock import patch import pyarrow.parquet as papq import pytest from wicker import schema from wicker.core.config import get_config -from wicker.core.persistance import BasicPersistor +from wicker.core.persistance import ( + BasicPersistor, + ColumnBytesFileWriter, + ParsedExample, + PointerParsedExample, +) from wicker.core.storage import S3PathFactory from wicker.schema.schema import DatasetSchema from wicker.testing.storage import FakeS3DataStorage @@ -17,8 +23,9 @@ DATASET_NAME = "dataset" DATASET_VERSION = "0.0.1" SCHEMA = schema.DatasetSchema( - primary_keys=["bar", "foo"], + primary_keys=["global_index", "bar", "foo"], fields=[ + schema.IntField("global_index"), schema.IntField("foo"), schema.StringField("bar"), schema.BytesField("bytescol"), @@ -28,6 +35,7 @@ ( "train" if i % 2 == 0 else "test", { + "global_index": i, "foo": random.randint(0, 10000), "bar": str(uuid.uuid4()), "bytescol": b"0", @@ -53,7 +61,6 @@ def assert_written_correctness(tmpdir: str) -> None: assert DATASET_NAME in os.listdir(os.path.join(tmpdir, prefix)) assert DATASET_VERSION in os.listdir(os.path.join(tmpdir, prefix, DATASET_NAME)) for partition in ["train", "test"]: - print(os.listdir(os.path.join(tmpdir, prefix))) columns_path = os.path.join(tmpdir, prefix, "__COLUMN_CONCATENATED_FILES__") all_read_bytes = b"" for filename in os.listdir(columns_path): @@ -65,7 +72,10 @@ def assert_written_correctness(tmpdir: str) -> None: # Load parquet file and assert ordering of primary_key assert f"{partition}.parquet" in os.listdir(os.path.join(tmpdir, prefix, DATASET_NAME, DATASET_VERSION)) tbl = papq.read_table(os.path.join(tmpdir, prefix, DATASET_NAME, DATASET_VERSION, f"{partition}.parquet")) - foobar = [(barval.as_py(), fooval.as_py()) for fooval, barval in zip(tbl["foo"], tbl["bar"])] + foobar = [ + (glo_idx.as_py(), barval.as_py(), fooval.as_py()) + for glo_idx, fooval, barval in zip(tbl["global_index"], tbl["foo"], tbl["bar"]) + ] assert foobar == sorted(foobar) @@ -74,7 +84,7 @@ def assert_written_correctness(tmpdir: str) -> None: [({}, DATASET_NAME, DATASET_VERSION, SCHEMA, copy.deepcopy(EXAMPLES_DUPES))], indirect=["mock_basic_persistor"], ) -def test_basic_persistor( +def test_basic_persistor_no_shuffle( mock_basic_persistor: Tuple[BasicPersistor, str], dataset_name: str, dataset_version: str, @@ -87,9 +97,146 @@ def test_basic_persistor( Ensure we read the right file locations, the right amount of bytes, and the ordering is correct. """ + # in order to assert that we are not shuffling we are going to sub out the + # persist partition function and get average distance on global index + # if it is == 2 (ie: samples are adjacent in partitions) then shuffling has occured + def mock_persist_wicker_partition( + self, + spark_partition_iter: Iterable[Tuple[str, ParsedExample]], + schema: schema.DatasetSchema, + s3_storage: FakeS3DataStorage, + s3_path_factory: S3PathFactory, + target_max_column_file_numrows: int = 50, + ) -> Iterable[Tuple[str, PointerParsedExample]]: + # set up the global sum and counter for calcing mean + global_sum = 0 + global_counter = 0 + # we still have to do all of the regular logic to test writing + column_bytes_file_writers: Dict[str, ColumnBytesFileWriter] = {} + heavy_pointer_columns = schema.get_pointer_columns() + metadata_columns = schema.get_non_pointer_columns() + previous_value, previous_parition = None, None + + for partition, example in spark_partition_iter: + # if the previous value is unset or the parition has changed + if not previous_value or previous_parition != partition: + previous_value = example["global_index"] + previous_parition = partition + # if we can calculate the distance because we are on same parition + # and the previous value is not None + else: + current_diff = abs(example["global_index"] - previous_value) + previous_value = example["global_index"] + previous_parition = partition + global_sum += current_diff + global_counter += 1 + # Create ColumnBytesFileWriter lazily as required, for each partition + if partition not in column_bytes_file_writers: + column_bytes_file_writers[partition] = ColumnBytesFileWriter( + s3_storage, + s3_path_factory, + target_file_rowgroup_size=target_max_column_file_numrows, + ) + + # Write to ColumnBytesFileWriter and return only metadata + heavy-pointers + parquet_metadata: Dict[str, Any] = {col: example[col] for col in metadata_columns} + for col in heavy_pointer_columns: + loc = column_bytes_file_writers[partition].add(col, example[col]) + parquet_metadata[col] = loc.to_bytes() + yield partition, parquet_metadata + + # Flush all writers when finished + for partition in column_bytes_file_writers: + column_bytes_file_writers[partition].close() + # assert that we are at mean 2 and that we have not shuffled + mean = global_sum / global_counter + assert mean == 2.0 + + with patch("wicker.core.persistance.AbstractDataPersistor.persist_wicker_partition", mock_persist_wicker_partition): + # create the mock basic persistor + mock_basic_persistor_obj, tempdir = mock_basic_persistor + # persist the dataset + mock_basic_persistor_obj.persist_wicker_dataset(dataset_name, dataset_version, dataset_schema, dataset) + # assert the dataset is correctly written + assert_written_correctness(tempdir) + + +@pytest.mark.parametrize( + "mock_basic_persistor, dataset_name, dataset_version, dataset_schema, dataset", + [({}, DATASET_NAME, DATASET_VERSION, SCHEMA, copy.deepcopy(EXAMPLES_DUPES))], + indirect=["mock_basic_persistor"], +) +def test_basic_persistor_shuffle( + mock_basic_persistor: Tuple[BasicPersistor, str], + dataset_name: str, + dataset_version: str, + dataset_schema: DatasetSchema, + dataset: List[Tuple[str, Dict[str, Any]]], +): + """Test if the basic persistor saves the correct data and shuffles it into different partitions + + Ensure we read the right file locations, the right amount of bytes, + and the ordering is correct. + """ + # in order to assert that we are shuffling we are going to sub out the + # persist partition function and get average distance on global index + # if it is != 2 (ie: samples are adjacent in partitions) then shuffling has occured + def mock_persist_wicker_partition( + self, + spark_partition_iter: Iterable[Tuple[str, ParsedExample]], + schema: schema.DatasetSchema, + s3_storage: FakeS3DataStorage, + s3_path_factory: S3PathFactory, + target_max_column_file_numrows: int = 50, + ) -> Iterable[Tuple[str, PointerParsedExample]]: + # set up the global sum and counter for calcing mean + global_sum = 0 + global_counter = 0 + # we still have to do all of the regular logic to test writing + column_bytes_file_writers: Dict[str, ColumnBytesFileWriter] = {} + heavy_pointer_columns = schema.get_pointer_columns() + metadata_columns = schema.get_non_pointer_columns() + previous_value, previous_parition = None, None + + for partition, example in spark_partition_iter: + # if the previous value is unset or the parition has changed + if not previous_value or previous_parition != partition: + previous_value = example["global_index"] + previous_parition = partition + # if we can calculate the distance because we are on same parition + # and the previous value is not None + else: + current_diff = abs(example["global_index"] - previous_value) + previous_value = example["global_index"] + previous_parition = partition + global_sum += current_diff + global_counter += 1 + # Create ColumnBytesFileWriter lazily as required, for each partition + if partition not in column_bytes_file_writers: + column_bytes_file_writers[partition] = ColumnBytesFileWriter( + s3_storage, + s3_path_factory, + target_file_rowgroup_size=target_max_column_file_numrows, + ) + + # Write to ColumnBytesFileWriter and return only metadata + heavy-pointers + parquet_metadata: Dict[str, Any] = {col: example[col] for col in metadata_columns} + for col in heavy_pointer_columns: + loc = column_bytes_file_writers[partition].add(col, example[col]) + parquet_metadata[col] = loc.to_bytes() + yield partition, parquet_metadata + + # Flush all writers when finished + for partition in column_bytes_file_writers: + column_bytes_file_writers[partition].close() + # assert that we are not at mean 2 and that we have shuffled successfully + mean = global_sum / global_counter + assert mean != 2.0 + # create the mock basic persistor - mock_basic_persistor_obj, tempdir = mock_basic_persistor - # persist the dataset - mock_basic_persistor_obj.persist_wicker_dataset(dataset_name, dataset_version, dataset_schema, dataset) - # assert the dataset is correctly written - assert_written_correctness(tempdir) + with patch("wicker.core.persistance.AbstractDataPersistor.persist_wicker_partition", mock_persist_wicker_partition): + mock_basic_persistor_obj, tempdir = mock_basic_persistor + # persist and shuffle the dataset + mock_basic_persistor_obj.persist_wicker_dataset(dataset_name, dataset_version, dataset_schema, dataset, False) + # assert the dataset is correctly written + assert_written_correctness(tempdir) diff --git a/wicker/core/datasets.py b/wicker/core/datasets.py index 94d04a6..7bc1a98 100644 --- a/wicker/core/datasets.py +++ b/wicker/core/datasets.py @@ -92,9 +92,10 @@ def __init__( self._partition = DatasetPartition(dataset_id=self._dataset_id, partition=dataset_partition_name) self._dataset_definition = DatasetDefinition( self._dataset_id, - schema=self.schema(), + schema=self.schema, ) + @property def schema(self) -> DatasetSchema: if self._schema is None: schema_path = self._s3_path_factory.get_dataset_schema_path(self._dataset_id) @@ -107,20 +108,24 @@ def schema(self) -> DatasetSchema: ) return self._schema + @property def arrow_table(self) -> pyarrow.Table: - path = self._s3_path_factory.get_dataset_partition_path(self._partition, s3_prefix=False) if not self._arrow_table: + path = self._s3_path_factory.get_dataset_partition_path(self._partition, s3_prefix=False) self._arrow_table = papq.read_table(path, filesystem=self._pa_filesystem) return self._arrow_table def __len__(self) -> int: - return len(self.arrow_table()) + return len(self.arrow_table) def __getitem__(self, idx: int) -> Dict[str, Any]: - tbl = self.arrow_table() - columns = self._columns_to_load if self._columns_to_load is not None else tbl.column_names - row = {col: tbl[col][idx].as_py() for col in columns} return dataloading.load_example( - self._column_bytes_file_cache.resolve_pointers(row, self.schema()), - self.schema(), + self._column_bytes_file_cache.resolve_pointers(self._get_row_pq_table(idx), self.schema), + self.schema, ) + + def _get_row_pq_table(self, idx: int): + tbl = self.arrow_table + columns = self._columns_to_load if self._columns_to_load is not None else tbl.column_names + row = {col: tbl[col][idx].as_py() for col in columns} + return row diff --git a/wicker/core/persistance.py b/wicker/core/persistance.py index fac1bc3..0e70422 100644 --- a/wicker/core/persistance.py +++ b/wicker/core/persistance.py @@ -1,4 +1,5 @@ import abc +import random from typing import Any, Dict, Iterable, List, Optional, Tuple import pyarrow as pa @@ -155,6 +156,7 @@ def persist_wicker_dataset( dataset: Any, s3_storage: S3DataStorage = S3DataStorage(), s3_path_factory: S3PathFactory = S3PathFactory(), + shuffle: bool = False, ) -> Optional[Dict[str, int]]: """ Persist wicker dataset public facing api function, for api consistency. @@ -170,9 +172,15 @@ def persist_wicker_dataset( :type s3_storage: S3DataStorage :param s3_path_factory: s3 path abstraction :type s3_path_factory: S3PathFactory + :param shuffle: to shuffle or not, is this a question? + :type shuffle: str """ return BasicPersistor(s3_storage, s3_path_factory).persist_wicker_dataset( - dataset_name, dataset_version, dataset_schema, dataset + dataset_name, + dataset_version, + dataset_schema, + dataset, + shuffle, ) @@ -189,7 +197,12 @@ def __init__( super().__init__(s3_storage, s3_path_factory) def persist_wicker_dataset( - self, dataset_name: str, dataset_version: str, dataset_schema: schema_module.DatasetSchema, dataset: Any + self, + dataset_name: str, + dataset_version: str, + dataset_schema: schema_module.DatasetSchema, + dataset: Any, + shuffle: bool = False, ) -> Optional[Dict[str, int]]: """ Persist a user defined dataset, pushing data to s3 in a basic manner @@ -202,6 +215,8 @@ def persist_wicker_dataset( :type dataset_schema: wicker.schema.schema.DatasetSchema :param dataset: Data of the dataset :type dataset: User defined + :param shuffle: to shuffle or not, is this a question? + :type shuffle: str """ # what needs to be done within this function # 1. Check if the variables are set @@ -223,9 +238,14 @@ def persist_wicker_dataset( dataset_0 = [(row[0], self.parse_row(row[1], dataset_schema)) for row in dataset] # 4. Sort the dataset if not sorted - sorted_dataset_0 = sorted(dataset_0, key=lambda tup: tup[0]) + dataset_1 = sorted(dataset_0, key=lambda tup: tup[0]) + + # 5. if we have shuffling, shuffle the dataset before partition + # ensures proper and random shuffling + if shuffle: + random.shuffle(dataset_1) - # 5. Partition the dataset into K partitions + # 6. Partition the dataset into K partitions partitions = [] def divide_chunks(list_to_divide): @@ -233,9 +253,9 @@ def divide_chunks(list_to_divide): for i in range(0, len(list_to_divide), PARTITION_SIZE): partitions.append(list_to_divide[i : i + PARTITION_SIZE]) - divide_chunks(sorted_dataset_0) + divide_chunks(dataset_1) - # 6. Persist the partitions to S3 + # 7. Persist the partitions to S3 for partition in partitions: # build a persistence iterator for each parition iterator = self.persist_wicker_partition( @@ -244,11 +264,11 @@ def divide_chunks(list_to_divide): # make sure all yields get called list(iterator) - # 7. Create the parition table, need to combine keys in a way we can form table + # 8. Create the parition table, need to combine keys in a way we can form table # split into k dicts where k is partition number and the data is a list of values # for each key for all the dicts in the partition merged_dicts: Dict[str, Dict[str, List[Any]]] = {} - for partition_key, row in sorted_dataset_0: + for partition_key, row in dataset_1: current_dict: Dict[str, List[Any]] = merged_dicts.get(partition_key, {}) for col in row.keys(): if col in current_dict: @@ -266,7 +286,7 @@ def divide_chunks(list_to_divide): pc.sort_indices(data_table, sort_keys=[(pk, "ascending") for pk in dataset_schema.primary_keys]), ) - # 8. Persist the partition table to s3 + # 9. Persist the partition table to s3 written_dict = {} for partition_key, pa_table in arrow_dict.items(): self.save_partition_tbl(