Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Streaming CSV reads #1479

Merged
merged 1 commit into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ def from_arrow(arrow_table: pa.Table) -> Table:
pyt = _PyTable.from_arrow_record_batches(arrow_table.to_batches(), schema._schema)
return Table._from_pytable(pyt)

@staticmethod
def from_arrow_record_batches(rbs: list[pa.RecordBatch], arrow_schema: pa.Schema) -> Table:
schema = Schema._from_field_name_and_types([(f.name, DataType.from_arrow_type(f.type)) for f in arrow_schema])
pyt = _PyTable.from_arrow_record_batches(rbs, schema._schema)
return Table._from_pytable(pyt)

@staticmethod
def from_pandas(pd_df: pd.DataFrame) -> Table:
if not _PANDAS_AVAILABLE:
Expand Down
46 changes: 40 additions & 6 deletions daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,17 @@ def read_parquet(
return _cast_table_to_schema(Table.from_arrow(table), read_options=read_options, schema=schema)


class PACSVStreamHelper:
def __init__(self, stream: pa.CSVStreamReader) -> None:
self.stream = stream

def __next__(self) -> pa.RecordBatch:
return self.stream.read_next_batch()

def __iter__(self) -> PACSVStreamHelper:
return self


def read_csv(
file: FileInput,
schema: Schema,
Expand Down Expand Up @@ -203,7 +214,7 @@ def read_csv(
fs = None

with _open_stream(file, fs) as f:
table = pacsv.read_csv(
pacsv_stream = pacsv.open_csv(
f,
parse_options=pacsv.ParseOptions(
delimiter=csv_options.delimiter,
Expand All @@ -222,11 +233,34 @@ def read_csv(
),
)

# TODO(jay): Can't limit number of rows with current PyArrow filesystem so we have to shave it off after the read
if read_options.num_rows is not None:
table = table[: read_options.num_rows]

return _cast_table_to_schema(Table.from_arrow(table), read_options=read_options, schema=schema)
if read_options.num_rows is not None:
rows_left = read_options.num_rows
pa_batches = []
pa_schema = None
for record_batch in PACSVStreamHelper(pacsv_stream):
if pa_schema is None:
pa_schema = record_batch.schema
if record_batch.num_rows > rows_left:
record_batch = record_batch.slice(0, rows_left)
pa_batches.append(record_batch)
rows_left -= record_batch.num_rows

# Break needs to be here; always need to process at least one record batch
if rows_left <= 0:
break

# If source schema isn't determined, then the file was truly empty; set an empty source schema
if pa_schema is None:
pa_schema = pa.schema([])

daft_table = Table.from_arrow_record_batches(pa_batches, pa_schema)
assert len(daft_table) <= read_options.num_rows

else:
pa_table = pacsv_stream.read_all()
daft_table = Table.from_arrow(pa_table)

return _cast_table_to_schema(daft_table, read_options=read_options, schema=schema)


def write_csv(
Expand Down
Loading