Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for decoding StringArray in LargeUtf8 schema #143

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 46 additions & 18 deletions serde_arrow/src/arrow_impl/deserialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ impl BufferExtract for dyn Array {
let typed = self
.as_any()
.downcast_ref::<PrimitiveArray<$arrow_type>>()
.ok_or_else(|| error!("Cannot interpret array as typed array"))?;
.ok_or_else(|| {
error!(
"Cannot interpret {} array as {}",
self.data_type(),
stringify!($arrow_type)
)
})?;

let buffer = buffers.$push_func(typed.values())?;
let validity = get_validity(typed).map(|v| buffers.push_u1(v));
Expand All @@ -48,11 +54,8 @@ impl BufferExtract for dyn Array {
}

macro_rules! convert_utf8 {
($array_type:ty, $variant:ident, $push_func:ident) => {{
let typed = self
.as_any()
.downcast_ref::<$array_type>()
.ok_or_else(|| error!("cannot convert array into string"))?;
($typed:expr, $variant:ident, $push_func:ident) => {{
let typed = $typed;

let buffer = buffers.push_u8(typed.value_data());
let offsets = buffers.$push_func(typed.value_offsets())?;
Expand Down Expand Up @@ -115,7 +118,7 @@ impl BufferExtract for dyn Array {
let typed = self
.as_any()
.downcast_ref::<BooleanArray>()
.ok_or_else(|| error!("cannot convert array into bool"))?;
.ok_or_else(|| error!("cannot convert {} array into bool", self.data_type()))?;
let values = typed.values();

let buffer = buffers.push_u1(BitBuffer {
Expand Down Expand Up @@ -156,15 +159,41 @@ impl BufferExtract for dyn Array {
T::Timestamp(U::Nanosecond, _) => {
convert_primitive!(TimestampNanosecondType, Date64, push_u64_cast)
}
T::Utf8 => convert_utf8!(StringArray, Utf8, push_u32_cast),
T::LargeUtf8 => convert_utf8!(LargeStringArray, LargeUtf8, push_u64_cast),
T::Utf8 => {
let typed = self.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
error!("cannot convert {} array into string", self.data_type())
})?;

convert_utf8!(typed, Utf8, push_u32_cast)
}
T::LargeUtf8 => {
// Try decoding as large strings first
match self
.as_any()
.downcast_ref::<LargeStringArray>()
.ok_or_else(|| error!("cannot convert {} array into string", self.data_type()))
{
Ok(typed) => convert_utf8!(typed, LargeUtf8, push_u64_cast),
Err(_) => {
// Failed; try decoding as small strings
let typed =
self.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
error!("cannot convert {} array into string", self.data_type())
})?;

convert_utf8!(typed, Utf8, push_u32_cast)
}
}
}
T::List => convert_list!(i32, List, push_u32_cast),
T::LargeList => convert_list!(i64, LargeList, push_u64_cast),
T::Struct => {
let typed = self
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| error!("cannot convert array into struct array"))?;
let typed = self.as_any().downcast_ref::<StructArray>().ok_or_else(|| {
error!(
"cannot convert {} array into struct array",
self.data_type()
)
})?;
let validity = get_validity(self).map(|v| buffers.push_u1(v));
let mut fields = Vec::new();

Expand Down Expand Up @@ -229,7 +258,7 @@ impl BufferExtract for dyn Array {
let typed = self
.as_any()
.downcast_ref::<DictionaryArray<$key_type>>()
.ok_or_else(|| error!("cannot convert array into u32 dictionary"))?;
.ok_or_else(|| error!("cannot convert {} array into u32 dictionary", self.data_type()))?;

// NOTE: the array is validity is given by the key validity
if typed.values().null_count() != 0 {
Expand Down Expand Up @@ -277,10 +306,9 @@ impl BufferExtract for dyn Array {
use crate::_impl::arrow::array::UnionArray;

// TODO: test assumptions
let typed = self
.as_any()
.downcast_ref::<UnionArray>()
.ok_or_else(|| error!("cannot convert array to union array"))?;
let typed = self.as_any().downcast_ref::<UnionArray>().ok_or_else(|| {
error!("cannot convert {} array to union array", self.data_type())
})?;

let types = buffers.push_u8_cast(typed.type_ids())?;

Expand Down
1 change: 1 addition & 0 deletions serde_arrow/src/test_end_to_end/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ mod issue_137_schema_like_from_arrow_schema;
mod issue_90;
mod test_docs_examples;
mod test_items;
mod test_strings;
134 changes: 134 additions & 0 deletions serde_arrow/src/test_end_to_end/test_strings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
use std::sync::Arc;

use serde::Deserialize;

use crate::{
self as serde_arrow,
_impl::arrow::_raw::{
array::{LargeStringArray, RecordBatch, StringArray},
schema::{DataType, Field, Schema},
},
internal::error::PanicOnError,
schema::{SchemaLike, TracingOptions},
};

#[derive(Deserialize, Debug, PartialEq, Eq)]
struct Record {
id: String,
}

#[test]
fn test_short_strings() -> PanicOnError<()> {
let column: StringArray = vec!["foo", "bar"].into();

let fields = vec![Field::new("id", DataType::Utf8, false)];

let batch = RecordBatch::try_new(
Arc::new(Schema::new(fields.clone())),
vec![Arc::new(column.clone())],
)?;

let records: Vec<Record> = serde_arrow::from_arrow(&fields, batch.columns())?;

assert_eq!(
records,
vec![
Record {
id: "foo".to_string()
},
Record {
id: "bar".to_string()
},
]
);

Ok(())
}

#[test]
fn test_large_strings_into_short_strings() -> PanicOnError<()> {
let column: LargeStringArray = vec!["foo", "bar"].into();

let deser_fields = vec![Field::new("id", DataType::Utf8, false)];
let batch_fields = vec![Field::new("id", DataType::LargeUtf8, false)];

assert_eq!(
Vec::<Field>::from_type::<Record>(TracingOptions::default())?,
batch_fields
);

let batch = RecordBatch::try_new(
Arc::new(Schema::new(batch_fields.clone())),
vec![Arc::new(column.clone())],
)?;

let result: Result<Vec<Record>, _> = serde_arrow::from_arrow(&deser_fields, batch.columns());

assert!(
result.is_err(),
"Reading large strings into a short strings array did not error"
);

Ok(())
}

#[test]
fn test_large_strings() -> PanicOnError<()> {
let column: LargeStringArray = vec!["foo", "bar"].into();

let fields = vec![Field::new("id", DataType::LargeUtf8, false)];
assert_eq!(
Vec::<Field>::from_type::<Record>(TracingOptions::default())?,
fields
);

let batch = RecordBatch::try_new(
Arc::new(Schema::new(fields.clone())),
vec![Arc::new(column.clone())],
)?;

let records: Vec<Record> = serde_arrow::from_arrow(&fields, batch.columns())?;

assert_eq!(
records,
vec![
Record {
id: "foo".to_string()
},
Record {
id: "bar".to_string()
},
]
);

Ok(())
}

#[test]
fn test_short_strings_into_large_strings() -> PanicOnError<()> {
let column: StringArray = vec!["foo", "bar"].into();

let deser_fields = Vec::<_>::from_type::<Record>(TracingOptions::default())?;
let batch_fields = vec![Field::new("id", DataType::Utf8, false)];

let batch = RecordBatch::try_new(
Arc::new(Schema::new(batch_fields.clone())),
vec![Arc::new(column.clone())],
)?;

let records: Vec<Record> = serde_arrow::from_arrow(&deser_fields, batch.columns())?;

assert_eq!(
records,
vec![
Record {
id: "foo".to_string()
},
Record {
id: "bar".to_string()
},
]
);

Ok(())
}