diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index b53af50908..73a88ffeba 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -275,38 +275,63 @@ fn materialize_scan_task( } impl MicroPartition { + /// Create a new "unloaded" MicroPartition using an associated [`ScanTask`] + /// + /// Schema invariants: + /// 1. All columns in `schema` must be exist in the `scan_task` schema + /// 2. Each Loaded column statistic in `statistics` must be castable to the corresponding column in the MicroPartition's schema pub fn new_unloaded( schema: SchemaRef, scan_task: Arc, metadata: TableMetadata, statistics: TableStatistics, ) -> Self { - if statistics.columns.len() != schema.fields.len() { - panic!("MicroPartition: TableStatistics and Schema have differing lengths") - } - if !statistics - .columns - .keys() - .zip(schema.fields.keys()) - .all(|(l, r)| l == r) - { - panic!("MicroPartition: TableStatistics and Schema have different column names\nTableStats:\n{},\nSchema\n{}", statistics, schema); - } + assert!( + schema + .fields + .keys() + .collect::>() + .is_subset(&scan_task.schema.fields.keys().collect::>()), + "Unloaded MicroPartition's schema names must be a subset of its ScanTask's schema" + ); MicroPartition { - schema, + schema: schema.clone(), state: Mutex::new(TableState::Unloaded(scan_task)), metadata, - statistics: Some(statistics), + statistics: Some( + statistics + .cast_to_schema(schema) + .expect("Statistics cannot be casted to schema"), + ), } } + /// Create a new "loaded" MicroPartition using the materialized tables + /// + /// Schema invariants: + /// 1. `schema` must match each Table's schema exactly + /// 2. If `statistics` is provided, each Loaded column statistic must be castable to the corresponding column in the MicroPartition's schema pub fn new_loaded( schema: SchemaRef, tables: Arc>, statistics: Option, ) -> Self { + // Check and validate invariants with asserts + for table in tables.iter() { + assert!( + table.schema == schema, + "Loaded MicroPartition's tables' schema must match its own schema exactly" + ); + } + + let statistics = statistics.map(|stats| { + stats + .cast_to_schema(schema.clone()) + .expect("Statistics cannot be casted to schema") + }); let tables_len_sum = tables.iter().map(|t| t.len()).sum(); + MicroPartition { schema, state: Mutex::new(TableState::Loaded(tables)), @@ -356,6 +381,7 @@ impl MicroPartition { .map(|cols| cols.iter().map(|s| s.as_str()).collect::>()); let row_groups = parquet_sources_to_row_groups(scan_task.sources.as_slice()); + read_parquet_into_micropartition( uris.as_slice(), columns.as_deref(), @@ -612,15 +638,24 @@ pub(crate) fn read_parquet_into_micropartition( return Err(common_error::DaftError::ValueError("Micropartition Parquet Reader does not support non-zero start offsets".to_string())); } + // Run the required I/O to retrieve all the Parquet FileMetaData let runtime_handle = daft_io::get_runtime(multithreaded_io)?; let io_client = daft_io::get_io_client(multithreaded_io, io_config.clone())?; - let meta_io_client = io_client.clone(); let meta_io_stats = io_stats.clone(); - let metadata = runtime_handle.block_on(async move { read_parquet_metadata_bulk(uris, meta_io_client, meta_io_stats).await })?; + + // Deserialize and collect relevant TableStatistics + let schemas = metadata + .iter() + .map(|m| { + let schema = infer_schema_with_options(m, &Some((*schema_infer_options).into()))?; + let daft_schema = daft_core::schema::Schema::try_from(&schema)?; + DaftResult::Ok(daft_schema) + }) + .collect::>>()?; let any_stats_avail = metadata .iter() .flat_map(|m| m.row_groups.iter()) @@ -629,10 +664,11 @@ pub(crate) fn read_parquet_into_micropartition( let stats = if any_stats_avail { let stat_per_table = metadata .iter() - .flat_map(|fm| { + .zip(schemas.iter()) + .flat_map(|(fm, schema)| { fm.row_groups .iter() - .map(daft_parquet::row_group_metadata_to_table_stats) + .map(|rgm| daft_parquet::row_group_metadata_to_table_stats(rgm, schema)) }) .collect::>>()?; stat_per_table.into_iter().try_reduce(|a, b| a.union(&b))? @@ -640,20 +676,10 @@ pub(crate) fn read_parquet_into_micropartition( None }; - let schemas = metadata - .iter() - .map(|m| { - let schema = infer_schema_with_options(m, &Some((*schema_infer_options).into()))?; - let daft_schema = daft_core::schema::Schema::try_from(&schema)?; - DaftResult::Ok(daft_schema) - }) - .collect::>>()?; - + // Union and prune the schema using the specified `columns` let unioned_schema = schemas.into_iter().try_reduce(|l, r| l.union(&r))?; - - let daft_schema = unioned_schema.expect("we need at least 1 schema"); - - let daft_schema = prune_fields_from_schema(daft_schema, columns)?; + let full_daft_schema = unioned_schema.expect("we need at least 1 schema"); + let pruned_daft_schema = prune_fields_from_schema(full_daft_schema, columns)?; // Get total number of rows, accounting for selected `row_groups` and the indicated `num_rows` let total_rows_no_limit = match &row_groups { @@ -677,7 +703,7 @@ pub(crate) fn read_parquet_into_micropartition( if let Some(stats) = stats { let owned_urls = uris.iter().map(|s| s.to_string()).collect::>(); - let daft_schema = Arc::new(daft_schema); + let daft_schema = Arc::new(pruned_daft_schema); let size_bytes = metadata .iter() .map(|m| -> u64 { @@ -750,10 +776,10 @@ pub(crate) fn read_parquet_into_micropartition( )?; let all_tables = all_tables .into_iter() - .map(|t| t.cast_to_schema(&daft_schema)) + .map(|t| t.cast_to_schema(&pruned_daft_schema)) .collect::>>()?; Ok(MicroPartition::new_loaded( - Arc::new(daft_schema), + Arc::new(pruned_daft_schema), all_tables.into(), None, )) diff --git a/src/daft-parquet/src/statistics/column_range.rs b/src/daft-parquet/src/statistics/column_range.rs index a58daa725c..4b386c9cef 100644 --- a/src/daft-parquet/src/statistics/column_range.rs +++ b/src/daft-parquet/src/statistics/column_range.rs @@ -4,7 +4,7 @@ use daft_core::{ logical::{DateArray, Decimal128Array, TimestampArray}, BinaryArray, BooleanArray, Int128Array, Int32Array, Int64Array, Utf8Array, }, - IntoSeries, Series, + DataType, IntoSeries, Series, }; use daft_stats::ColumnRangeStatistics; use parquet2::{ @@ -15,7 +15,7 @@ use parquet2::{ }; use snafu::{OptionExt, ResultExt}; -use super::{MissingParquetColumnStatisticsSnafu, Wrap}; +use super::{DaftStatsSnafu, MissingParquetColumnStatisticsSnafu, Wrap}; use super::utils::*; use super::UnableToParseUtf8FromBinarySnafu; @@ -392,43 +392,56 @@ fn convert_int96_column_range_statistics( Ok(ColumnRangeStatistics::Missing) } -impl TryFrom<&dyn Statistics> for Wrap { - type Error = super::Error; +pub(crate) fn parquet_statistics_to_column_range_statistics( + pq_stats: &dyn Statistics, + daft_dtype: &DataType, +) -> Result { + // Create ColumnRangeStatistics containing Series objects that are the **physical** types parsed from Parquet + let ptype = pq_stats.physical_type(); + let stats = pq_stats.as_any(); + let daft_stats = match ptype { + PhysicalType::Boolean => stats + .downcast_ref::() + .unwrap() + .try_into() + .map(|wrap: Wrap| wrap.0), + PhysicalType::Int32 => stats + .downcast_ref::>() + .unwrap() + .try_into() + .map(|wrap: Wrap| wrap.0), + PhysicalType::Int64 => stats + .downcast_ref::>() + .unwrap() + .try_into() + .map(|wrap: Wrap| wrap.0), + PhysicalType::Int96 => Ok(convert_int96_column_range_statistics( + stats + .downcast_ref::>() + .unwrap(), + )?), + PhysicalType::Float => stats + .downcast_ref::>() + .unwrap() + .try_into() + .map(|wrap: Wrap| wrap.0), + PhysicalType::Double => stats + .downcast_ref::>() + .unwrap() + .try_into() + .map(|wrap: Wrap| wrap.0), + PhysicalType::ByteArray => stats + .downcast_ref::() + .unwrap() + .try_into() + .map(|wrap: Wrap| wrap.0), + PhysicalType::FixedLenByteArray(_) => stats + .downcast_ref::() + .unwrap() + .try_into() + .map(|wrap: Wrap| wrap.0), + }; - fn try_from(value: &dyn Statistics) -> Result { - let ptype = value.physical_type(); - let stats = value.as_any(); - match ptype { - PhysicalType::Boolean => stats - .downcast_ref::() - .unwrap() - .try_into(), - PhysicalType::Int32 => stats - .downcast_ref::>() - .unwrap() - .try_into(), - PhysicalType::Int64 => stats - .downcast_ref::>() - .unwrap() - .try_into(), - PhysicalType::Int96 => Ok(Wrap(convert_int96_column_range_statistics( - stats - .downcast_ref::>() - .unwrap(), - )?)), - PhysicalType::Float => stats - .downcast_ref::>() - .unwrap() - .try_into(), - PhysicalType::Double => stats - .downcast_ref::>() - .unwrap() - .try_into(), - PhysicalType::ByteArray => stats.downcast_ref::().unwrap().try_into(), - PhysicalType::FixedLenByteArray(_) => stats - .downcast_ref::() - .unwrap() - .try_into(), - } - } + // Cast to ensure that the ColumnRangeStatistics now contain the targeted Daft **logical** type + daft_stats.and_then(|s| s.cast(daft_dtype).context(DaftStatsSnafu)) } diff --git a/src/daft-parquet/src/statistics/table_stats.rs b/src/daft-parquet/src/statistics/table_stats.rs index b138343b63..25200fdfed 100644 --- a/src/daft-parquet/src/statistics/table_stats.rs +++ b/src/daft-parquet/src/statistics/table_stats.rs @@ -1,37 +1,51 @@ use common_error::DaftResult; +use daft_core::schema::Schema; use daft_stats::{ColumnRangeStatistics, TableStatistics}; use snafu::ResultExt; -use super::Wrap; +use super::column_range::parquet_statistics_to_column_range_statistics; use indexmap::IndexMap; -impl TryFrom<&crate::metadata::RowGroupMetaData> for Wrap { - type Error = super::Error; - fn try_from(value: &crate::metadata::RowGroupMetaData) -> Result { - let _num_rows = value.num_rows(); - let mut columns = IndexMap::new(); - for col in value.columns() { - let stats = col - .statistics() - .transpose() - .context(super::UnableToParseParquetColumnStatisticsSnafu)?; - let col_stats: Option> = - stats.and_then(|v| v.as_ref().try_into().ok()); - let col_stats = col_stats.unwrap_or(ColumnRangeStatistics::Missing.into()); - columns.insert( - col.descriptor().path_in_schema.get(0).unwrap().clone(), - col_stats.0, - ); - } - - Ok(TableStatistics { columns }.into()) - } -} - pub fn row_group_metadata_to_table_stats( metadata: &crate::metadata::RowGroupMetaData, + schema: &Schema, ) -> DaftResult { - let result = Wrap::::try_from(metadata)?; - Ok(result.0) + // Create a map from {field_name: statistics} from the RowGroupMetaData for easy access + let mut parquet_column_metadata: IndexMap<_, _> = metadata + .columns() + .iter() + .map(|col| { + let top_level_column_name = col + .descriptor() + .path_in_schema + .first() + .expect("Parquet schema should have at least one entry in path_in_schema"); + (top_level_column_name, col.statistics()) + }) + .collect(); + + // Iterate through the schema and construct ColumnRangeStatistics per field + let columns = schema + .fields + .iter() + .map(|(field_name, field)| { + if ColumnRangeStatistics::supports_dtype(&field.dtype) { + let stats: ColumnRangeStatistics = parquet_column_metadata + .remove(field_name) + .expect("Cannot find parsed Daft field in Parquet rowgroup metadata") + .transpose() + .context(super::UnableToParseParquetColumnStatisticsSnafu)? + .and_then(|v| { + parquet_statistics_to_column_range_statistics(v.as_ref(), &field.dtype).ok() + }) + .unwrap_or(ColumnRangeStatistics::Missing); + Ok((field_name.clone(), stats)) + } else { + Ok((field_name.clone(), ColumnRangeStatistics::Missing)) + } + }) + .collect::>>()?; + + Ok(TableStatistics { columns }) } diff --git a/src/daft-stats/src/column_stats/mod.rs b/src/daft-stats/src/column_stats/mod.rs index c6eb64bcce..f76496a47c 100644 --- a/src/daft-stats/src/column_stats/mod.rs +++ b/src/daft-stats/src/column_stats/mod.rs @@ -7,7 +7,7 @@ use std::string::FromUtf8Error; use daft_core::{ array::ops::full::FullNull, datatypes::{BooleanArray, NullArray}, - IntoSeries, Series, + DataType, IntoSeries, Series, }; use snafu::{ResultExt, Snafu}; @@ -44,12 +44,38 @@ impl ColumnRangeStatistics { assert_eq!(l.len(), 1); assert_eq!(u.len(), 1); assert_eq!(l.data_type(), u.data_type()); + assert!(ColumnRangeStatistics::supports_dtype(l.data_type())); Ok(ColumnRangeStatistics::Loaded(l, u)) } _ => Ok(ColumnRangeStatistics::Missing), } } + pub fn supports_dtype(dtype: &DataType) -> bool { + match dtype { + // SUPPORTED TYPES: + // Null + DataType::Null | + + // Numeric types + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::Int128 | + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 | + DataType::Float32 | DataType::Float64 | DataType::Decimal128(..) | DataType::Boolean | + + // String types + DataType::Utf8 | DataType::Binary | + + // Temporal types + DataType::Date | DataType::Time(..) | DataType::Timestamp(..) | DataType::Duration(..) => true, + + // UNSUPPORTED TYPES: + // Types that don't support comparisons and can't be used as ColumnRangeStatistics + DataType::List(..) | DataType::FixedSizeList(..) | DataType::Image(..) | DataType::FixedShapeImage(..) | DataType::Tensor(..) | DataType::FixedShapeTensor(..) | DataType::Struct(..) | DataType::Extension(..) | DataType::Embedding(..) | DataType::Unknown => false, + #[cfg(feature = "python")] + DataType::Python => false, + } + } + pub fn to_truth_value(&self) -> TruthValue { match self { Self::Missing => TruthValue::Maybe, @@ -116,6 +142,59 @@ impl ColumnRangeStatistics { let _num_bytes = series.size_bytes().unwrap(); Self::Loaded(lower, upper) } + + /// Casts the internal [`Series`] objects to the specified DataType + pub fn cast(&self, dtype: &DataType) -> crate::Result { + match self { + // `Missing` is casted to `Missing` + ColumnRangeStatistics::Missing => Ok(ColumnRangeStatistics::Missing), + + // If the type to cast to matches the current type exactly, short-circuit the logic here. This should be the + // most common case (e.g. parsing a Parquet file with the same types as the inferred types) + ColumnRangeStatistics::Loaded(l, r) if l.data_type() == dtype => { + Ok(ColumnRangeStatistics::Loaded(l.clone(), r.clone())) + } + + // Only certain types are allowed to be casted in the context of ColumnRangeStatistics + // as casting may not correctly preserve ordering of elements. We allow-list some type combinations + // but for most combinations, we will default to `ColumnRangeStatistics::Missing`. + ColumnRangeStatistics::Loaded(l, r) => { + match (l.data_type(), dtype) { + // Int casting to higher bitwidths + (DataType::Int8, DataType::Int16) | + (DataType::Int8, DataType::Int32) | + (DataType::Int8, DataType::Int64) | + (DataType::Int16, DataType::Int32) | + (DataType::Int16, DataType::Int64) | + (DataType::Int32, DataType::Int64) | + // UInt casting to higher bitwidths + (DataType::UInt8, DataType::UInt16) | + (DataType::UInt8, DataType::UInt32) | + (DataType::UInt8, DataType::UInt64) | + (DataType::UInt16, DataType::UInt32) | + (DataType::UInt16, DataType::UInt64) | + (DataType::UInt32, DataType::UInt64) | + // Float casting to higher bitwidths + (DataType::Float32, DataType::Float64) | + // Numeric to temporal casting from smaller-than-eq bitwidths + (DataType::Int8, DataType::Date) | + (DataType::Int16, DataType::Date) | + (DataType::Int32, DataType::Date) | + (DataType::Int8, DataType::Timestamp(..)) | + (DataType::Int16, DataType::Timestamp(..)) | + (DataType::Int32, DataType::Timestamp(..)) | + (DataType::Int64, DataType::Timestamp(..)) | + // Binary to Utf8 + (DataType::Binary, DataType::Utf8) + => Ok(ColumnRangeStatistics::Loaded( + l.cast(dtype).context(DaftCoreComputeSnafu)?, + r.cast(dtype).context(DaftCoreComputeSnafu)?, + )), + _ => Ok(ColumnRangeStatistics::Missing) + } + } + } + } } impl std::fmt::Display for ColumnRangeStatistics { diff --git a/src/daft-stats/src/table_stats.rs b/src/daft-stats/src/table_stats.rs index db3acbec4f..35154ddf72 100644 --- a/src/daft-stats/src/table_stats.rs +++ b/src/daft-stats/src/table_stats.rs @@ -6,7 +6,10 @@ use indexmap::{IndexMap, IndexSet}; use crate::column_stats::ColumnRangeStatistics; -use daft_core::{array::ops::DaftCompare, schema::Schema}; +use daft_core::{ + array::ops::DaftCompare, + schema::{Schema, SchemaRef}, +}; #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct TableStatistics { @@ -104,6 +107,20 @@ impl TableStatistics { _ => Ok(ColumnRangeStatistics::Missing), } } + + pub fn cast_to_schema(&self, schema: SchemaRef) -> crate::Result { + let mut columns = IndexMap::new(); + for (field_name, field) in schema.fields.iter() { + let crs = match self.columns.get(field_name) { + Some(column_stat) => column_stat + .cast(&field.dtype) + .unwrap_or(ColumnRangeStatistics::Missing), + None => ColumnRangeStatistics::Missing, + }; + columns.insert(field_name.clone(), crs); + } + Ok(TableStatistics { columns }) + } } impl Display for TableStatistics {