Skip to content

Commit

Permalink
add unit tests for batch and file writers
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Oct 31, 2024
1 parent 072ae6e commit b67499e
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 0 deletions.
75 changes: 75 additions & 0 deletions src/daft-writers/src/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,78 @@ impl WriterFactory for TargetBatchWriterFactory {
)))
}
}

#[cfg(test)]
mod tests {

use super::*;
use crate::test::{make_dummy_mp, DummyWriterFactory};

#[test]
fn test_target_batch_writer_exact_batch() {
let dummy_writer_factory = DummyWriterFactory;
let mut writer =
TargetBatchWriter::new(1, dummy_writer_factory.create_writer(0, None).unwrap());

let mp = make_dummy_mp(1);
writer.write(&mp).unwrap();
let res = writer.close().unwrap();

assert!(res.is_some());
let write_count = res
.unwrap()
.get_column("write_count")
.unwrap()
.u64()
.unwrap()
.get(0)
.unwrap();
assert_eq!(write_count, 1);
}

#[test]
fn test_target_batch_writer_small_batches() {
let dummy_writer_factory = DummyWriterFactory;
let mut writer =
TargetBatchWriter::new(3, dummy_writer_factory.create_writer(0, None).unwrap());

for _ in 0..8 {
let mp = make_dummy_mp(1);
writer.write(&mp).unwrap();
}
let res = writer.close().unwrap();

assert!(res.is_some());
let write_count = res
.unwrap()
.get_column("write_count")
.unwrap()
.u64()
.unwrap()
.get(0)
.unwrap();
assert_eq!(write_count, 3);
}

#[test]
fn test_target_batch_writer_big_batch() {
let dummy_writer_factory = DummyWriterFactory;
let mut writer =
TargetBatchWriter::new(3, dummy_writer_factory.create_writer(0, None).unwrap());

let mp = make_dummy_mp(10);
writer.write(&mp).unwrap();
let res = writer.close().unwrap();

assert!(res.is_some());
let write_count = res
.unwrap()
.get_column("write_count")
.unwrap()
.u64()
.unwrap()
.get(0)
.unwrap();
assert_eq!(write_count, 4);
}
}
69 changes: 69 additions & 0 deletions src/daft-writers/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,72 @@ impl WriterFactory for TargetFileSizeWriterFactory {
>)
}
}

#[cfg(test)]
mod tests {

use super::*;
use crate::test::{make_dummy_mp, DummyWriterFactory};

#[test]
fn test_target_file_writer_exact_file() {
let dummy_writer_factory = DummyWriterFactory;
let mut writer =
TargetFileSizeWriter::new(1, Arc::new(dummy_writer_factory), None).unwrap();

let mp = make_dummy_mp(1);
writer.write(&mp).unwrap();
let res = writer.close().unwrap();
assert_eq!(res.len(), 1);
}

#[test]
fn test_target_file_writer_less_rows_for_one_file() {
let dummy_writer_factory = DummyWriterFactory;
let mut writer =
TargetFileSizeWriter::new(3, Arc::new(dummy_writer_factory), None).unwrap();

let mp = make_dummy_mp(2);
writer.write(&mp).unwrap();
let res = writer.close().unwrap();
assert_eq!(res.len(), 1);
}

#[test]
fn test_target_file_writer_more_rows_for_one_file() {
let dummy_writer_factory = DummyWriterFactory;
let mut writer =
TargetFileSizeWriter::new(3, Arc::new(dummy_writer_factory), None).unwrap();

let mp = make_dummy_mp(4);
writer.write(&mp).unwrap();
let res = writer.close().unwrap();
assert_eq!(res.len(), 2);
}

#[test]
fn test_target_file_writer_multiple_files() {
let dummy_writer_factory = DummyWriterFactory;
let mut writer =
TargetFileSizeWriter::new(3, Arc::new(dummy_writer_factory), None).unwrap();

let mp = make_dummy_mp(10);
writer.write(&mp).unwrap();
let res = writer.close().unwrap();
assert_eq!(res.len(), 4);
}

#[test]
fn test_target_file_writer_many_writes_many_files() {
let dummy_writer_factory = DummyWriterFactory;
let mut writer =
TargetFileSizeWriter::new(3, Arc::new(dummy_writer_factory), None).unwrap();

for _ in 0..10 {
let mp = make_dummy_mp(1);
writer.write(&mp).unwrap();
}
let res = writer.close().unwrap();
assert_eq!(res.len(), 4);
}
}
3 changes: 3 additions & 0 deletions src/daft-writers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ mod file;
mod partition;
mod physical;

#[cfg(test)]
mod test;

#[cfg(feature = "python")]
mod python;

Expand Down
1 change: 1 addition & 0 deletions src/daft-writers/src/partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{FileWriter, WriterFactory};
/// PartitionedWriter is a writer that partitions the input data by a set of columns, and writes each partition
/// to a separate file. It uses a map to keep track of the writers for each partition.
struct PartitionedWriter {
// TODO: Figure out a way to NOT use the IndexHash + RawEntryMut pattern here. Ideally we want to store ScalarValues, aka. single Rows of the partition values as keys for the hashmap.
per_partition_writers:
HashMap<IndexHash, Box<dyn FileWriter<Input = Arc<MicroPartition>, Result = Vec<Table>>>>,
saved_partition_values: Vec<Table>,
Expand Down
84 changes: 84 additions & 0 deletions src/daft-writers/src/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use std::sync::Arc;

use common_error::DaftResult;
use daft_core::{
prelude::{Int64Array, Schema, UInt64Array, Utf8Array},
series::IntoSeries,
};
use daft_micropartition::MicroPartition;
use daft_table::Table;

use crate::{FileWriter, WriterFactory};

pub(crate) struct DummyWriterFactory;

impl WriterFactory for DummyWriterFactory {
type Input = Arc<MicroPartition>;
type Result = Option<Table>;

fn create_writer(
&self,
file_idx: usize,
partition_values: Option<&Table>,
) -> DaftResult<Box<dyn FileWriter<Input = Self::Input, Result = Self::Result>>> {
Ok(Box::new(DummyWriter {
file_idx: file_idx.to_string(),
partition_values: partition_values.cloned(),
write_count: 0,
})
as Box<
dyn FileWriter<Input = Self::Input, Result = Self::Result>,
>)
}
}

pub(crate) struct DummyWriter {
file_idx: String,
partition_values: Option<Table>,
write_count: usize,
}

impl FileWriter for DummyWriter {
type Input = Arc<MicroPartition>;
type Result = Option<Table>;

fn write(&mut self, _input: &Self::Input) -> DaftResult<()> {
self.write_count += 1;
Ok(())
}

fn close(&mut self) -> DaftResult<Self::Result> {
let path_series =
Utf8Array::from_values("path", std::iter::once(self.file_idx.clone())).into_series();
let write_count_series =
UInt64Array::from_values("write_count", std::iter::once(self.write_count as u64))
.into_series();
let path_table = Table::new_unchecked(
Schema::new(vec![
path_series.field().clone(),
write_count_series.field().clone(),
])
.unwrap(),
vec![path_series.into(), write_count_series.into()],
1,
);
if let Some(partition_values) = self.partition_values.take() {
let unioned = path_table.union(&partition_values)?;
Ok(Some(unioned))
} else {
Ok(Some(path_table))
}
}
}

pub(crate) fn make_dummy_mp(num_rows: usize) -> Arc<MicroPartition> {
let series =
Int64Array::from_values("ints", std::iter::repeat(42).take(num_rows)).into_series();
let schema = Arc::new(Schema::new(vec![series.field().clone()]).unwrap());
let table = Table::new_unchecked(schema.clone(), vec![series.into()], num_rows);
Arc::new(MicroPartition::new_loaded(
schema.into(),
vec![table].into(),
None,
))
}

0 comments on commit b67499e

Please sign in to comment.