Skip to content

Commit

Permalink
Add SparsUnion support, Add interval datatype
Browse files Browse the repository at this point in the history
  • Loading branch information
chmp committed Oct 13, 2024
1 parent 508f3ae commit 73077ea
Show file tree
Hide file tree
Showing 7 changed files with 392 additions and 84 deletions.
36 changes: 32 additions & 4 deletions marrow/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
view::{
BitsWithOffset, BooleanView, BytesView, DecimalView, DenseUnionView, DictionaryView,
FixedSizeBinaryView, FixedSizeListView, ListView, MapView, NullView, PrimitiveView,
StructView, TimeView, TimestampView, View,
SparseUnionView, StructView, TimeView, TimestampView, View,
},
};

Expand Down Expand Up @@ -83,8 +83,10 @@ pub enum Array {
Dictionary(DictionaryArray),
/// An array of maps
Map(MapArray),
/// An array of unions
/// An array of unions with compact memory layout
DenseUnion(DenseUnionArray),
/// An array of unions
SparseUnion(SparseUnionArray),
}

impl Array {
Expand Down Expand Up @@ -123,6 +125,7 @@ impl Array {
Self::Map(array) => View::Map(array.as_view()),
Self::Dictionary(array) => View::Dictionary(array.as_view()),
Self::DenseUnion(array) => View::DenseUnion(array.as_view()),
Self::SparseUnion(array) => View::SparseUnion(array.as_view()),
}
}
}
Expand Down Expand Up @@ -498,7 +501,7 @@ impl DictionaryArray {
}
}

/// A union of different data types
/// A union of different data types with a compact representation
///
/// This corresponds roughly to Rust's enums. Each element has a type, which indicates the
/// underlying array to use. For fast lookups the offsets into the underlying arrays are stored as
Expand All @@ -515,7 +518,8 @@ pub struct DenseUnionArray {
}

impl DenseUnionArray {
fn as_view(&self) -> DenseUnionView<'_> {
/// Get the view for this array
pub fn as_view(&self) -> DenseUnionView<'_> {
DenseUnionView {
types: &self.types,
offsets: &self.offsets,
Expand All @@ -527,3 +531,27 @@ impl DenseUnionArray {
}
}
}

/// A union of different data types with a less compact representation
///
#[derive(Debug, Clone, PartialEq)]
pub struct SparseUnionArray {
/// The types of each element
pub types: Vec<i8>,
/// The arrays with their metadata
pub fields: Vec<(i8, FieldMeta, Array)>,
}

impl SparseUnionArray {
/// Get the view for this array
pub fn as_view(&self) -> SparseUnionView<'_> {
SparseUnionView {
types: &self.types,
fields: self
.fields
.iter()
.map(|(type_id, meta, array)| (*type_id, meta.clone(), array.as_view()))
.collect(),
}
}
}
58 changes: 57 additions & 1 deletion marrow/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ pub enum DataType {
Time64(TimeUnit),
/// Durations stored as `i64` with the given unit
Duration(TimeUnit),
/// Calendar intervals with different layouts depending on the given unit
Interval(IntervalUnit),
/// Fixed point values stored with the given precision and scale
Decimal128(u8, i8),
/// Structs
Expand Down Expand Up @@ -280,7 +282,7 @@ impl std::fmt::Display for UnionMode {
impl std::str::FromStr for UnionMode {
type Err = MarrowError;

fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
fn from_str(s: &str) -> Result<Self> {
match s {
"Sparse" => Ok(UnionMode::Sparse),
"Dense" => Ok(UnionMode::Dense),
Expand All @@ -306,3 +308,57 @@ fn union_mode_as_str() {
assert_variant!(Dense);
assert_variant!(Sparse);
}

/// The unit of calendar intervals
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum IntervalUnit {
/// An interval as the number of months, stored as `i32`
YearMonth,
/// An interval as the number of days, stored as `i32`, and milliseconds, stored as `i32`
DayTime,
/// An interval as the number of months (stored as `i32`), days (stored as `i32`) and nanoseconds (stored as `i64`)
MonthDayNano,
}

impl std::fmt::Display for IntervalUnit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::YearMonth => write!(f, "YearMonth"),
Self::DayTime => write!(f, "DayTime"),
Self::MonthDayNano => write!(f, "MonthDayNano"),
}
}
}

impl std::str::FromStr for IntervalUnit {
type Err = MarrowError;

fn from_str(s: &str) -> Result<Self> {
match s {
"YearMonth" => Ok(Self::YearMonth),
"DayTime" => Ok(Self::DayTime),
"MonthDayNano" => Ok(Self::MonthDayNano),
s => fail!(ErrorKind::ParseError, "Invalid IntervalUnit: {s}"),
}
}
}

#[test]
fn interval_unit() {
use std::str::FromStr;

macro_rules! assert_variant {
($variant:ident) => {
assert_eq!((IntervalUnit::$variant).to_string(), stringify!($variant));
assert_eq!(
IntervalUnit::from_str(stringify!($variant)).unwrap(),
IntervalUnit::$variant
);
};
}

assert_variant!(YearMonth);
assert_variant!(DayTime);
assert_variant!(MonthDayNano);
}
131 changes: 98 additions & 33 deletions marrow/src/impl_arrow/impl_api_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ use half::f16;

use crate::{
array::Array,
datatypes::{meta_from_field, DataType, Field, FieldMeta, TimeUnit, UnionMode},
datatypes::{meta_from_field, DataType, Field, FieldMeta, IntervalUnit, TimeUnit, UnionMode},
error::{fail, ErrorKind, MarrowError, Result},
view::{
BitsWithOffset, BooleanView, BytesView, DecimalView, DenseUnionView, DictionaryView,
FixedSizeListView, ListView, MapView, NullView, PrimitiveView, StructView, TimeView,
TimestampView, View,
FixedSizeListView, ListView, MapView, NullView, PrimitiveView, SparseUnionView, StructView,
TimeView, TimestampView, View,
},
};

Expand Down Expand Up @@ -54,6 +54,7 @@ impl TryFrom<&arrow_schema::DataType> for DataType {
tz.as_ref().map(|s| s.to_string()),
)),
AT::Duration(unit) => Ok(T::Duration(unit.clone().try_into()?)),
AT::Interval(unit) => Ok(T::Interval(unit.clone().try_into()?)),
AT::Binary => Ok(T::Binary),
AT::LargeBinary => Ok(T::LargeBinary),
AT::FixedSizeBinary(n) => Ok(T::FixedSizeBinary(*n)),
Expand Down Expand Up @@ -136,6 +137,7 @@ impl TryFrom<&DataType> for arrow_schema::DataType {
tz.as_ref().map(|s| s.to_string().into()),
)),
T::Duration(unit) => Ok(AT::Duration((*unit).try_into()?)),
T::Interval(unit) => Ok(AT::Interval((*unit).try_into()?)),
T::Binary => Ok(AT::Binary),
T::LargeBinary => Ok(AT::LargeBinary),
T::FixedSizeBinary(n) => Ok(AT::FixedSizeBinary(*n)),
Expand Down Expand Up @@ -234,7 +236,7 @@ impl TryFrom<UnionMode> for arrow_schema::UnionMode {
}
}

/// Converison to `arrow` arrays (*requires one of the `arrow-{version}` features*)
/// Conversion to `arrow` arrays (*requires one of the `arrow-{version}` features*)
impl TryFrom<Array> for Arc<dyn arrow_array::Array> {
type Error = MarrowError;

Expand All @@ -243,6 +245,32 @@ impl TryFrom<Array> for Arc<dyn arrow_array::Array> {
}
}

/// Conversion from `arrow` interval units (*requires one of the `arrow2-{version}` features*)
impl TryFrom<arrow_schema::IntervalUnit> for IntervalUnit {
type Error = MarrowError;

fn try_from(value: arrow_schema::IntervalUnit) -> Result<Self> {
match value {
arrow_schema::IntervalUnit::YearMonth => Ok(IntervalUnit::YearMonth),
arrow_schema::IntervalUnit::DayTime => Ok(IntervalUnit::DayTime),
arrow_schema::IntervalUnit::MonthDayNano => Ok(IntervalUnit::MonthDayNano),
}
}
}

/// Conversion to `arrow` interval units (*requires one of the `arrow2-{version}` features*)
impl TryFrom<IntervalUnit> for arrow_schema::IntervalUnit {
type Error = MarrowError;

fn try_from(value: IntervalUnit) -> Result<Self> {
match value {
IntervalUnit::YearMonth => Ok(arrow_schema::IntervalUnit::YearMonth),
IntervalUnit::DayTime => Ok(arrow_schema::IntervalUnit::DayTime),
IntervalUnit::MonthDayNano => Ok(arrow_schema::IntervalUnit::MonthDayNano),
}
}
}

fn build_array_data(value: Array) -> Result<arrow_data::ArrayData> {
use Array as A;
type ArrowF16 =
Expand Down Expand Up @@ -466,17 +494,7 @@ fn build_array_data(value: Array) -> Result<arrow_data::ArrayData> {
)?)
}
A::DenseUnion(arr) => {
let mut fields = Vec::new();
let mut child_data = Vec::new();

for (type_id, meta, array) in arr.fields {
let child = build_array_data(array)?;
let field = field_from_data_and_meta(&child, meta);

fields.push((type_id, Arc::new(field)));
child_data.push(child);
}

let (fields, child_data) = union_fields_into_fields_and_data(arr.fields)?;
Ok(arrow_data::ArrayData::try_new(
arrow_schema::DataType::Union(
fields.into_iter().collect(),
Expand All @@ -492,9 +510,43 @@ fn build_array_data(value: Array) -> Result<arrow_data::ArrayData> {
child_data,
)?)
}
A::SparseUnion(arr) => {
let (fields, child_data) = union_fields_into_fields_and_data(arr.fields)?;
Ok(arrow_data::ArrayData::try_new(
arrow_schema::DataType::Union(
fields.into_iter().collect(),
arrow_schema::UnionMode::Sparse,
),
arr.types.len(),
None,
0,
vec![arrow_buffer::ScalarBuffer::from(arr.types).into_inner()],
child_data,
)?)
}
}
}

fn union_fields_into_fields_and_data(
union_fields: Vec<(i8, FieldMeta, Array)>,
) -> Result<(
Vec<(i8, arrow_schema::FieldRef)>,
Vec<arrow_data::ArrayData>,
)> {
let mut fields = Vec::new();
let mut child_data = Vec::new();

for (type_id, meta, array) in union_fields {
let child = build_array_data(array)?;
let field = field_from_data_and_meta(&child, meta);

fields.push((type_id, Arc::new(field)));
child_data.push(child);
}

Ok((fields, child_data))
}

/// Converison from `arrow` arrays (*requires one of the `arrow-{version}` features*)
impl<'a> TryFrom<&'a dyn arrow_array::Array> for View<'a> {
type Error = MarrowError;
Expand Down Expand Up @@ -801,13 +853,8 @@ impl<'a> TryFrom<&'a dyn arrow_array::Array> for View<'a> {
} else if let Some(array) = any.downcast_ref::<arrow_array::UnionArray>() {
use arrow_array::Array;

let arrow_schema::DataType::Union(union_fields, arrow_schema::UnionMode::Dense) =
array.data_type()
else {
fail!(
ErrorKind::Unsupported,
"Invalid data type: only dense unions are supported"
);
let arrow_schema::DataType::Union(union_fields, mode) = array.data_type() else {
fail!(ErrorKind::Unsupported, "Invalid data type for UnionArray");
};

let mut fields = Vec::new();
Expand All @@ -816,18 +863,36 @@ impl<'a> TryFrom<&'a dyn arrow_array::Array> for View<'a> {
let view: View = array.child(type_id).as_ref().try_into()?;
fields.push((type_id, meta, view));
}
let Some(offsets) = array.offsets() else {
fail!(
ErrorKind::Unsupported,
"Dense unions must have an offset array"
);
};

Ok(View::DenseUnion(DenseUnionView {
types: array.type_ids(),
offsets,
fields,
}))
match mode {
arrow_schema::UnionMode::Dense => {
let Some(offsets) = array.offsets() else {
fail!(
ErrorKind::Unsupported,
"Dense unions must have an offset array"
);
};

Ok(View::DenseUnion(DenseUnionView {
types: array.type_ids(),
offsets,
fields,
}))
}
arrow_schema::UnionMode::Sparse => {
if array.offsets().is_some() {
fail!(
ErrorKind::Unsupported,
"Sparse unions must not have an offset array"
);
};

Ok(View::SparseUnion(SparseUnionView {
types: array.type_ids(),
fields,
}))
}
}
} else {
fail!(
ErrorKind::Unsupported,
Expand Down
Loading

0 comments on commit 73077ea

Please sign in to comment.