diff --git a/crates/deltalake-core/src/delta_datafusion/mod.rs b/crates/deltalake-core/src/delta_datafusion/mod.rs index 3917347888..83db86e8e2 100644 --- a/crates/deltalake-core/src/delta_datafusion/mod.rs +++ b/crates/deltalake-core/src/delta_datafusion/mod.rs @@ -32,7 +32,9 @@ use arrow::datatypes::{DataType as ArrowDataType, Schema as ArrowSchema, SchemaR use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use arrow_array::types::UInt16Type; -use arrow_array::{DictionaryArray, StringArray}; +use arrow_array::{Array, DictionaryArray, StringArray}; +use arrow_cast::display::array_value_to_string; + use arrow_schema::Field; use async_trait::async_trait; use chrono::{NaiveDateTime, TimeZone, Utc}; @@ -66,17 +68,20 @@ use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; use datafusion_proto::logical_plan::LogicalExtensionCodec; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_sql::planner::ParserOptions; + +use itertools::Itertools; use log::error; use object_store::ObjectMeta; use serde::{Deserialize, Serialize}; use url::Url; use crate::errors::{DeltaResult, DeltaTableError}; -use crate::kernel::{Add, DataType as DeltaDataType, Invariant, PrimitiveType}; +use crate::kernel::{Add, DataCheck, DataType as DeltaDataType, Invariant, PrimitiveType}; use crate::logstore::LogStoreRef; use crate::protocol::{ColumnCountStat, ColumnValueStat}; use crate::table::builder::ensure_table_uri; use crate::table::state::DeltaTableState; +use crate::table::Constraint; use crate::{open_table, open_table_with_storage_options, DeltaTable}; const PATH_COLUMN: &str = "__delta_rs_path"; @@ -1017,15 +1022,43 @@ pub(crate) fn logical_expr_to_physical_expr( /// Responsible for checking batches of data conform to table's invariants. #[derive(Clone)] pub struct DeltaDataChecker { + constraints: Vec, invariants: Vec, ctx: SessionContext, } impl DeltaDataChecker { + /// Create a new DeltaDataChecker with a specified set of invariants + pub fn new_with_invariants(invariants: Vec) -> Self { + Self { + invariants, + constraints: vec![], + ctx: SessionContext::new(), + } + } + + /// Create a new DeltaDataChecker with a specified set of constraints + pub fn new_with_constraints(constraints: Vec) -> Self { + Self { + constraints, + invariants: vec![], + ctx: SessionContext::new(), + } + } + /// Create a new DeltaDataChecker - pub fn new(invariants: Vec) -> Self { + pub fn new(snapshot: &DeltaTableState) -> Self { + let metadata = snapshot.metadata(); + + let invariants = metadata + .and_then(|meta| meta.schema.get_invariants().ok()) + .unwrap_or_default(); + let constraints = metadata + .map(|meta| meta.get_constraints()) + .unwrap_or_default(); Self { invariants, + constraints, ctx: SessionContext::new(), } } @@ -1035,45 +1068,54 @@ impl DeltaDataChecker { /// If it does not, it will return [DeltaTableError::InvalidData] with a list /// of values that violated each invariant. pub async fn check_batch(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> { - self.enforce_invariants(record_batch).await - // TODO: for support for Protocol V3, check constraints + self.enforce_checks(record_batch, &self.invariants).await?; + self.enforce_checks(record_batch, &self.constraints).await } - async fn enforce_invariants(&self, record_batch: &RecordBatch) -> Result<(), DeltaTableError> { - // Invariants are deprecated, so let's not pay the overhead for any of this - // if we can avoid it. - if self.invariants.is_empty() { + async fn enforce_checks( + &self, + record_batch: &RecordBatch, + checks: &[C], + ) -> Result<(), DeltaTableError> { + if checks.is_empty() { return Ok(()); } - let table = MemTable::try_new(record_batch.schema(), vec![vec![record_batch.clone()]])?; self.ctx.register_table("data", Arc::new(table))?; let mut violations: Vec = Vec::new(); - for invariant in self.invariants.iter() { - if invariant.field_name.contains('.') { + for check in checks { + if check.get_name().contains('.') { return Err(DeltaTableError::Generic( - "Support for column invariants on nested columns is not supported.".to_string(), + "Support for nested columns is not supported.".to_string(), )); } let sql = format!( - "SELECT {} FROM data WHERE not ({}) LIMIT 1", - invariant.field_name, invariant.invariant_sql + "SELECT {} FROM data WHERE NOT ({}) LIMIT 1", + check.get_name(), + check.get_expression() ); let dfs: Vec = self.ctx.sql(&sql).await?.collect().await?; if !dfs.is_empty() && dfs[0].num_rows() > 0 { - let value = format!("{:?}", dfs[0].column(0)); + let value: String = dfs[0] + .columns() + .iter() + .map(|c| array_value_to_string(c, 0).unwrap_or(String::from("null"))) + .join(", "); + let msg = format!( - "Invariant ({}) violated by value {}", - invariant.invariant_sql, value + "Check or Invariant ({}) violated by value in row: [{}]", + check.get_expression(), + value ); violations.push(msg); } } + self.ctx.deregister_table("data")?; if !violations.is_empty() { Err(DeltaTableError::InvalidData { violations }) } else { @@ -1747,7 +1789,7 @@ mod tests { .unwrap(); // Empty invariants is okay let invariants: Vec = vec![]; - assert!(DeltaDataChecker::new(invariants) + assert!(DeltaDataChecker::new_with_invariants(invariants) .check_batch(&batch) .await .is_ok()); @@ -1757,7 +1799,7 @@ mod tests { Invariant::new("a", "a is not null"), Invariant::new("b", "b < 1000"), ]; - assert!(DeltaDataChecker::new(invariants) + assert!(DeltaDataChecker::new_with_invariants(invariants) .check_batch(&batch) .await .is_ok()); @@ -1767,7 +1809,9 @@ mod tests { Invariant::new("a", "a is null"), Invariant::new("b", "b < 100"), ]; - let result = DeltaDataChecker::new(invariants).check_batch(&batch).await; + let result = DeltaDataChecker::new_with_invariants(invariants) + .check_batch(&batch) + .await; assert!(result.is_err()); assert!(matches!(result, Err(DeltaTableError::InvalidData { .. }))); if let Err(DeltaTableError::InvalidData { violations }) = result { @@ -1776,7 +1820,9 @@ mod tests { // Irrelevant invariants return a different error let invariants = vec![Invariant::new("c", "c > 2000")]; - let result = DeltaDataChecker::new(invariants).check_batch(&batch).await; + let result = DeltaDataChecker::new_with_invariants(invariants) + .check_batch(&batch) + .await; assert!(result.is_err()); // Nested invariants are unsupported @@ -1790,7 +1836,9 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![inner]).unwrap(); let invariants = vec![Invariant::new("x.b", "x.b < 1000")]; - let result = DeltaDataChecker::new(invariants).check_batch(&batch).await; + let result = DeltaDataChecker::new_with_invariants(invariants) + .check_batch(&batch) + .await; assert!(result.is_err()); assert!(matches!(result, Err(DeltaTableError::Generic { .. }))); } diff --git a/crates/deltalake-core/src/kernel/actions/types.rs b/crates/deltalake-core/src/kernel/actions/types.rs index f38cbd51b2..f64a5caa08 100644 --- a/crates/deltalake-core/src/kernel/actions/types.rs +++ b/crates/deltalake-core/src/kernel/actions/types.rs @@ -130,9 +130,11 @@ pub struct Protocol { pub min_writer_version: i32, /// A collection of features that a client must implement in order to correctly /// read this table (exist only when minReaderVersion is set to 3) + #[serde(skip_serializing_if = "Option::is_none")] pub reader_features: Option>, /// A collection of features that a client must implement in order to correctly /// write this table (exist only when minWriterVersion is set to 7) + #[serde(skip_serializing_if = "Option::is_none")] pub writer_features: Option>, } diff --git a/crates/deltalake-core/src/kernel/mod.rs b/crates/deltalake-core/src/kernel/mod.rs index c8d01c138d..9fb9dba6b4 100644 --- a/crates/deltalake-core/src/kernel/mod.rs +++ b/crates/deltalake-core/src/kernel/mod.rs @@ -13,3 +13,11 @@ pub use actions::*; pub use error::*; pub use expressions::*; pub use schema::*; + +/// A trait for all kernel types that are used as part of data checking +pub trait DataCheck { + /// The name of the specific check + fn get_name(&self) -> &str; + /// The SQL expression to use for the check + fn get_expression(&self) -> &str; +} diff --git a/crates/deltalake-core/src/kernel/schema.rs b/crates/deltalake-core/src/kernel/schema.rs index bc83c05070..08cf991dd5 100644 --- a/crates/deltalake-core/src/kernel/schema.rs +++ b/crates/deltalake-core/src/kernel/schema.rs @@ -6,6 +6,7 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use std::{collections::HashMap, fmt::Display}; +use crate::kernel::DataCheck; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -97,6 +98,16 @@ impl Invariant { } } +impl DataCheck for Invariant { + fn get_name(&self) -> &str { + &self.field_name + } + + fn get_expression(&self) -> &str { + &self.invariant_sql + } +} + /// Represents a struct field defined in the Delta table schema. // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#Schema-Serialization-Format #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] diff --git a/crates/deltalake-core/src/operations/constraints.rs b/crates/deltalake-core/src/operations/constraints.rs new file mode 100644 index 0000000000..889e668b1a --- /dev/null +++ b/crates/deltalake-core/src/operations/constraints.rs @@ -0,0 +1,315 @@ +//! Add a check constraint to a table + +use std::collections::HashMap; +use std::sync::Arc; + +use chrono::Utc; +use datafusion::execution::context::SessionState; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use futures::future::BoxFuture; +use futures::StreamExt; +use serde_json::json; + +use crate::delta_datafusion::{register_store, DeltaDataChecker, DeltaScanBuilder}; +use crate::kernel::{Action, CommitInfo, IsolationLevel, Metadata, Protocol}; +use crate::logstore::LogStoreRef; +use crate::operations::datafusion_utils::Expression; +use crate::operations::transaction::commit; +use crate::protocol::DeltaOperation; +use crate::table::state::DeltaTableState; +use crate::table::Constraint; +use crate::DeltaTable; +use crate::{DeltaResult, DeltaTableError}; + +/// Build a constraint to add to a table +pub struct ConstraintBuilder { + snapshot: DeltaTableState, + name: Option, + expr: Option, + log_store: LogStoreRef, + state: Option, +} + +impl ConstraintBuilder { + /// Create a new builder + pub fn new(log_store: LogStoreRef, snapshot: DeltaTableState) -> Self { + Self { + name: None, + expr: None, + snapshot, + log_store, + state: None, + } + } + + /// Specify the constraint to be added + pub fn with_constraint, E: Into>( + mut self, + column: S, + expression: E, + ) -> Self { + self.name = Some(column.into()); + self.expr = Some(expression.into()); + self + } + + /// Specify the datafusion session context + pub fn with_session_state(mut self, state: SessionState) -> Self { + self.state = Some(state); + self + } +} + +impl std::future::IntoFuture for ConstraintBuilder { + type Output = DeltaResult; + + type IntoFuture = BoxFuture<'static, Self::Output>; + + fn into_future(self) -> Self::IntoFuture { + let mut this = self; + + Box::pin(async move { + let name = match this.name { + Some(v) => v, + None => return Err(DeltaTableError::Generic("No name provided".to_string())), + }; + let expr = match this.expr { + Some(Expression::String(s)) => s, + Some(Expression::DataFusion(e)) => e.to_string(), + None => { + return Err(DeltaTableError::Generic( + "No expression provided".to_string(), + )) + } + }; + + let mut metadata = this + .snapshot + .metadata() + .ok_or(DeltaTableError::NoMetadata)? + .clone(); + let configuration_key = format!("delta.constraints.{}", name); + + if metadata.configuration.contains_key(&configuration_key) { + return Err(DeltaTableError::Generic(format!( + "Constraint with name: {} already exists, expr: {}", + name, expr + ))); + } + + let state = this.state.unwrap_or_else(|| { + let session = SessionContext::new(); + register_store(this.log_store.clone(), session.runtime_env()); + session.state() + }); + + // Checker built here with the one time constraint to check. + let checker = DeltaDataChecker::new_with_constraints(vec![Constraint::new("*", &expr)]); + let scan = DeltaScanBuilder::new(&this.snapshot, this.log_store.clone(), &state) + .build() + .await?; + + let plan: Arc = Arc::new(scan); + let mut tasks = vec![]; + for p in 0..plan.output_partitioning().partition_count() { + let inner_plan = plan.clone(); + let inner_checker = checker.clone(); + let task_ctx = Arc::new(TaskContext::from(&state)); + let mut record_stream: SendableRecordBatchStream = + inner_plan.execute(p, task_ctx)?; + let handle: tokio::task::JoinHandle> = + tokio::task::spawn(async move { + while let Some(maybe_batch) = record_stream.next().await { + let batch = maybe_batch?; + inner_checker.check_batch(&batch).await?; + } + Ok(()) + }); + tasks.push(handle); + } + futures::future::join_all(tasks) + .await + .into_iter() + .collect::, _>>() + .map_err(|err| DeltaTableError::Generic(err.to_string()))? + .into_iter() + .collect::, _>>()?; + + // We have validated the table passes it's constraints, now to add the constraint to + // the table. + + metadata + .configuration + .insert(format!("delta.constraints.{}", name), Some(expr.clone())); + + let old_protocol = this.snapshot.protocol(); + let protocol = Protocol { + min_reader_version: if old_protocol.min_reader_version > 1 { + old_protocol.min_reader_version + } else { + 1 + }, + min_writer_version: if old_protocol.min_writer_version > 3 { + old_protocol.min_writer_version + } else { + 3 + }, + reader_features: old_protocol.reader_features.clone(), + writer_features: old_protocol.writer_features.clone(), + }; + + let operational_parameters = HashMap::from_iter([ + ("name".to_string(), json!(&name)), + ("expr".to_string(), json!(&expr)), + ]); + + let operations = DeltaOperation::AddConstraint { + name: name.clone(), + expr: expr.clone(), + }; + + let commit_info = CommitInfo { + timestamp: Some(Utc::now().timestamp_millis()), + operation: Some(operations.name().to_string()), + operation_parameters: Some(operational_parameters), + read_version: Some(this.snapshot.version()), + isolation_level: Some(IsolationLevel::Serializable), + is_blind_append: Some(false), + ..Default::default() + }; + + let actions = vec![ + Action::CommitInfo(commit_info), + Action::Metadata(Metadata::try_from(metadata)?), + Action::Protocol(protocol), + ]; + + let version = commit( + this.log_store.as_ref(), + &actions, + operations, + &this.snapshot, + None, + ) + .await?; + + this.snapshot + .merge(DeltaTableState::from_actions(actions, version)?, true, true); + Ok(DeltaTable::new_with_state(this.log_store, this.snapshot)) + }) + } +} + +#[cfg(feature = "datafusion")] +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{Array, Int32Array, RecordBatch, StringArray}; + + use crate::writer::test_utils::{create_bare_table, get_arrow_schema, get_record_batch}; + use crate::{DeltaOps, DeltaResult}; + + #[cfg(feature = "datafusion")] + #[tokio::test] + async fn add_constraint_with_invalid_data() -> DeltaResult<()> { + let batch = get_record_batch(None, false); + let write = DeltaOps(create_bare_table()) + .write(vec![batch.clone()]) + .await?; + let table = DeltaOps(write); + + let constraint = table + .add_constraint() + .with_constraint("id", "value > 5") + .await; + dbg!(&constraint); + assert!(constraint.is_err()); + Ok(()) + } + + #[cfg(feature = "datafusion")] + #[tokio::test] + async fn add_valid_constraint() -> DeltaResult<()> { + let batch = get_record_batch(None, false); + let write = DeltaOps(create_bare_table()) + .write(vec![batch.clone()]) + .await?; + let table = DeltaOps(write); + + let constraint = table + .add_constraint() + .with_constraint("id", "value < 1000") + .await; + dbg!(&constraint); + assert!(constraint.is_ok()); + let version = constraint?.version(); + assert_eq!(version, 1); + Ok(()) + } + + #[cfg(feature = "datafusion")] + #[tokio::test] + async fn add_conflicting_named_constraint() -> DeltaResult<()> { + let batch = get_record_batch(None, false); + let write = DeltaOps(create_bare_table()) + .write(vec![batch.clone()]) + .await?; + let table = DeltaOps(write); + + let new_table = table + .add_constraint() + .with_constraint("id", "value < 60") + .await?; + + let new_table = DeltaOps(new_table); + let second_constraint = new_table + .add_constraint() + .with_constraint("id", "value < 10") + .await; + dbg!(&second_constraint); + assert!(second_constraint.is_err()); + Ok(()) + } + + #[cfg(feature = "datafusion")] + #[tokio::test] + async fn write_data_that_violates_constraint() -> DeltaResult<()> { + let batch = get_record_batch(None, false); + let write = DeltaOps(create_bare_table()) + .write(vec![batch.clone()]) + .await?; + + let table = DeltaOps(write) + .add_constraint() + .with_constraint("id", "value > 0") + .await?; + let table = DeltaOps(table); + let invalid_values: Vec> = vec![ + Arc::new(StringArray::from(vec!["A"])), + Arc::new(Int32Array::from(vec![-10])), + Arc::new(StringArray::from(vec!["2021-02-02"])), + ]; + let batch = RecordBatch::try_new(get_arrow_schema(&None), invalid_values)?; + let err = table.write(vec![batch]).await; + dbg!(&err); + assert!(err.is_err()); + Ok(()) + } + + #[tokio::test] + async fn write_data_that_does_not_violate_constraint() -> DeltaResult<()> { + let batch = get_record_batch(None, false); + let write = DeltaOps(create_bare_table()) + .write(vec![batch.clone()]) + .await?; + let table = DeltaOps(write); + + let err = table.write(vec![batch]).await; + + assert!(err.is_ok()); + Ok(()) + } +} diff --git a/crates/deltalake-core/src/operations/mod.rs b/crates/deltalake-core/src/operations/mod.rs index 99a7303691..ee3fb45114 100644 --- a/crates/deltalake-core/src/operations/mod.rs +++ b/crates/deltalake-core/src/operations/mod.rs @@ -29,8 +29,8 @@ pub mod vacuum; #[cfg(feature = "datafusion")] use self::{ - datafusion_utils::Expression, delete::DeleteBuilder, load::LoadBuilder, merge::MergeBuilder, - update::UpdateBuilder, write::WriteBuilder, + constraints::ConstraintBuilder, datafusion_utils::Expression, delete::DeleteBuilder, + load::LoadBuilder, merge::MergeBuilder, update::UpdateBuilder, write::WriteBuilder, }; #[cfg(feature = "datafusion")] pub use ::datafusion::physical_plan::common::collect as collect_sendable_stream; @@ -40,6 +40,8 @@ use arrow::record_batch::RecordBatch; use optimize::OptimizeBuilder; use restore::RestoreBuilder; +#[cfg(feature = "datafusion")] +pub mod constraints; #[cfg(feature = "datafusion")] pub mod delete; #[cfg(feature = "datafusion")] @@ -189,6 +191,13 @@ impl DeltaOps { ) -> MergeBuilder { MergeBuilder::new(self.0.log_store, self.0.state, predicate.into(), source) } + + /// Add a check constraint to a table + #[cfg(feature = "datafusion")] + #[must_use] + pub fn add_constraint(self) -> ConstraintBuilder { + ConstraintBuilder::new(self.0.log_store, self.0.state) + } } impl From for DeltaOps { diff --git a/crates/deltalake-core/src/operations/transaction/protocol.rs b/crates/deltalake-core/src/operations/transaction/protocol.rs index 38209fb4aa..9c20755935 100644 --- a/crates/deltalake-core/src/operations/transaction/protocol.rs +++ b/crates/deltalake-core/src/operations/transaction/protocol.rs @@ -170,7 +170,7 @@ pub static INSTANCE: Lazy = Lazy::new(|| { let mut writer_features = HashSet::new(); writer_features.insert(WriterFeatures::AppendOnly); writer_features.insert(WriterFeatures::Invariants); - // writer_features.insert(WriterFeatures::CheckConstraints); + writer_features.insert(WriterFeatures::CheckConstraints); // writer_features.insert(WriterFeatures::ChangeDataFeed); // writer_features.insert(WriterFeatures::GeneratedColumns); // writer_features.insert(WriterFeatures::ColumnMapping); diff --git a/crates/deltalake-core/src/operations/write.rs b/crates/deltalake-core/src/operations/write.rs index f2ccddea74..6da3b18ecb 100644 --- a/crates/deltalake-core/src/operations/write.rs +++ b/crates/deltalake-core/src/operations/write.rs @@ -306,11 +306,6 @@ pub(crate) async fn write_execution_plan( safe_cast: bool, overwrite_schema: bool, ) -> DeltaResult> { - let invariants = snapshot - .metadata() - .and_then(|meta| meta.schema.get_invariants().ok()) - .unwrap_or_default(); - // Use input schema to prevent wrapping partitions columns into a dictionary. let schema: ArrowSchemaRef = if overwrite_schema { plan.schema() @@ -318,7 +313,7 @@ pub(crate) async fn write_execution_plan( snapshot.input_schema().unwrap_or(plan.schema()) }; - let checker = DeltaDataChecker::new(invariants); + let checker = DeltaDataChecker::new(snapshot); // Write data to disk let mut tasks = vec![]; diff --git a/crates/deltalake-core/src/protocol/mod.rs b/crates/deltalake-core/src/protocol/mod.rs index 661d75b244..311f6dac7e 100644 --- a/crates/deltalake-core/src/protocol/mod.rs +++ b/crates/deltalake-core/src/protocol/mod.rs @@ -410,6 +410,13 @@ pub enum DeltaOperation { /// The update predicate predicate: Option, }, + /// Add constraints to a table + AddConstraint { + /// Constraints name + name: String, + /// Expression to check against + expr: String, + }, /// Merge data with a source data with the following predicate #[serde(rename_all = "camelCase")] @@ -497,6 +504,7 @@ impl DeltaOperation { DeltaOperation::Restore { .. } => "RESTORE", DeltaOperation::VacuumStart { .. } => "VACUUM START", DeltaOperation::VacuumEnd { .. } => "VACUUM END", + DeltaOperation::AddConstraint { .. } => "ADD CONSTRAINT", } } @@ -532,7 +540,10 @@ impl DeltaOperation { /// Denotes if the operation changes the data contained in the table pub fn changes_data(&self) -> bool { match self { - Self::Optimize { .. } | Self::VacuumStart { .. } | Self::VacuumEnd { .. } => false, + Self::Optimize { .. } + | Self::VacuumStart { .. } + | Self::VacuumEnd { .. } + | Self::AddConstraint { .. } => false, Self::Create { .. } | Self::FileSystemCheck {} | Self::StreamingUpdate { .. } diff --git a/crates/deltalake-core/src/table/mod.rs b/crates/deltalake-core/src/table/mod.rs index de6a176e91..83374d1657 100644 --- a/crates/deltalake-core/src/table/mod.rs +++ b/crates/deltalake-core/src/table/mod.rs @@ -21,8 +21,8 @@ use self::builder::DeltaTableConfig; use self::state::DeltaTableState; use crate::errors::DeltaTableError; use crate::kernel::{ - Action, Add, CommitInfo, DataType, Format, Metadata, Protocol, ReaderFeatures, Remove, - StructType, WriterFeatures, + Action, Add, CommitInfo, DataCheck, DataType, Format, Metadata, Protocol, ReaderFeatures, + Remove, StructType, WriterFeatures, }; use crate::logstore::LogStoreRef; use crate::logstore::{self, LogStoreConfig}; @@ -136,6 +136,35 @@ impl PartialEq for CheckPoint { impl Eq for CheckPoint {} +/// A constraint in a check constraint +#[derive(Eq, PartialEq, Debug, Default, Clone)] +pub struct Constraint { + /// The full path to the field. + pub name: String, + /// The SQL string that must always evaluate to true. + pub expr: String, +} + +impl Constraint { + /// Create a new invariant + pub fn new(field_name: &str, invariant_sql: &str) -> Self { + Self { + name: field_name.to_string(), + expr: invariant_sql.to_string(), + } + } +} + +impl DataCheck for Constraint { + fn get_name(&self) -> &str { + &self.name + } + + fn get_expression(&self) -> &str { + &self.expr + } +} + /// Delta table metadata #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] pub struct DeltaTableMetaData { @@ -187,6 +216,20 @@ impl DeltaTableMetaData { &self.configuration } + /// Return the check constraints on the current table + pub fn get_constraints(&self) -> Vec { + self.configuration + .iter() + .filter_map(|(field, value)| { + if field.starts_with("delta.constraints") { + value.as_ref().map(|f| Constraint::new("*", f)) + } else { + None + } + }) + .collect() + } + /// Return partition fields along with their data type from the current schema. pub fn get_partition_col_data_types(&self) -> Vec<(&String, &DataType)> { // JSON add actions contain a `partitionValues` field which is a map. @@ -955,6 +998,27 @@ mod tests { } } + #[test] + fn get_table_constraints() { + let state = DeltaTableMetaData::new( + None, + None, + None, + StructType::new(vec![]), + vec![], + HashMap::from_iter(vec![ + ( + "delta.constraints.id".to_string(), + Some("id > 0".to_string()), + ), + ("delta.blahblah".to_string(), None), + ]), + ); + + let constraints = state.get_constraints(); + assert_eq!(constraints.len(), 1) + } + async fn create_test_table() -> (DeltaTable, TempDir) { let tmp_dir = TempDir::new("create_table_test").unwrap(); let table_dir = tmp_dir.path().join("test_create"); diff --git a/crates/deltalake-sql/src/planner.rs b/crates/deltalake-sql/src/planner.rs index 3eb4742308..099f97087d 100644 --- a/crates/deltalake-sql/src/planner.rs +++ b/crates/deltalake-sql/src/planner.rs @@ -48,7 +48,6 @@ impl<'a, S: ContextProvider> DeltaSqlToRel<'a, S> { } Statement::Describe(describe) => self.describe_to_plan(describe), Statement::Vacuum(vacuum) => self.vacuum_to_plan(vacuum), - _ => todo!(), } } @@ -92,7 +91,6 @@ mod tests { use arrow_schema::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::DataFusionError; - use datafusion_expr::logical_plan::builder::LogicalTableSource; use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource}; use datafusion_sql::TableReference; diff --git a/python/src/lib.rs b/python/src/lib.rs index 95957dd32f..645a2f0b72 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1319,7 +1319,7 @@ impl PyDeltaDataChecker { }) .collect(); Self { - inner: DeltaDataChecker::new(invariants), + inner: DeltaDataChecker::new_with_invariants(invariants), rt: tokio::runtime::Runtime::new().unwrap(), } }