Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: coalesce schema issues #12308

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2029,6 +2029,43 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_coalesce_schema() -> Result<()> {
let ctx = SessionContext::new();

let query = r#"SELECT COALESCE(null, 5)"#;
mesejo marked this conversation as resolved.
Show resolved Hide resolved

let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}

#[tokio::test]
async fn test_coalesce_from_values_schema() -> Result<()> {
let ctx = SessionContext::new();

let query = r#"SELECT COALESCE(column1, column2) FROM VALUES (null, 1.2)"#;

let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}

#[tokio::test]
async fn test_coalesce_from_values_schema_multiple_rows() -> Result<()> {
let ctx = SessionContext::new();

let query = r#"SELECT COALESCE(column1, column2)
FROM VALUES
(null, 1.2),
(1.1, null),
(2, 5);"#;

let result = ctx.sql(query).await?;
assert_logical_expr_schema_eq_physical_expr_schema(result).await?;
Ok(())
}

#[tokio::test]
async fn test_array_agg_schema() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
23 changes: 12 additions & 11 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,22 @@ impl ExprSchemable for Expr {
.collect::<Result<Vec<_>>>()?;

// verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| {
plan_datafusion_err!(
"{} {}",
err,
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
&arg_data_types,
let new_data_types = data_types_with_scalar_udf(&arg_data_types, func)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the root cause of the issue and to solve this other changes are necessary. Therefore, I think we should go with this change and maybe further optimize the coercion in another PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so should I leave it as it is? Or change it back to how it was:

data_types_with_scalar_udf(&arg_data_types, func)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with the current change, maybe wait for @findepi @alamb for how to move on with this PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I defer to @jayzhan211 -- if he is good to merge this PR, let's get the conflicts resolved and merge it in.

If there is additional work we know is needed / could be cleaned up, let's try and file them as tickets

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conflicts solved! 😄

.map_err(|err| {
plan_datafusion_err!(
"{} {}",
err,
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
&arg_data_types,
)
)
)
})?;
})?;

// perform additional function arguments validation (due to limited
// expressiveness of `TypeSignature`), then infer return type
Ok(func.return_type_from_exprs(args, schema, &arg_data_types)?)
Ok(func.return_type_from_exprs(args, schema, &new_data_types)?)
}
Expr::WindowFunction(window_function) => self
.data_type_and_nullable_with_window_function(schema, window_function)
Expand Down
4 changes: 3 additions & 1 deletion datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ impl LogicalPlanBuilder {
common_type = Some(data_type);
}
}
field_types.push(common_type.unwrap_or(DataType::Utf8));
// assuming common_type was not set, and no error, therefore the type should be NULL
// since the code loop skips NULL
field_types.push(common_type.unwrap_or(DataType::Null));
}
// wrap cast if data type is not same as common type.
for row in &mut values {
Expand Down
30 changes: 26 additions & 4 deletions datafusion/functions/src/core/coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ use arrow::array::{new_null_array, BooleanArray};
use arrow::compute::kernels::zip::zip;
use arrow::compute::{and, is_not_null, is_null};
use arrow::datatypes::DataType;

use datafusion_common::{exec_err, ExprSchema, Result};
use datafusion_expr::type_coercion::binary::type_union_resolution;
use datafusion_expr::{ColumnarValue, Expr, ExprSchemable};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use itertools::Itertools;

#[derive(Debug)]
pub struct CoalesceFunc {
Expand Down Expand Up @@ -60,12 +60,16 @@ impl ScalarUDFImpl for CoalesceFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(arg_types[0].clone())
Ok(arg_types
.iter()
.find_or_first(|d| !d.is_null())
.unwrap()
mesejo marked this conversation as resolved.
Show resolved Hide resolved
.clone())
}

// If all the element in coalesce is non-null, the result is non-null
// If any the arguments in coalesce is non-null, the result is non-null
fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool {
args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true))
args.iter().all(|e| e.nullable(schema).ok().unwrap_or(true))
mesejo marked this conversation as resolved.
Show resolved Hide resolved
}

/// coalesce evaluates to the first value which is not NULL
Expand Down Expand Up @@ -154,4 +158,22 @@ mod test {
.unwrap();
assert_eq!(return_type, DataType::Date32);
}

#[test]
fn test_coalesce_return_types_with_nulls_first() {
let coalesce = core::coalesce::CoalesceFunc::new();
let return_type = coalesce
.return_type(&[DataType::Null, DataType::Date32])
.unwrap();
assert_eq!(return_type, DataType::Date32);
}

#[test]
fn test_coalesce_return_types_with_nulls_last() {
let coalesce = core::coalesce::CoalesceFunc::new();
let return_type = coalesce
.return_type(&[DataType::Int64, DataType::Null])
.unwrap();
assert_eq!(return_type, DataType::Int64);
}
}
57 changes: 30 additions & 27 deletions datafusion/functions/src/datetime/to_local_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,16 @@ use std::sync::Arc;
use arrow::array::timezone::Tz;
use arrow::array::{Array, ArrayRef, PrimitiveBuilder};
use arrow::datatypes::DataType::Timestamp;
use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
use arrow::datatypes::{
ArrowTimestampType, DataType, TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType,
};
use arrow::datatypes::{
TimeUnit,
TimeUnit::{Microsecond, Millisecond, Nanosecond, Second},
};

use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc};
use datafusion_common::cast::as_primitive_array;
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{
ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD,
};
use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

/// A UDF function that converts a timezone-aware timestamp to local time (with no offset or
/// timezone information). In other words, this function strips off the timezone from the timestamp,
Expand All @@ -55,20 +49,8 @@ impl Default for ToLocalTimeFunc {

impl ToLocalTimeFunc {
pub fn new() -> Self {
let base_sig = |array_type: TimeUnit| {
[
Exact(vec![Timestamp(array_type, None)]),
Exact(vec![Timestamp(array_type, Some(TIMEZONE_WILDCARD.into()))]),
]
};

let full_sig = [Nanosecond, Microsecond, Millisecond, Second]
.into_iter()
.flat_map(base_sig)
.collect::<Vec<_>>();

Self {
signature: Signature::one_of(full_sig, Volatility::Immutable),
signature: Signature::user_defined(Volatility::Immutable),
}
}

Expand Down Expand Up @@ -328,13 +310,10 @@ impl ScalarUDFImpl for ToLocalTimeFunc {
}

match &arg_types[0] {
Timestamp(Nanosecond, _) => Ok(Timestamp(Nanosecond, None)),
Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)),
Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)),
Timestamp(Second, _) => Ok(Timestamp(Second, None)),
Timestamp(timeunit, _) => Ok(Timestamp(*timeunit, None)),
_ => exec_err!(
"The to_local_time function can only accept timestamp as the arg, got {:?}", arg_types[0]
),
)
}
}

Expand All @@ -348,6 +327,30 @@ impl ScalarUDFImpl for ToLocalTimeFunc {

self.to_local_time(args)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 1 {
return plan_err!(
"to_local_time function requires 1 argument, got {:?}",
arg_types.len()
);
}

let first_arg = arg_types[0].clone();
match &first_arg {
Timestamp(Nanosecond, timezone) => {
Ok(vec![Timestamp(Nanosecond, timezone.clone())])
}
Timestamp(Microsecond, timezone) => {
Ok(vec![Timestamp(Microsecond, timezone.clone())])
}
Timestamp(Millisecond, timezone) => {
Ok(vec![Timestamp(Millisecond, timezone.clone())])
}
Timestamp(Second, timezone) => Ok(vec![Timestamp(Second, timezone.clone())]),
_ => plan_err!("The to_local_time function can only accept Timestamp as the arg got {first_arg}"),
}
}
}

#[cfg(test)]
Expand Down
103 changes: 58 additions & 45 deletions datafusion/functions/src/encoding/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use datafusion_expr::ColumnarValue;
use std::sync::Arc;
use std::{fmt, str::FromStr};

use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;

Expand All @@ -49,17 +48,8 @@ impl Default for EncodeFunc {

impl EncodeFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that moving the signature from a data driven description (aka describe "what" is needed and letting some other code compute if the given arguments match that signature), this PR is moving many of the functions towards more functional (each function has to implement its own custom coercion, likely resulting in significant duplication).

What do you think (perhaps as a follow on PR) of adding DataType::Null support to the Signature calculations somehow rather than inlining / duplicating the coercion logic?

Maybe something like

Signature::allow_null(..)

that would support automatically coercing arguments from null?

Or maybe we should always support coercing Null to any type

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternative signature like Signature::String, similar to Signature::numeric that includes converting null to string too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure -- I was just reacting that this "handle null" pattern seems common and it seems like this approach will require custom coerce logic for all functions 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Null to T coercion needs to be handled elsewhere anyway (eg when computing type of a UNION, etc.).
We can free functions from having to bother about coercions at all and let the engine calculate coercions when building the logical plan.

This is actually super fundamental for DataFusion vision as a composable query engine. Coercion rules are very implementation-specific. If we had functions spiced up with coercions inside them, that would make those functions non-reusable.

cc @wizardxz @sadboy

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100%

It seems to me like Signature is supposed to communicate what types the function implementation has a native implementation for and the coercion of whatever the user provided doesn't match one of the supported types

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@findepi Are you suggesting something like general coercion that is non-function specific? But what if we want different coercion rule for different function, we might need to do coercion function wise

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we want different coercion rules for different functions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea is that it is more flexible to the user, although, without the real use case, it might be a premature optimization 🤔.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to chip in late; this PR addresses other issues, such as #12307. I wonder if I could split it and leave the changes regarding the coercion of functions in this one (to keep the discussion in one place) and the others in a new PR. Would that be ok?

vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![Binary, Utf8]),
Exact(vec![LargeBinary, Utf8]),
],
Volatility::Immutable,
),
signature: Signature::user_defined(Volatility::Immutable),
}
}
}
Expand All @@ -77,23 +67,39 @@ impl ScalarUDFImpl for EncodeFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;

Ok(match arg_types[0] {
Utf8 => Utf8,
LargeUtf8 => LargeUtf8,
Binary => Utf8,
LargeBinary => LargeUtf8,
Null => Null,
_ => {
return plan_err!("The encode function can only accept utf8 or binary.");
}
})
Ok(arg_types[0].to_owned())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
encode(args)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return plan_err!(
"{} expects to get 2 arguments, but got {}",
self.name(),
arg_types.len()
);
}

if arg_types[1] != DataType::Utf8 {
return Err(DataFusionError::Plan("2nd argument should be Utf8".into()));
}

match arg_types[0] {
DataType::Utf8 | DataType::Binary | DataType::Null => {
Ok(vec![DataType::Utf8; 2])
}
DataType::LargeUtf8 | DataType::LargeBinary => {
Ok(vec![DataType::LargeUtf8, DataType::Utf8])
}
_ => plan_err!(
"1st argument should be Utf8 or Binary or Null, got {:?}",
arg_types[0]
),
}
}
}

#[derive(Debug)]
Expand All @@ -109,17 +115,8 @@ impl Default for DecodeFunc {

impl DecodeFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![Binary, Utf8]),
Exact(vec![LargeBinary, Utf8]),
],
Volatility::Immutable,
),
signature: Signature::user_defined(Volatility::Immutable),
}
}
}
Expand All @@ -137,23 +134,39 @@ impl ScalarUDFImpl for DecodeFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;

Ok(match arg_types[0] {
Utf8 => Binary,
LargeUtf8 => LargeBinary,
Binary => Binary,
LargeBinary => LargeBinary,
Null => Null,
_ => {
return plan_err!("The decode function can only accept utf8 or binary.");
}
})
Ok(arg_types[0].to_owned())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
decode(args)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return plan_err!(
"{} expects to get 2 arguments, but got {}",
self.name(),
arg_types.len()
);
}

if arg_types[1] != DataType::Utf8 {
return plan_err!("2nd argument should be Utf8");
}

match arg_types[0] {
DataType::Utf8 | DataType::Binary | DataType::Null => {
Ok(vec![DataType::Binary, DataType::Utf8])
}
DataType::LargeUtf8 | DataType::LargeBinary => {
Ok(vec![DataType::LargeBinary, DataType::Utf8])
}
_ => plan_err!(
"1st argument should be Utf8 or Binary or Null, got {:?}",
arg_types[0]
),
}
}
}

#[derive(Debug, Copy, Clone)]
Expand Down
Loading