Skip to content

Commit

Permalink
fix: pass coerced data types
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo committed Sep 9, 2024
1 parent e878a6b commit 8322811
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 123 deletions.
23 changes: 12 additions & 11 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,22 @@ impl ExprSchemable for Expr {
.collect::<Result<Vec<_>>>()?;

// verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| {
plan_datafusion_err!(
"{} {}",
err,
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
&arg_data_types,
let new_data_types = data_types_with_scalar_udf(&arg_data_types, func)
.map_err(|err| {
plan_datafusion_err!(
"{} {}",
err,
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
&arg_data_types,
)
)
)
})?;
})?;

// perform additional function arguments validation (due to limited
// expressiveness of `TypeSignature`), then infer return type
Ok(func.return_type_from_exprs(args, schema, &arg_data_types)?)
Ok(func.return_type_from_exprs(args, schema, &new_data_types)?)
}
Expr::WindowFunction(WindowFunction { fun, args, .. }) => {
let data_types = args
Expand Down
50 changes: 24 additions & 26 deletions datafusion/functions/src/datetime/to_local_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,16 @@ use std::sync::Arc;
use arrow::array::timezone::Tz;
use arrow::array::{Array, ArrayRef, PrimitiveBuilder};
use arrow::datatypes::DataType::Timestamp;
use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
use arrow::datatypes::{
ArrowTimestampType, DataType, TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType,
};
use arrow::datatypes::{
TimeUnit,
TimeUnit::{Microsecond, Millisecond, Nanosecond, Second},
};

use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc};
use datafusion_common::cast::as_primitive_array;
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{
ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD,
};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

/// A UDF function that converts a timezone-aware timestamp to local time (with no offset or
/// timezone information). In other words, this function strips off the timezone from the timestamp,
Expand All @@ -55,20 +49,8 @@ impl Default for ToLocalTimeFunc {

impl ToLocalTimeFunc {
pub fn new() -> Self {
let base_sig = |array_type: TimeUnit| {
[
Exact(vec![Timestamp(array_type, None)]),
Exact(vec![Timestamp(array_type, Some(TIMEZONE_WILDCARD.into()))]),
]
};

let full_sig = [Nanosecond, Microsecond, Millisecond, Second]
.into_iter()
.flat_map(base_sig)
.collect::<Vec<_>>();

Self {
signature: Signature::one_of(full_sig, Volatility::Immutable),
signature: Signature::user_defined(Volatility::Immutable),
}
}

Expand Down Expand Up @@ -328,13 +310,10 @@ impl ScalarUDFImpl for ToLocalTimeFunc {
}

match &arg_types[0] {
Timestamp(Nanosecond, _) => Ok(Timestamp(Nanosecond, None)),
Timestamp(Microsecond, _) => Ok(Timestamp(Microsecond, None)),
Timestamp(Millisecond, _) => Ok(Timestamp(Millisecond, None)),
Timestamp(Second, _) => Ok(Timestamp(Second, None)),
Timestamp(timeunit, _) => Ok(Timestamp(*timeunit, None)),
_ => exec_err!(
"The to_local_time function can only accept timestamp as the arg, got {:?}", arg_types[0]
),
)
}
}

Expand All @@ -348,6 +327,25 @@ impl ScalarUDFImpl for ToLocalTimeFunc {

self.to_local_time(args)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 1 {
return Err(DataFusionError::Execution(format!(
"to_local_time function requires 1 argument, got {:?}",
arg_types.len()
)));
}

match &arg_types[0] {
Timestamp(Nanosecond, timezone) => Ok(vec![Timestamp(Nanosecond, timezone.clone())]),
Timestamp(Microsecond, timezone) => Ok(vec![Timestamp(Microsecond, timezone.clone())]),
Timestamp(Millisecond, timezone) => Ok(vec![Timestamp(Millisecond, timezone.clone())]),
Timestamp(Second, timezone) => Ok(vec![Timestamp(Second, timezone.clone())]),
_ => Err(DataFusionError::Execution(format!(
"The to_local_time function can only accept timestamp as the arg, got {:?}", arg_types[0]
)))
}
}
}

#[cfg(test)]
Expand Down
103 changes: 58 additions & 45 deletions datafusion/functions/src/encoding/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use datafusion_expr::ColumnarValue;
use std::sync::Arc;
use std::{fmt, str::FromStr};

use datafusion_expr::TypeSignature::*;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;

Expand All @@ -49,17 +48,8 @@ impl Default for EncodeFunc {

impl EncodeFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![Binary, Utf8]),
Exact(vec![LargeBinary, Utf8]),
],
Volatility::Immutable,
),
signature: Signature::user_defined(Volatility::Immutable),
}
}
}
Expand All @@ -77,23 +67,39 @@ impl ScalarUDFImpl for EncodeFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;

Ok(match arg_types[0] {
Utf8 => Utf8,
LargeUtf8 => LargeUtf8,
Binary => Utf8,
LargeBinary => LargeUtf8,
Null => Null,
_ => {
return plan_err!("The encode function can only accept utf8 or binary.");
}
})
Ok(arg_types[0].to_owned())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
encode(args)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return Err(DataFusionError::Plan(format!(
"{} expects to get 2 arguments, but got {}",
self.name(),
arg_types.len()
)));
}

if arg_types[1] != DataType::Utf8 {
return Err(DataFusionError::Plan("2nd argument should be Utf8".into()));
}

match arg_types[0] {
DataType::Utf8 | DataType::Binary | DataType::Null => {
Ok(vec![DataType::Utf8; 2])
}
DataType::LargeUtf8 | DataType::LargeBinary => {
Ok(vec![DataType::LargeUtf8, DataType::Utf8])
}
_ => Err(DataFusionError::Plan(format!(
"1st argument should be Utf8 or Binary or Null, got {:?}",
arg_types[0]
))),
}
}
}

#[derive(Debug)]
Expand All @@ -109,17 +115,8 @@ impl Default for DecodeFunc {

impl DecodeFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![Binary, Utf8]),
Exact(vec![LargeBinary, Utf8]),
],
Volatility::Immutable,
),
signature: Signature::user_defined(Volatility::Immutable),
}
}
}
Expand All @@ -137,23 +134,39 @@ impl ScalarUDFImpl for DecodeFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;

Ok(match arg_types[0] {
Utf8 => Binary,
LargeUtf8 => LargeBinary,
Binary => Binary,
LargeBinary => LargeBinary,
Null => Null,
_ => {
return plan_err!("The decode function can only accept utf8 or binary.");
}
})
Ok(arg_types[0].to_owned())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
decode(args)
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return Err(DataFusionError::Plan(format!(
"{} expects to get 2 arguments, but got {}",
self.name(),
arg_types.len()
)));
}

if arg_types[1] != DataType::Utf8 {
return Err(DataFusionError::Plan("2nd argument should be Utf8".into()));
}

match arg_types[0] {
DataType::Utf8 | DataType::Binary | DataType::Null => {
Ok(vec![DataType::Binary, DataType::Utf8])
}
DataType::LargeUtf8 | DataType::LargeBinary => {
Ok(vec![DataType::LargeBinary, DataType::Utf8])
}
_ => Err(DataFusionError::Plan(format!(
"1st argument should be Utf8 or Binary or Null, got {:?}",
arg_types[0]
))),
}
}
}

#[derive(Debug, Copy, Clone)]
Expand Down
36 changes: 21 additions & 15 deletions datafusion/functions/src/unicode/strpos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ use arrow::array::{
};
use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};

use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

use crate::utils::{make_scalar_function, utf8_to_int_type};
Expand All @@ -43,20 +42,8 @@ impl Default for StrposFunc {

impl StrposFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Utf8]),
Exact(vec![Utf8, LargeUtf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
Exact(vec![Utf8View, Utf8View]),
Exact(vec![Utf8View, Utf8]),
Exact(vec![Utf8View, LargeUtf8]),
],
Volatility::Immutable,
),
signature: Signature::user_defined(Volatility::Immutable),
aliases: vec![String::from("instr"), String::from("position")],
}
}
Expand Down Expand Up @@ -86,6 +73,25 @@ impl ScalarUDFImpl for StrposFunc {
fn aliases(&self) -> &[String] {
&self.aliases
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
match arg_types {
[first, second ] => {
match (first, second) {
(DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8, DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8) => Ok(arg_types.to_vec()),
(DataType::Null, DataType::Null) => Ok(vec![DataType::Utf8, DataType::Utf8]),
(DataType::Null, _) => Ok(vec![DataType::Utf8, second.to_owned()]),
(_, DataType::Null) => Ok(vec![first.to_owned(), DataType::Utf8]),
(DataType::Dictionary(_, value_type), DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8) => match **value_type {
DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 | DataType::Null | DataType::Binary => Ok(vec![*value_type.clone(), second.to_owned()]),
_ => Err(DataFusionError::Execution(format!("The STRPOS/INSTR/POSITION function can only accept strings, but got {:?}.", **value_type))),
},
_ => Err(DataFusionError::Execution(format!("The STRPOS/INSTR/POSITION function can only accept strings, but got {:?}.", arg_types)))
}
},
_ => Err(DataFusionError::Execution(format!("The STRPOS/INSTR/POSITION function can only accept strings, but got {:?}", arg_types)))
}
}
}

fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
Expand Down
Loading

0 comments on commit 8322811

Please sign in to comment.