Skip to content

Commit

Permalink
[BUG] Add an allowlist of DataTypes that ColumnRangeStatistics suppor…
Browse files Browse the repository at this point in the history
…ts and validation of TableStatistics (#1632)

1. We should disallow creation of ColumnRangeStatistics from
non-comparable types to avoid issues at runtime
2. We also add validation when creating MicroPartitions:
* The column names in a MicroPartition's schema must be found in its
ScanTask's schema
* When creating Statistics for a MicroPartition, we cast those
Statistics to the MicroPartition's schema to ensure type compatibility

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Nov 27, 2023
1 parent 66a8269 commit 12bd499
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 101 deletions.
92 changes: 59 additions & 33 deletions src/daft-micropartition/src/micropartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,38 +275,63 @@ fn materialize_scan_task(
}

impl MicroPartition {
/// Create a new "unloaded" MicroPartition using an associated [`ScanTask`]
///
/// Schema invariants:
/// 1. All columns in `schema` must be exist in the `scan_task` schema
/// 2. Each Loaded column statistic in `statistics` must be castable to the corresponding column in the MicroPartition's schema
pub fn new_unloaded(
schema: SchemaRef,
scan_task: Arc<ScanTask>,
metadata: TableMetadata,
statistics: TableStatistics,
) -> Self {
if statistics.columns.len() != schema.fields.len() {
panic!("MicroPartition: TableStatistics and Schema have differing lengths")
}
if !statistics
.columns
.keys()
.zip(schema.fields.keys())
.all(|(l, r)| l == r)
{
panic!("MicroPartition: TableStatistics and Schema have different column names\nTableStats:\n{},\nSchema\n{}", statistics, schema);
}
assert!(
schema
.fields
.keys()
.collect::<HashSet<_>>()
.is_subset(&scan_task.schema.fields.keys().collect::<HashSet<_>>()),
"Unloaded MicroPartition's schema names must be a subset of its ScanTask's schema"
);

MicroPartition {
schema,
schema: schema.clone(),
state: Mutex::new(TableState::Unloaded(scan_task)),
metadata,
statistics: Some(statistics),
statistics: Some(
statistics
.cast_to_schema(schema)
.expect("Statistics cannot be casted to schema"),
),
}
}

/// Create a new "loaded" MicroPartition using the materialized tables
///
/// Schema invariants:
/// 1. `schema` must match each Table's schema exactly
/// 2. If `statistics` is provided, each Loaded column statistic must be castable to the corresponding column in the MicroPartition's schema
pub fn new_loaded(
schema: SchemaRef,
tables: Arc<Vec<Table>>,
statistics: Option<TableStatistics>,
) -> Self {
// Check and validate invariants with asserts
for table in tables.iter() {
assert!(
table.schema == schema,
"Loaded MicroPartition's tables' schema must match its own schema exactly"
);
}

let statistics = statistics.map(|stats| {
stats
.cast_to_schema(schema.clone())
.expect("Statistics cannot be casted to schema")
});
let tables_len_sum = tables.iter().map(|t| t.len()).sum();

MicroPartition {
schema,
state: Mutex::new(TableState::Loaded(tables)),
Expand Down Expand Up @@ -356,6 +381,7 @@ impl MicroPartition {
.map(|cols| cols.iter().map(|s| s.as_str()).collect::<Vec<&str>>());

let row_groups = parquet_sources_to_row_groups(scan_task.sources.as_slice());

read_parquet_into_micropartition(
uris.as_slice(),
columns.as_deref(),
Expand Down Expand Up @@ -612,15 +638,24 @@ pub(crate) fn read_parquet_into_micropartition(
return Err(common_error::DaftError::ValueError("Micropartition Parquet Reader does not support non-zero start offsets".to_string()));
}

// Run the required I/O to retrieve all the Parquet FileMetaData
let runtime_handle = daft_io::get_runtime(multithreaded_io)?;
let io_client = daft_io::get_io_client(multithreaded_io, io_config.clone())?;

let meta_io_client = io_client.clone();
let meta_io_stats = io_stats.clone();

let metadata = runtime_handle.block_on(async move {
read_parquet_metadata_bulk(uris, meta_io_client, meta_io_stats).await
})?;

// Deserialize and collect relevant TableStatistics
let schemas = metadata
.iter()
.map(|m| {
let schema = infer_schema_with_options(m, &Some((*schema_infer_options).into()))?;
let daft_schema = daft_core::schema::Schema::try_from(&schema)?;
DaftResult::Ok(daft_schema)
})
.collect::<DaftResult<Vec<_>>>()?;
let any_stats_avail = metadata
.iter()
.flat_map(|m| m.row_groups.iter())
Expand All @@ -629,31 +664,22 @@ pub(crate) fn read_parquet_into_micropartition(
let stats = if any_stats_avail {
let stat_per_table = metadata
.iter()
.flat_map(|fm| {
.zip(schemas.iter())
.flat_map(|(fm, schema)| {
fm.row_groups
.iter()
.map(daft_parquet::row_group_metadata_to_table_stats)
.map(|rgm| daft_parquet::row_group_metadata_to_table_stats(rgm, schema))
})
.collect::<DaftResult<Vec<TableStatistics>>>()?;
stat_per_table.into_iter().try_reduce(|a, b| a.union(&b))?
} else {
None
};

let schemas = metadata
.iter()
.map(|m| {
let schema = infer_schema_with_options(m, &Some((*schema_infer_options).into()))?;
let daft_schema = daft_core::schema::Schema::try_from(&schema)?;
DaftResult::Ok(daft_schema)
})
.collect::<DaftResult<Vec<_>>>()?;

// Union and prune the schema using the specified `columns`
let unioned_schema = schemas.into_iter().try_reduce(|l, r| l.union(&r))?;

let daft_schema = unioned_schema.expect("we need at least 1 schema");

let daft_schema = prune_fields_from_schema(daft_schema, columns)?;
let full_daft_schema = unioned_schema.expect("we need at least 1 schema");
let pruned_daft_schema = prune_fields_from_schema(full_daft_schema, columns)?;

// Get total number of rows, accounting for selected `row_groups` and the indicated `num_rows`
let total_rows_no_limit = match &row_groups {
Expand All @@ -677,7 +703,7 @@ pub(crate) fn read_parquet_into_micropartition(
if let Some(stats) = stats {
let owned_urls = uris.iter().map(|s| s.to_string()).collect::<Vec<_>>();

let daft_schema = Arc::new(daft_schema);
let daft_schema = Arc::new(pruned_daft_schema);
let size_bytes = metadata
.iter()
.map(|m| -> u64 {
Expand Down Expand Up @@ -750,10 +776,10 @@ pub(crate) fn read_parquet_into_micropartition(
)?;
let all_tables = all_tables
.into_iter()
.map(|t| t.cast_to_schema(&daft_schema))
.map(|t| t.cast_to_schema(&pruned_daft_schema))
.collect::<DaftResult<Vec<_>>>()?;
Ok(MicroPartition::new_loaded(
Arc::new(daft_schema),
Arc::new(pruned_daft_schema),
all_tables.into(),
None,
))
Expand Down
93 changes: 53 additions & 40 deletions src/daft-parquet/src/statistics/column_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use daft_core::{
logical::{DateArray, Decimal128Array, TimestampArray},
BinaryArray, BooleanArray, Int128Array, Int32Array, Int64Array, Utf8Array,
},
IntoSeries, Series,
DataType, IntoSeries, Series,
};
use daft_stats::ColumnRangeStatistics;
use parquet2::{
Expand All @@ -15,7 +15,7 @@ use parquet2::{
};
use snafu::{OptionExt, ResultExt};

use super::{MissingParquetColumnStatisticsSnafu, Wrap};
use super::{DaftStatsSnafu, MissingParquetColumnStatisticsSnafu, Wrap};

use super::utils::*;
use super::UnableToParseUtf8FromBinarySnafu;
Expand Down Expand Up @@ -392,43 +392,56 @@ fn convert_int96_column_range_statistics(
Ok(ColumnRangeStatistics::Missing)
}

impl TryFrom<&dyn Statistics> for Wrap<ColumnRangeStatistics> {
type Error = super::Error;
pub(crate) fn parquet_statistics_to_column_range_statistics(
pq_stats: &dyn Statistics,
daft_dtype: &DataType,
) -> Result<ColumnRangeStatistics, super::Error> {
// Create ColumnRangeStatistics containing Series objects that are the **physical** types parsed from Parquet
let ptype = pq_stats.physical_type();
let stats = pq_stats.as_any();
let daft_stats = match ptype {
PhysicalType::Boolean => stats
.downcast_ref::<BooleanStatistics>()
.unwrap()
.try_into()
.map(|wrap: Wrap<ColumnRangeStatistics>| wrap.0),
PhysicalType::Int32 => stats
.downcast_ref::<PrimitiveStatistics<i32>>()
.unwrap()
.try_into()
.map(|wrap: Wrap<ColumnRangeStatistics>| wrap.0),
PhysicalType::Int64 => stats
.downcast_ref::<PrimitiveStatistics<i64>>()
.unwrap()
.try_into()
.map(|wrap: Wrap<ColumnRangeStatistics>| wrap.0),
PhysicalType::Int96 => Ok(convert_int96_column_range_statistics(
stats
.downcast_ref::<PrimitiveStatistics<[u32; 3]>>()
.unwrap(),
)?),
PhysicalType::Float => stats
.downcast_ref::<PrimitiveStatistics<f32>>()
.unwrap()
.try_into()
.map(|wrap: Wrap<ColumnRangeStatistics>| wrap.0),
PhysicalType::Double => stats
.downcast_ref::<PrimitiveStatistics<f64>>()
.unwrap()
.try_into()
.map(|wrap: Wrap<ColumnRangeStatistics>| wrap.0),
PhysicalType::ByteArray => stats
.downcast_ref::<BinaryStatistics>()
.unwrap()
.try_into()
.map(|wrap: Wrap<ColumnRangeStatistics>| wrap.0),
PhysicalType::FixedLenByteArray(_) => stats
.downcast_ref::<FixedLenStatistics>()
.unwrap()
.try_into()
.map(|wrap: Wrap<ColumnRangeStatistics>| wrap.0),
};

fn try_from(value: &dyn Statistics) -> Result<Self, Self::Error> {
let ptype = value.physical_type();
let stats = value.as_any();
match ptype {
PhysicalType::Boolean => stats
.downcast_ref::<BooleanStatistics>()
.unwrap()
.try_into(),
PhysicalType::Int32 => stats
.downcast_ref::<PrimitiveStatistics<i32>>()
.unwrap()
.try_into(),
PhysicalType::Int64 => stats
.downcast_ref::<PrimitiveStatistics<i64>>()
.unwrap()
.try_into(),
PhysicalType::Int96 => Ok(Wrap(convert_int96_column_range_statistics(
stats
.downcast_ref::<PrimitiveStatistics<[u32; 3]>>()
.unwrap(),
)?)),
PhysicalType::Float => stats
.downcast_ref::<PrimitiveStatistics<f32>>()
.unwrap()
.try_into(),
PhysicalType::Double => stats
.downcast_ref::<PrimitiveStatistics<f64>>()
.unwrap()
.try_into(),
PhysicalType::ByteArray => stats.downcast_ref::<BinaryStatistics>().unwrap().try_into(),
PhysicalType::FixedLenByteArray(_) => stats
.downcast_ref::<FixedLenStatistics>()
.unwrap()
.try_into(),
}
}
// Cast to ensure that the ColumnRangeStatistics now contain the targeted Daft **logical** type
daft_stats.and_then(|s| s.cast(daft_dtype).context(DaftStatsSnafu))
}
66 changes: 40 additions & 26 deletions src/daft-parquet/src/statistics/table_stats.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,51 @@
use common_error::DaftResult;
use daft_core::schema::Schema;
use daft_stats::{ColumnRangeStatistics, TableStatistics};
use snafu::ResultExt;

use super::Wrap;
use super::column_range::parquet_statistics_to_column_range_statistics;

use indexmap::IndexMap;

impl TryFrom<&crate::metadata::RowGroupMetaData> for Wrap<TableStatistics> {
type Error = super::Error;
fn try_from(value: &crate::metadata::RowGroupMetaData) -> Result<Self, Self::Error> {
let _num_rows = value.num_rows();
let mut columns = IndexMap::new();
for col in value.columns() {
let stats = col
.statistics()
.transpose()
.context(super::UnableToParseParquetColumnStatisticsSnafu)?;
let col_stats: Option<Wrap<ColumnRangeStatistics>> =
stats.and_then(|v| v.as_ref().try_into().ok());
let col_stats = col_stats.unwrap_or(ColumnRangeStatistics::Missing.into());
columns.insert(
col.descriptor().path_in_schema.get(0).unwrap().clone(),
col_stats.0,
);
}

Ok(TableStatistics { columns }.into())
}
}

pub fn row_group_metadata_to_table_stats(
metadata: &crate::metadata::RowGroupMetaData,
schema: &Schema,
) -> DaftResult<TableStatistics> {
let result = Wrap::<TableStatistics>::try_from(metadata)?;
Ok(result.0)
// Create a map from {field_name: statistics} from the RowGroupMetaData for easy access
let mut parquet_column_metadata: IndexMap<_, _> = metadata
.columns()
.iter()
.map(|col| {
let top_level_column_name = col
.descriptor()
.path_in_schema
.first()
.expect("Parquet schema should have at least one entry in path_in_schema");
(top_level_column_name, col.statistics())
})
.collect();

// Iterate through the schema and construct ColumnRangeStatistics per field
let columns = schema
.fields
.iter()
.map(|(field_name, field)| {
if ColumnRangeStatistics::supports_dtype(&field.dtype) {
let stats: ColumnRangeStatistics = parquet_column_metadata
.remove(field_name)
.expect("Cannot find parsed Daft field in Parquet rowgroup metadata")
.transpose()
.context(super::UnableToParseParquetColumnStatisticsSnafu)?
.and_then(|v| {
parquet_statistics_to_column_range_statistics(v.as_ref(), &field.dtype).ok()
})
.unwrap_or(ColumnRangeStatistics::Missing);
Ok((field_name.clone(), stats))
} else {
Ok((field_name.clone(), ColumnRangeStatistics::Missing))
}
})
.collect::<DaftResult<IndexMap<_, _>>>()?;

Ok(TableStatistics { columns })
}
Loading

0 comments on commit 12bd499

Please sign in to comment.