Skip to content

Commit

Permalink
Minor: Add ScalarValue::data_type() for consistency with other APIs (
Browse files Browse the repository at this point in the history
…#7492)

* Minor: Add `ScalarValue::data_type()` for consistency

* Use new API in a few places
  • Loading branch information
alamb authored Sep 8, 2023
1 parent 4d44512 commit 93c209f
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 16 deletions.
11 changes: 9 additions & 2 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1072,8 +1072,8 @@ impl ScalarValue {
})
}

/// Getter for the `DataType` of the value
pub fn get_datatype(&self) -> DataType {
/// return the [`DataType`] of this `ScalarValue`
pub fn data_type(&self) -> DataType {
match self {
ScalarValue::Boolean(_) => DataType::Boolean,
ScalarValue::UInt8(_) => DataType::UInt8,
Expand Down Expand Up @@ -1149,6 +1149,13 @@ impl ScalarValue {
}
}

/// Getter for the `DataType` of the value.
///
/// Suggest using [`Self::data_type`] as a more standard API
pub fn get_datatype(&self) -> DataType {
self.data_type()
}

/// Calculate arithmetic negation for a scalar value
pub fn arithmetic_negate(&self) -> Result<Self> {
match self {
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -939,11 +939,11 @@ impl LogicalPlan {
// Verify if the types of the params matches the types of the values
let iter = prepare_lp.data_types.iter().zip(param_values.iter());
for (i, (param_type, value)) in iter.enumerate() {
if *param_type != value.get_datatype() {
if *param_type != value.data_type() {
return plan_err!(
"Expected parameter of type {:?}, got {:?} at index {}",
param_type,
value.get_datatype(),
value.data_type(),
i
);
}
Expand Down Expand Up @@ -1183,11 +1183,11 @@ impl LogicalPlan {
))
})?;
// check if the data type of the value matches the data type of the placeholder
if Some(value.get_datatype()) != *data_type {
if Some(value.data_type()) != *data_type {
return internal_err!(
"Placeholder value type mismatch: expected {:?}, got {:?}",
data_type,
value.get_datatype()
value.data_type()
);
}
// Replace the placeholder with the value
Expand Down
12 changes: 6 additions & 6 deletions datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ fn try_cast_literal_to_type(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Result<Option<ScalarValue>> {
let lit_data_type = lit_value.get_datatype();
let lit_data_type = lit_value.data_type();
// the rule just support the signed numeric data type now
if !is_support_data_type(&lit_data_type) || !is_support_data_type(target_type) {
return Ok(None);
Expand Down Expand Up @@ -817,7 +817,7 @@ mod tests {
for s2 in &scalars {
let expected_value = ExpectedCast::Value(s2.clone());

expect_cast(s1.clone(), s2.get_datatype(), expected_value);
expect_cast(s1.clone(), s2.data_type(), expected_value);
}
}
}
Expand All @@ -842,7 +842,7 @@ mod tests {
for s2 in &scalars {
let expected_value = ExpectedCast::Value(s2.clone());

expect_cast(s1.clone(), s2.get_datatype(), expected_value);
expect_cast(s1.clone(), s2.data_type(), expected_value);
}
}

Expand Down Expand Up @@ -976,10 +976,10 @@ mod tests {
assert_eq!(lit_tz_none, lit_tz_utc);

// e.g. DataType::Timestamp(_, None)
let dt_tz_none = lit_tz_none.get_datatype();
let dt_tz_none = lit_tz_none.data_type();

// e.g. DataType::Timestamp(_, Some(utc))
let dt_tz_utc = lit_tz_utc.get_datatype();
let dt_tz_utc = lit_tz_utc.data_type();

// None <--> None
expect_cast(
Expand Down Expand Up @@ -1102,7 +1102,7 @@ mod tests {
if let (
DataType::Timestamp(left_unit, left_tz),
DataType::Timestamp(right_unit, right_tz),
) = (actual_value.get_datatype(), expected_value.get_datatype())
) = (actual_value.data_type(), expected_value.data_type())
{
assert_eq!(left_unit, right_unit);
assert_eq!(left_tz, right_tz);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
ScalarValue::Float64(Some(q)) => *q,
got => return not_impl_err!(
"Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
got.get_datatype()
got.data_type()
)
};

Expand Down Expand Up @@ -182,7 +182,7 @@ fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize,
got => return not_impl_err!(
"Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
got.get_datatype()
got.data_type()
)
};
Ok(max_size)
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sql/src/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}

let data_types: HashSet<DataType> =
values.iter().map(|e| e.get_datatype()).collect();
values.iter().map(|e| e.data_type()).collect();

if data_types.is_empty() {
Ok(lit(ScalarValue::new_list(None, DataType::Utf8)))
} else if data_types.len() > 1 {
not_impl_err!("Arrays with different types are not supported: {data_types:?}")
} else {
let data_type = values[0].get_datatype();
let data_type = values[0].data_type();

Ok(lit(ScalarValue::new_list(Some(values), data_type)))
}
Expand Down

0 comments on commit 93c209f

Please sign in to comment.