Skip to content

Commit

Permalink
feat: add register_avro
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo committed Aug 21, 2023
1 parent 217ede8 commit 2cf52bf
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
26 changes: 26 additions & 0 deletions datafusion/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,32 @@ def test_register_json(ctx, tmp_path):
ctx.register_json("json4", gzip_path, file_compression_type="rar")


def test_register_avro(ctx):
path = "testing/data/avro/alltypes_plain.avro"
ctx.register_avro("alltypes_plain", path)
result = ctx.sql(
"SELECT SUM(tinyint_col) as tinyint_sum FROM alltypes_plain"
).collect()
result = pa.Table.from_batches(result).to_pydict()
assert result["tinyint_sum"][0] > 0

alternative_schema = pa.schema(
[
pa.field("id", pa.int64()),
]
)

ctx.register_avro(
"alltypes_plain_schema",
path,
schema=alternative_schema,
infinite=False,
)
result = ctx.sql("SELECT * FROM alltypes_plain_schema").collect()
result = pa.Table.from_batches(result)
assert result.schema == alternative_schema


def test_execute(ctx, tmp_path):
data = [1, 1, 2, 2, 3, 11, 12]

Expand Down
33 changes: 33 additions & 0 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,39 @@ impl PySessionContext {
Ok(())
}

#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (name,
path,
schema=None,
file_extension=".avro",
table_partition_cols=vec![],
infinite=false))]
fn register_avro(
&mut self,
name: &str,
path: PathBuf,
schema: Option<PyArrowType<Schema>>,
file_extension: &str,
table_partition_cols: Vec<(String, String)>,
infinite: bool,
py: Python,
) -> PyResult<()> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;

let mut options = AvroReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.mark_infinite(infinite);
options.file_extension = file_extension;
options.schema = schema.as_ref().map(|x| &x.0);

let result = self.ctx.register_avro(name, path, options);
wait_for_future(py, result).map_err(DataFusionError::from)?;

Ok(())
}

// Registers a PyArrow.Dataset
fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> PyResult<()> {
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
Expand Down

0 comments on commit 2cf52bf

Please sign in to comment.