Skip to content

Commit

Permalink
fix: avro_to_arrow: Handle avro nested nullable struct (union)
Browse files Browse the repository at this point in the history
Corrects handling of a nullable struct union.

Signed-off-by: 🐼 Samrose Ahmed 🐼 <[email protected]>
  • Loading branch information
Samrose-Ahmed committed Oct 2, 2023
1 parent f959127 commit 268598d
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 18 deletions.
98 changes: 88 additions & 10 deletions datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use crate::arrow::error::ArrowError;
use crate::arrow::record_batch::RecordBatch;
use crate::arrow::util::bit_util;
use crate::error::{DataFusionError, Result};
use apache_avro::schema::RecordSchema;
use apache_avro::schema::{RecordSchema, UnionSchema};
use apache_avro::{
schema::{Schema as AvroSchema, SchemaKind},
types::Value,
Expand Down Expand Up @@ -97,15 +97,15 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> {
schema_lookup: &'b mut BTreeMap<String, usize>,
) -> Result<&'b BTreeMap<String, usize>> {
match schema {
AvroSchema::Record(RecordSchema {
name,
fields,
lookup,
..
}) => {
AvroSchema::Union(union_schema) => {
if union_schema.is_nullable() {
let rec_schema = &union_schema.variants()[1];
Self::child_schema_lookup(rec_schema, schema_lookup)?;
}
}
AvroSchema::Record(RecordSchema { fields, lookup, .. }) => {
lookup.iter().for_each(|(field_name, pos)| {
schema_lookup
.insert(format!("{}.{}", name.fullname(None), field_name), *pos);
schema_lookup.insert(field_name.clone(), *pos);
});

for field in fields {
Expand Down Expand Up @@ -562,8 +562,9 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> {
});
values
.iter()
.map(|v| match v {
.map(|v| match maybe_resolve_union(v) {
Value::Record(record) => record,
Value::Null => &null_struct_array,
other => panic!("expected Record, got {other:?}"),
})
.collect::<Vec<&Vec<(String, Value)>>>()
Expand Down Expand Up @@ -775,14 +776,20 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> {
let len = rows.len();
let num_bytes = bit_util::ceil(len, 8);
let mut null_buffer = MutableBuffer::from_len_zeroed(num_bytes);
let empty_vec = vec![];
let struct_rows = rows
.iter()
.enumerate()
.map(|(i, row)| (i, self.field_lookup(field.name(), row)))
.map(|(i, v)| {
let v = v.map(maybe_resolve_union);
if let Some(Value::Record(value)) = v {
bit_util::set_bit(&mut null_buffer, i);
value
} else if v.is_none() {
&empty_vec
} else if let Some(Value::Null) = v {
&empty_vec
} else {
panic!("expected struct got {v:?}");
}
Expand Down Expand Up @@ -1018,7 +1025,10 @@ mod test {
use crate::arrow::array::Array;
use crate::arrow::datatypes::{Field, TimeUnit};
use crate::datasource::avro_to_arrow::{Reader, ReaderBuilder};
use apache_avro::types::Value as AvroValue;
use arrow::datatypes::DataType;
use arrow_array::cast::AsArray;
use datafusion_common::assert_batches_eq;
use datafusion_common::cast::{
as_int32_array, as_int64_array, as_list_array, as_timestamp_microsecond_array,
};
Expand Down Expand Up @@ -1101,6 +1111,74 @@ mod test {
assert_eq!(batch.num_rows(), 3);
}

#[test]
fn test_avro_nullable_struct() {
let schema = apache_avro::Schema::parse_str(
r#"
{
"type": "record",
"name": "r1",
"fields": [
{
"name": "col1",
"type": [
"null",
{
"type": "record",
"name": "r2",
"fields": [
{
"name": "col2",
"type": "string"
}
]
}
],
"default": null
}
]
}"#,
)
.unwrap();
let mut r1 = apache_avro::types::Record::new(&schema).unwrap();
r1.put("col1", AvroValue::Union(0, Box::new(AvroValue::Null)));
let mut r2 = apache_avro::types::Record::new(&schema).unwrap();
r2.put(
"col1",
AvroValue::Union(
1,
Box::new(AvroValue::Record(vec![(
"col2".to_string(),
AvroValue::String("hello".to_string()),
)])),
),
);

let mut w = apache_avro::Writer::new(&schema, vec![]);
w.append(r1).unwrap();
w.append(r2).unwrap();
let bytes = w.into_inner().unwrap();

let mut reader = ReaderBuilder::new()
.read_schema()
.with_batch_size(2)
.build(std::io::Cursor::new(bytes))
.unwrap();
let batch = reader.next().unwrap().unwrap();
assert_eq!(batch.num_rows(), 2);
assert_eq!(batch.num_columns(), 1);

let expected = [
"+---------------+",
"| col1 |",
"+---------------+",
"| |",
"| {col2: hello} |",
"+---------------+",
];
assert_batches_eq!(expected, &[batch]);
}

#[test]
fn test_avro_iterator() {
let reader = build_reader("alltypes_plain.avro", 5);
Expand Down
49 changes: 46 additions & 3 deletions datafusion/core/src/datasource/avro_to_arrow/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ fn schema_to_field_with_props(
DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense)
}
}
AvroSchema::Record(RecordSchema { name, fields, .. }) => {
AvroSchema::Record(RecordSchema { fields, .. }) => {
let fields: Result<_> = fields
.iter()
.map(|field| {
Expand All @@ -129,8 +129,8 @@ fn schema_to_field_with_props(
}*/
schema_to_field_with_props(
&field.schema,
Some(&format!("{}.{}", name.fullname(None), field.name)),
false,
Some(&field.name),
nullable,
Some(props),
)
})
Expand Down Expand Up @@ -442,6 +442,49 @@ mod test {
assert_eq!(arrow_schema.unwrap(), expected);
}

#[test]
fn test_nested_schema() {
let avro_schema = apache_avro::Schema::parse_str(
r#"
{
"type": "record",
"name": "r1",
"fields": [
{
"name": "col1",
"type": [
"null",
{
"type": "record",
"name": "r2",
"fields": [
{
"name": "col2",
"type": "string"
}
]
}
],
"default": null
}
]
}"#,
)
.unwrap();
// should not use Avro Record names.
let expected_arrow_schema = Schema::new(vec![Field::new(
"col1",
arrow::datatypes::DataType::Struct(
vec![Field::new("col2", Utf8, true)].into(),
),
true,
)]);
assert_eq!(
to_arrow_schema(&avro_schema).unwrap(),
expected_arrow_schema
);
}

#[test]
fn test_non_record_schema() {
let arrow_schema = to_arrow_schema(&AvroSchema::String);
Expand Down
8 changes: 4 additions & 4 deletions datafusion/sqllogictest/test_files/avro.slt
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ SELECT id, CAST(string_col AS varchar) FROM alltypes_plain_multi_files
1 1

# test avro nested records
query ??
SELECT f1, f2 FROM nested_records
query ????
SELECT f1, f2, f3, f4 FROM nested_records
----
{ns2.record2.f1_1: aaa, ns2.record2.f1_2: 10, ns2.record2.f1_3: {ns3.record3.f1_3_1: 3.14}} [{ns4.record4.f2_1: true, ns4.record4.f2_2: 1.2}, {ns4.record4.f2_1: true, ns4.record4.f2_2: 2.2}]
{ns2.record2.f1_1: bbb, ns2.record2.f1_2: 20, ns2.record2.f1_3: {ns3.record3.f1_3_1: 3.14}} [{ns4.record4.f2_1: false, ns4.record4.f2_2: 10.2}]
{f1_1: aaa, f1_2: 10, f1_3: {f1_3_1: 3.14}} [{f2_1: true, f2_2: 1.2}, {f2_1: true, f2_2: 2.2}] {f3_1: xyz} [{f4_1: 200}, {f4_1: }]
{f1_1: bbb, f1_2: 20, f1_3: {f1_3_1: 3.14}} [{f2_1: false, f2_2: 10.2}] NULL [{f4_1: }, {f4_1: 300}]

# test avro enum
query TTT
Expand Down
2 changes: 1 addition & 1 deletion testing

0 comments on commit 268598d

Please sign in to comment.