diff --git a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs index f983e26d48a4..1fcda6fcab99 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs @@ -35,7 +35,7 @@ use crate::arrow::error::ArrowError; use crate::arrow::record_batch::RecordBatch; use crate::arrow::util::bit_util; use crate::error::{DataFusionError, Result}; -use apache_avro::schema::RecordSchema; +use apache_avro::schema::{RecordSchema, UnionSchema}; use apache_avro::{ schema::{Schema as AvroSchema, SchemaKind}, types::Value, @@ -97,15 +97,15 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { schema_lookup: &'b mut BTreeMap, ) -> Result<&'b BTreeMap> { match schema { - AvroSchema::Record(RecordSchema { - name, - fields, - lookup, - .. - }) => { + AvroSchema::Union(union_schema) => { + if union_schema.is_nullable() { + let rec_schema = &union_schema.variants()[1]; + Self::child_schema_lookup(rec_schema, schema_lookup)?; + } + } + AvroSchema::Record(RecordSchema { fields, lookup, .. }) => { lookup.iter().for_each(|(field_name, pos)| { - schema_lookup - .insert(format!("{}.{}", name.fullname(None), field_name), *pos); + schema_lookup.insert(field_name.clone(), *pos); }); for field in fields { @@ -562,8 +562,9 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { }); values .iter() - .map(|v| match v { + .map(|v| match maybe_resolve_union(v) { Value::Record(record) => record, + Value::Null => &null_struct_array, other => panic!("expected Record, got {other:?}"), }) .collect::>>() @@ -775,14 +776,20 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { let len = rows.len(); let num_bytes = bit_util::ceil(len, 8); let mut null_buffer = MutableBuffer::from_len_zeroed(num_bytes); + let empty_vec = vec![]; let struct_rows = rows .iter() .enumerate() .map(|(i, row)| (i, self.field_lookup(field.name(), row))) .map(|(i, v)| { + let v = v.map(maybe_resolve_union); if let Some(Value::Record(value)) = v { bit_util::set_bit(&mut null_buffer, i); value + } else if v.is_none() { + &empty_vec + } else if let Some(Value::Null) = v { + &empty_vec } else { panic!("expected struct got {v:?}"); } @@ -1018,7 +1025,10 @@ mod test { use crate::arrow::array::Array; use crate::arrow::datatypes::{Field, TimeUnit}; use crate::datasource::avro_to_arrow::{Reader, ReaderBuilder}; + use apache_avro::types::Value as AvroValue; use arrow::datatypes::DataType; + use arrow_array::cast::AsArray; + use datafusion_common::assert_batches_eq; use datafusion_common::cast::{ as_int32_array, as_int64_array, as_list_array, as_timestamp_microsecond_array, }; @@ -1101,6 +1111,74 @@ mod test { assert_eq!(batch.num_rows(), 3); } + #[test] + fn test_avro_nullable_struct() { + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "col1", + "type": [ + "null", + { + "type": "record", + "name": "r2", + "fields": [ + { + "name": "col2", + "type": "string" + } + ] + } + ], + "default": null + } + ] + }"#, + ) + .unwrap(); + let mut r1 = apache_avro::types::Record::new(&schema).unwrap(); + r1.put("col1", AvroValue::Union(0, Box::new(AvroValue::Null))); + let mut r2 = apache_avro::types::Record::new(&schema).unwrap(); + r2.put( + "col1", + AvroValue::Union( + 1, + Box::new(AvroValue::Record(vec![( + "col2".to_string(), + AvroValue::String("hello".to_string()), + )])), + ), + ); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + w.append(r1).unwrap(); + w.append(r2).unwrap(); + let bytes = w.into_inner().unwrap(); + + let mut reader = ReaderBuilder::new() + .read_schema() + .with_batch_size(2) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 1); + + let expected = [ + "+---------------+", + "| col1 |", + "+---------------+", + "| |", + "| {col2: hello} |", + "+---------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + #[test] fn test_avro_iterator() { let reader = build_reader("alltypes_plain.avro", 5); diff --git a/datafusion/core/src/datasource/avro_to_arrow/schema.rs b/datafusion/core/src/datasource/avro_to_arrow/schema.rs index f15e378cc699..48cdc23d919d 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/schema.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/schema.rs @@ -116,7 +116,7 @@ fn schema_to_field_with_props( DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense) } } - AvroSchema::Record(RecordSchema { name, fields, .. }) => { + AvroSchema::Record(RecordSchema { fields, .. }) => { let fields: Result<_> = fields .iter() .map(|field| { @@ -129,8 +129,8 @@ fn schema_to_field_with_props( }*/ schema_to_field_with_props( &field.schema, - Some(&format!("{}.{}", name.fullname(None), field.name)), - false, + Some(&field.name), + nullable, Some(props), ) }) @@ -442,6 +442,49 @@ mod test { assert_eq!(arrow_schema.unwrap(), expected); } + #[test] + fn test_nested_schema() { + let avro_schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "col1", + "type": [ + "null", + { + "type": "record", + "name": "r2", + "fields": [ + { + "name": "col2", + "type": "string" + } + ] + } + ], + "default": null + } + ] + }"#, + ) + .unwrap(); + // should not use Avro Record names. + let expected_arrow_schema = Schema::new(vec![Field::new( + "col1", + arrow::datatypes::DataType::Struct( + vec![Field::new("col2", Utf8, true)].into(), + ), + true, + )]); + assert_eq!( + to_arrow_schema(&avro_schema).unwrap(), + expected_arrow_schema + ); + } + #[test] fn test_non_record_schema() { let arrow_schema = to_arrow_schema(&AvroSchema::String); diff --git a/datafusion/sqllogictest/test_files/avro.slt b/datafusion/sqllogictest/test_files/avro.slt index ede11406e1a9..e744fb0558a1 100644 --- a/datafusion/sqllogictest/test_files/avro.slt +++ b/datafusion/sqllogictest/test_files/avro.slt @@ -101,11 +101,11 @@ SELECT id, CAST(string_col AS varchar) FROM alltypes_plain_multi_files 1 1 # test avro nested records -query ?? -SELECT f1, f2 FROM nested_records +query ???? +SELECT f1, f2, f3, f4 FROM nested_records ---- -{ns2.record2.f1_1: aaa, ns2.record2.f1_2: 10, ns2.record2.f1_3: {ns3.record3.f1_3_1: 3.14}} [{ns4.record4.f2_1: true, ns4.record4.f2_2: 1.2}, {ns4.record4.f2_1: true, ns4.record4.f2_2: 2.2}] -{ns2.record2.f1_1: bbb, ns2.record2.f1_2: 20, ns2.record2.f1_3: {ns3.record3.f1_3_1: 3.14}} [{ns4.record4.f2_1: false, ns4.record4.f2_2: 10.2}] +{f1_1: aaa, f1_2: 10, f1_3: {f1_3_1: 3.14}} [{f2_1: true, f2_2: 1.2}, {f2_1: true, f2_2: 2.2}] {f3_1: xyz} [{f4_1: 200}, {f4_1: }] +{f1_1: bbb, f1_2: 20, f1_3: {f1_3_1: 3.14}} [{f2_1: false, f2_2: 10.2}] NULL [{f4_1: }, {f4_1: 300}] # test avro enum query TTT diff --git a/testing b/testing index 37f29510ce97..48672a3cbe7f 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 37f29510ce97cd491b8e6ed75866c6533a5ea2a1 +Subproject commit 48672a3cbe7f84cc2ddda0023c6e3698f7539a99