From 88d805231141820341bbb4684f18b9cd2bf7e000 Mon Sep 17 00:00:00 2001 From: AndieHuang Date: Sat, 2 Apr 2022 13:09:17 +0800 Subject: [PATCH] merge pa table --- python/pyjava/data/datasource.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/python/pyjava/data/datasource.py b/python/pyjava/data/datasource.py index da05502..f94cdff 100644 --- a/python/pyjava/data/datasource.py +++ b/python/pyjava/data/datasource.py @@ -2,6 +2,7 @@ from typing import Any, Generic, List, Callable, Union, Tuple, Iterable import ray +import pyarrow as pa from ray.data.datasource.datasource import WriteResult from ray.types import ObjectRef @@ -16,17 +17,36 @@ def __init__(self, data_refs: List[str], num_tables_per_block: int = 1): self.data_refs = data_refs self.block_refs = [] + def merged_pa_generator(pa_generator): + merged_tables = [] + while True: + try: + for i in range(0, 7): + patable = next(pa_generator) + merged_tables.append(patable) + yield pa.concat_tables(merged_tables) + merged_tables.clear() + except StopIteration as e: + if len(merged_tables) > 0: + yield pa.concat_tables(merged_tables) + print("Reading from pa table iterator is done!") + break + @ray.remote def make_block(_data_ref): block_refs = [] data_iter = RayContext.fetch_data_from_single_data_server_as_arrow(_data_ref) temp_box = [] - for arrow_table in data_iter: + + merged_data_iters = merged_pa_generator(data_iter) + + for arrow_table in merged_data_iters: temp_box.append(arrow_table) if len(temp_box) == num_tables_per_block: for t in temp_box: block_refs.append(ray.put(t)) temp_box.clear() + if len(temp_box) != 0: for t in temp_box: block_refs.append(ray.put(t))