Skip to content

Commit

Permalink
Introduce Signature::Coercible (#12275)
Browse files Browse the repository at this point in the history
* introduce signature float

Signed-off-by: jayzhan211 <[email protected]>

* fix test

Signed-off-by: jayzhan211 <[email protected]>

* change float to coercible

Signed-off-by: jayzhan211 <[email protected]>

* fmt

Signed-off-by: jayzhan211 <[email protected]>

* rm test

Signed-off-by: jayzhan211 <[email protected]>

* add comment

Signed-off-by: jayzhan211 <[email protected]>

* typo

Signed-off-by: jayzhan211 <[email protected]>

---------

Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 authored Sep 2, 2024
1 parent ac74cd3 commit 8db30e2
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 18 deletions.
6 changes: 4 additions & 2 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2428,7 +2428,8 @@ mod tests {
let df: Vec<RecordBatch> = df.select(aggr_expr)?.collect().await?;

assert_batches_sorted_eq!(
["+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
[
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
"| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |",
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
"| | | | | | | | 1 | -85 |",
Expand All @@ -2452,7 +2453,8 @@ mod tests {
"| -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 |",
"| -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 |",
"| -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 |",
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+"],
"+-------------+----------+-----------------+---------------+--------+-----+------+----+------+",
],
&df
);

Expand Down
15 changes: 14 additions & 1 deletion datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ pub enum TypeSignature {
Uniform(usize, Vec<DataType>),
/// Exact number of arguments of an exact type
Exact(Vec<DataType>),
/// The number of arguments that can be coerced to in order
/// For example, `Coercible(vec![DataType::Float64])` accepts
/// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]`
/// since i32 and f32 can be casted to f64
Coercible(Vec<DataType>),
/// Fixed number of arguments of arbitrary types
/// If a function takes 0 argument, its `TypeSignature` should be `Any(0)`
Any(usize),
Expand Down Expand Up @@ -188,7 +193,7 @@ impl TypeSignature {
TypeSignature::Numeric(num) => {
vec![format!("Numeric({})", num)]
}
TypeSignature::Exact(types) => {
TypeSignature::Exact(types) | TypeSignature::Coercible(types) => {
vec![Self::join_types(types, ", ")]
}
TypeSignature::Any(arg_count) => {
Expand Down Expand Up @@ -300,6 +305,14 @@ impl Signature {
volatility,
}
}
/// Target coerce types in order
pub fn coercible(target_types: Vec<DataType>, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Coercible(target_types),
volatility,
}
}

/// A specified number of arguments of any type
pub fn any(arg_count: usize, volatility: Volatility) -> Self {
Signature {
Expand Down
6 changes: 4 additions & 2 deletions datafusion/expr-common/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ pub fn check_arg_count(
);
}
}
TypeSignature::UserDefined | TypeSignature::Numeric(_) => {
TypeSignature::UserDefined
| TypeSignature::Numeric(_)
| TypeSignature::Coercible(_) => {
// User-defined signature is validated in `coerce_types`
// Numreic signature is validated in `get_valid_types`
// Numeric and Coercible signature is validated in `get_valid_types`
}
_ => {
return internal_err!(
Expand Down
33 changes: 32 additions & 1 deletion datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,14 @@ fn try_coerce_types(
let mut valid_types = valid_types;

// Well-supported signature that returns exact valid types.
if !valid_types.is_empty() && matches!(type_signature, TypeSignature::UserDefined) {
if !valid_types.is_empty()
&& matches!(
type_signature,
TypeSignature::UserDefined
| TypeSignature::Numeric(_)
| TypeSignature::Coercible(_)
)
{
// exact valid types
assert_eq!(valid_types.len(), 1);
let valid_types = valid_types.swap_remove(0);
Expand Down Expand Up @@ -397,6 +404,30 @@ fn get_valid_types(

vec![vec![valid_type; *number]]
}
TypeSignature::Coercible(target_types) => {
if target_types.is_empty() {
return plan_err!(
"The signature expected at least one argument but received {}",
current_types.len()
);
}
if target_types.len() != current_types.len() {
return plan_err!(
"The signature expected {} arguments but received {}",
target_types.len(),
current_types.len()
);
}

for (data_type, target_type) in current_types.iter().zip(target_types.iter())
{
if !can_cast_types(data_type, target_type) {
return plan_err!("{data_type} is not coercible to {target_type}");
}
}

vec![target_types.to_owned()]
}
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
Expand Down
11 changes: 5 additions & 6 deletions datafusion/functions-aggregate/src/stddev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ impl Stddev {
/// Create a new STDDEV aggregate function
pub fn new() -> Self {
Self {
signature: Signature::numeric(1, Volatility::Immutable),
signature: Signature::coercible(
vec![DataType::Float64],
Volatility::Immutable,
),
alias: vec!["stddev_samp".to_string()],
}
}
Expand All @@ -88,11 +91,7 @@ impl AggregateUDFImpl for Stddev {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("Stddev requires numeric input types");
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

Expand Down
11 changes: 5 additions & 6 deletions datafusion/functions-aggregate/src/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ impl VarianceSample {
pub fn new() -> Self {
Self {
aliases: vec![String::from("var_sample"), String::from("var_samp")],
signature: Signature::numeric(1, Volatility::Immutable),
signature: Signature::coercible(
vec![DataType::Float64],
Volatility::Immutable,
),
}
}
}
Expand All @@ -97,11 +100,7 @@ impl AggregateUDFImpl for VarianceSample {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("Variance requires numeric input types");
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

Expand Down

0 comments on commit 8db30e2

Please sign in to comment.