diff --git a/daft/table/table.py b/daft/table/table.py index 3a464a3058..44029777e4 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -111,6 +111,12 @@ def from_arrow(arrow_table: pa.Table) -> Table: pyt = _PyTable.from_arrow_record_batches(arrow_table.to_batches(), schema._schema) return Table._from_pytable(pyt) + @staticmethod + def from_arrow_record_batches(rbs: list[pa.RecordBatch], arrow_schema: pa.Schema) -> Table: + schema = Schema._from_field_name_and_types([(f.name, DataType.from_arrow_type(f.type)) for f in arrow_schema]) + pyt = _PyTable.from_arrow_record_batches(rbs, schema._schema) + return Table._from_pytable(pyt) + @staticmethod def from_pandas(pd_df: pd.DataFrame) -> Table: if not _PANDAS_AVAILABLE: diff --git a/daft/table/table_io.py b/daft/table/table_io.py index fe94dc88a4..10cd50caf2 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -175,6 +175,17 @@ def read_parquet( return _cast_table_to_schema(Table.from_arrow(table), read_options=read_options, schema=schema) +class PACSVStreamHelper: + def __init__(self, stream: pa.CSVStreamReader) -> None: + self.stream = stream + + def __next__(self) -> pa.RecordBatch: + return self.stream.read_next_batch() + + def __iter__(self) -> PACSVStreamHelper: + return self + + def read_csv( file: FileInput, schema: Schema, @@ -203,7 +214,7 @@ def read_csv( fs = None with _open_stream(file, fs) as f: - table = pacsv.read_csv( + pacsv_stream = pacsv.open_csv( f, parse_options=pacsv.ParseOptions( delimiter=csv_options.delimiter, @@ -222,11 +233,34 @@ def read_csv( ), ) - # TODO(jay): Can't limit number of rows with current PyArrow filesystem so we have to shave it off after the read - if read_options.num_rows is not None: - table = table[: read_options.num_rows] - - return _cast_table_to_schema(Table.from_arrow(table), read_options=read_options, schema=schema) + if read_options.num_rows is not None: + rows_left = read_options.num_rows + pa_batches = [] + pa_schema = None + for record_batch in PACSVStreamHelper(pacsv_stream): + if pa_schema is None: + pa_schema = record_batch.schema + if record_batch.num_rows > rows_left: + record_batch = record_batch.slice(0, rows_left) + pa_batches.append(record_batch) + rows_left -= record_batch.num_rows + + # Break needs to be here; always need to process at least one record batch + if rows_left <= 0: + break + + # If source schema isn't determined, then the file was truly empty; set an empty source schema + if pa_schema is None: + pa_schema = pa.schema([]) + + daft_table = Table.from_arrow_record_batches(pa_batches, pa_schema) + assert len(daft_table) <= read_options.num_rows + + else: + pa_table = pacsv_stream.read_all() + daft_table = Table.from_arrow(pa_table) + + return _cast_table_to_schema(daft_table, read_options=read_options, schema=schema) def write_csv(