Skip to content

Commit

Permalink
feat: support map lambda functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Dragonliu2018 committed Oct 24, 2024
1 parent b1538fa commit 8016144
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 16 deletions.
41 changes: 41 additions & 0 deletions src/query/expression/src/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ impl<'a> Evaluator<'a> {
{
continue;
}
log::info!("check_expr: {} {}", column.data_type, data_type);
assert_eq!(
&column.data_type,
data_type,
Expand Down Expand Up @@ -1318,6 +1319,46 @@ impl<'a> Evaluator<'a> {
};
builder.push(val.as_ref());
}
ScalarRef::Map(col) => {
let col_len = col.len();
let (key_col, value_col) = match col {
Column::Tuple(t) => (t[0].clone(), t[1].clone()),
_ => {
return Err(ErrorCode::Internal("Map is not Column::Tuple"));
}
};
let entry = match func_name {
"map_transform_keys" => BlockEntry::new(
key_col.data_type().clone(),
Value::Column(key_col.clone()),
),
"map_transform_values" => BlockEntry::new(
value_col.data_type().clone(),
Value::Column(value_col.clone()),
),
_ => {
return Err(ErrorCode::Internal(format!(
"lambda function `{func_name}` is not found"
)));
}
};
entries.push(entry);
let block = DataBlock::new(entries, col_len);

let evaluator = Evaluator::new(&block, self.func_ctx, self.fn_registry);
let result = evaluator.run(&expr)?;
let result_col = result.convert_to_full_column(expr.data_type(), col_len);
let val = match func_name {
"map_transform_keys" => Scalar::Map(Column::Tuple(vec![result_col, value_col])),
"map_transform_values" => Scalar::Map(Column::Tuple(vec![key_col, result_col])),
_ => {
return Err(ErrorCode::Internal(format!(
"lambda function `{func_name}` is not found"
)));
}
};
builder.push(val.as_ref());
}
ScalarRef::Null => {
builder.push_default();
}
Expand Down
96 changes: 80 additions & 16 deletions src/query/sql/src/planner/semantic/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1866,21 +1866,7 @@ impl<'a> TypeChecker<'a> {
.map(|param| param.name.to_lowercase())
.collect::<Vec<_>>();

// TODO: support multiple params
// ARRAY_REDUCE have two params
if params.len() != 1 && func_name != "array_reduce" {
return Err(ErrorCode::SemanticError(format!(
"incorrect number of parameters in lambda function, {} expects 1 parameter, but got {}",
func_name, params.len()
))
.set_span(span));
} else if func_name == "array_reduce" && params.len() != 2 {
return Err(ErrorCode::SemanticError(format!(
"incorrect number of parameters in lambda function, {} expects 2 parameters, but got {}",
func_name, params.len()
))
.set_span(span));
}
self.check_lambda_param_count(func_name, params.len(), span)?;

if args.len() != 1 {
return Err(ErrorCode::SemanticError(format!(
Expand All @@ -1891,9 +1877,11 @@ impl<'a> TypeChecker<'a> {
.set_span(span));
}
let box (mut arg, arg_type) = self.resolve(args[0])?;
log::info!("{} -- arg_type: {}", func_name, arg_type);

let inner_ty = match arg_type.remove_nullable() {
DataType::Array(box inner_ty) => inner_ty.clone(),
DataType::Map(box inner_ty) => inner_ty.clone(),
DataType::Null | DataType::EmptyArray => DataType::Null,
_ => {
return Err(ErrorCode::SemanticError(
Expand All @@ -1907,9 +1895,26 @@ impl<'a> TypeChecker<'a> {
let inner_tys = if func_name == "array_reduce" {
let max_ty = self.transform_to_max_type(&inner_ty)?;
vec![max_ty.clone(), max_ty.clone()]
} else if func_name == "map_transform_keys" || func_name == "map_transform_values" {
match &inner_ty {
DataType::Tuple(t) => t.clone(),
_ => {
return Err(ErrorCode::Internal(
"map_transform_keys/map_transform_values inner_ty should be DataType::Tuple",
));
}
}
} else {
vec![inner_ty.clone()]
};
log::info!(
"{} -- inner_ty: {} -- inner_tys[0]: {} -- {}",
func_name,
inner_ty,
inner_tys[0],
inner_tys.len()
);
let tmp_ty = inner_tys[0].clone();

let columns = params
.iter()
Expand All @@ -1923,6 +1928,7 @@ impl<'a> TypeChecker<'a> {
&columns,
&lambda.expr,
)?;
log::info!("{} -- lambda_type: {}", func_name, lambda_type);

let return_type = if func_name == "array_filter" {
if lambda_type.remove_nullable() == DataType::Boolean {
Expand Down Expand Up @@ -1955,11 +1961,22 @@ impl<'a> TypeChecker<'a> {
});
}
max_ty.wrap_nullable()
} else if func_name == "map_transform_keys" {
DataType::Map(Box::new(DataType::Tuple(vec![
lambda_type.clone(),
inner_tys[1].clone(),
])))
} else if func_name == "map_transform_values" {
DataType::Map(Box::new(DataType::Tuple(vec![
inner_tys[0].clone(),
lambda_type.clone(),
])))
} else if arg_type.is_nullable() {
DataType::Nullable(Box::new(DataType::Array(Box::new(lambda_type.clone()))))
} else {
DataType::Array(Box::new(lambda_type.clone()))
};
log::info!("{} -- return_type: {}", func_name, return_type);

let (lambda_func, data_type) = match arg_type.remove_nullable() {
// Null and Empty array can convert to ConstantExpr
Expand Down Expand Up @@ -2049,8 +2066,55 @@ impl<'a> TypeChecker<'a> {
)
}
};
log::info!("{} -- data_type: {}", func_name, data_type);

Ok(Box::new((lambda_func, tmp_ty)))
}

Ok(Box::new((lambda_func, data_type)))
fn check_lambda_param_count(
&mut self,
func_name: &str,
param_count: usize,
span: Span,
) -> Result<()> {
// Note: when a new lambda function is added, it needs to be added to the HashMap.
// The key of HashMap is lambda function name, and the value is param count.
let func_to_param_count: HashMap<&str, usize> = [
("array_transform", 1),
("array_apply", 1),
("array_map", 1),
("array_filter", 1),
("array_reduce", 2),
("json_array_transform", 1),
("json_array_apply", 1),
("json_array_map", 1),
("json_array_filter", 1),
("json_array_reduce", 2),
("map_filter", 2),
("map_transform_keys", 2),
("map_transform_values", 2),
]
.iter()
.cloned()
.collect();

match func_to_param_count.get(func_name) {
Some(&expected_count) => {
if param_count != expected_count {
return Err(ErrorCode::SemanticError(format!(
"incorrect number of parameters in lambda function, {} expects {} parameter(s), but got {}",
func_name, expected_count, param_count
))
.set_span(span));
}
Ok(())
}
None => Err(ErrorCode::Internal(format!(
"not found lambda function '{}' in HashMap 'func_to_param_count'",
func_name
))
.set_span(span)),
}
}

fn resolve_score_search_function(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,28 @@ SELECT map_contains_key({'k1': 'v1', 'k2': NULL}, 'k2')

statement ok
DROP DATABASE map_func_test

query
SELECT map_transform_keys({1:1,2:2,3:3}, (k, v) -> k + 1);
----
{2:1,3:2,4:3}

query
SELECT map_transform_keys({11:111,22:222,33:333}, (k, v) -> k + 1);
----
{12:111,23:222,34:333}

query
SELECT map_transform_keys({11:111,22:222,33:333}, (k, v) -> k * 2 + 1);
----
{23:111,45:222,67:333}

query
SELECT map_transform_keys({11:111,22:222,33:333}, (k, v) -> v + 1);
----
{112:111,223:222,334:333}

query
SELECT map_transform_keys({11:111,22:222,33:333}, (k, v) -> k + v + 1);
----
{123:111,245:222,367:333}

0 comments on commit 8016144

Please sign in to comment.