Skip to content

Commit

Permalink
Extend changes to Optional growable_validities to struct and FSL grow…
Browse files Browse the repository at this point in the history
…able impls
  • Loading branch information
Jay Chia committed Sep 6, 2023
1 parent 88b2ab5 commit 289db08
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 18 deletions.
31 changes: 22 additions & 9 deletions src/daft-core/src/array/growable/fixed_size_list_growable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub struct FixedSizeListGrowable<'a> {
dtype: DataType,
element_fixed_len: usize,
child_growable: Box<dyn Growable + 'a>,
growable_validity: ArrowBitmapGrowable<'a>,
growable_validity: Option<ArrowBitmapGrowable<'a>>,
}

impl<'a> FixedSizeListGrowable<'a> {
Expand All @@ -33,10 +33,15 @@ impl<'a> FixedSizeListGrowable<'a> {
use_validity,
capacity * element_fixed_len,
);
let growable_validity = ArrowBitmapGrowable::new(
arrays.iter().map(|a| a.validity()).collect(),
capacity,
);
let growable_validity =
if use_validity || arrays.iter().any(|arr| arr.validity().is_some()) {
Some(ArrowBitmapGrowable::new(
arrays.iter().map(|a| a.validity()).collect(),
capacity,
))
} else {
None
};
Self {
name: name.to_string(),
dtype: dtype.clone(),
Expand All @@ -57,24 +62,32 @@ impl<'a> Growable for FixedSizeListGrowable<'a> {
start * self.element_fixed_len,
len * self.element_fixed_len,
);
self.growable_validity.extend(index, start, len);

match &mut self.growable_validity {
Some(growable_validity) => growable_validity.extend(index, start, len),
None => (),
}
}

fn add_nulls(&mut self, additional: usize) {
self.child_growable
.add_nulls(additional * self.element_fixed_len);
self.growable_validity.add_nulls(additional);

match &mut self.growable_validity {
Some(growable_validity) => growable_validity.add_nulls(additional),
None => (),
}
}

fn build(&mut self) -> DaftResult<Series> {
let grown_validity = std::mem::take(&mut self.growable_validity);

let built_child = self.child_growable.build()?;
let built_validity = grown_validity.build();
let built_validity = grown_validity.map(|v| v.build());
Ok(FixedSizeListArray::new(
Field::new(self.name.clone(), self.dtype.clone()),
built_child,
Some(built_validity),
built_validity,
)
.into_series())
}
Expand Down
31 changes: 22 additions & 9 deletions src/daft-core/src/array/growable/struct_growable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub struct StructGrowable<'a> {
name: String,
dtype: DataType,
children_growables: Vec<Box<dyn Growable + 'a>>,
growable_validity: ArrowBitmapGrowable<'a>,
growable_validity: Option<ArrowBitmapGrowable<'a>>,
}

impl<'a> StructGrowable<'a> {
Expand Down Expand Up @@ -41,10 +41,15 @@ impl<'a> StructGrowable<'a> {
)
})
.collect::<Vec<_>>();
let growable_validity = ArrowBitmapGrowable::new(
arrays.iter().map(|a| a.validity()).collect(),
capacity,
);
let growable_validity =
if use_validity || arrays.iter().any(|arr| arr.validity().is_some()) {
Some(ArrowBitmapGrowable::new(
arrays.iter().map(|a| a.validity()).collect(),
capacity,
))
} else {
None
};
Self {
name: name.to_string(),
dtype: dtype.clone(),
Expand All @@ -62,14 +67,22 @@ impl<'a> Growable for StructGrowable<'a> {
for child_growable in &mut self.children_growables {
child_growable.extend(index, start, len)
}
self.growable_validity.extend(index, start, len);

match &mut self.growable_validity {
Some(growable_validity) => growable_validity.extend(index, start, len),
None => (),
}
}

fn add_nulls(&mut self, additional: usize) {
for child_growable in &mut self.children_growables {
child_growable.add_nulls(additional);
}
self.growable_validity.add_nulls(additional);

match &mut self.growable_validity {
Some(growable_validity) => growable_validity.add_nulls(additional),
None => (),
}
}

fn build(&mut self) -> DaftResult<Series> {
Expand All @@ -80,11 +93,11 @@ impl<'a> Growable for StructGrowable<'a> {
.iter_mut()
.map(|cg| cg.build())
.collect::<DaftResult<Vec<_>>>()?;
let built_validity = grown_validity.build();
let built_validity = grown_validity.map(|v| v.build());
Ok(StructArray::new(
Field::new(self.name.clone(), self.dtype.clone()),
built_children,
Some(built_validity),
built_validity,
)
.into_series())
}
Expand Down

0 comments on commit 289db08

Please sign in to comment.