Skip to content

Commit

Permalink
feat(functions): add new params format in map_insert
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxuanliang committed May 24, 2024
1 parent fa0c45d commit ea0c1d4
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 21 deletions.
90 changes: 73 additions & 17 deletions src/query/functions/src/scalars/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
use std::collections::HashSet;
use std::hash::Hash;

use databend_common_expression::types::array::ArrayColumn;
use databend_common_expression::types::map::KvColumn;
use databend_common_expression::types::map::KvPair;
use databend_common_expression::types::nullable::NullableDomain;
use databend_common_expression::types::ArgType;
use databend_common_expression::types::ArrayType;
Expand All @@ -31,8 +34,11 @@ use databend_common_expression::types::ValueType;
use databend_common_expression::vectorize_1_arg;
use databend_common_expression::vectorize_with_builder_2_arg;
use databend_common_expression::vectorize_with_builder_3_arg;
use databend_common_expression::vectorize_with_builder_4_arg;
use databend_common_expression::EvalContext;
use databend_common_expression::FunctionDomain;
use databend_common_expression::FunctionRegistry;
use databend_common_expression::ScalarRef;
use databend_common_expression::Value;
use databend_common_hashtable::StackHashSet;
use siphasher::sip128::Hasher128;
Expand Down Expand Up @@ -284,29 +290,79 @@ pub fn register(registry: &mut FunctionRegistry) {
// insert operation only works on specific key type: boolean, string, numeric, decimal, date, datetime
let key_type = &ctx.generics[0];
if !key_type.is_boolean()
&& !key_type.is_string()
&& !key_type.is_numeric()
&& !key_type.is_decimal()
&& !key_type.is_date_or_date_time() {
&& !key_type.is_string()
&& !key_type.is_numeric()
&& !key_type.is_decimal()
&& !key_type.is_date_or_date_time() {
ctx.set_error(output.len(), format!("map keys can not be {}", key_type));
}

// check if key already exists in the map
let duplicate_key = source.iter().any(|(k, _)| k == key);
let mut new_map = ArrayType::create_builder(source.len() + 1, ctx.generics);
for (k, v) in source.iter() {
if k == key {
new_map.put_item((k.clone(), value.clone()));
continue;
}
new_map.put_item((k.clone(), v.clone()));
// default behavior is to insert new key-value pair, and if the key already exists, update the value.
output.append_column(&build_new_map(&source, key, value, ctx));
}),
);

// grammar: map_insert(map, insert_key, insert_value, allow_update)
registry.register_passthrough_nullable_4_arg(
"map_insert",
|_, domain1, key_domain, value_domain, _|
FunctionDomain::Domain(match (domain1, key_domain, value_domain) {
(Some((key_domain, val_domain)), insert_key_domain, insert_value_domain) => Some((
key_domain.merge(insert_key_domain),
val_domain.merge(insert_value_domain),
)),
(None, _, _) => None,
}),
vectorize_with_builder_4_arg::<
MapType<GenericType<0>, GenericType<1>>,
GenericType<0>,
GenericType<1>,
BooleanType,
MapType<GenericType<0>, GenericType<1>>,
>(|source, key: databend_common_expression::ScalarRef, value, allow_update, output, ctx| {
let key_type = &ctx.generics[0];
if !key_type.is_boolean()
&& !key_type.is_string()
&& !key_type.is_numeric()
&& !key_type.is_decimal()
&& !key_type.is_date_or_date_time() {
ctx.set_error(output.len(), format!("map keys can not be {}", key_type));
}
if !duplicate_key {
new_map.put_item((key.clone(), value.clone()));

let duplicate_key = source.iter().any(|(k, _)| k == key);
// if duplicate_key is true and allow_update is false, return the original map
if duplicate_key && !allow_update {
let mut new_builder = ArrayType::create_builder(source.len(), ctx.generics);
source.iter().for_each(|(k, v)| new_builder.put_item((k.clone(), v.clone())));
new_builder.commit_row();
output.append_column(&new_builder.build());
return;
}
new_map.commit_row();

output.append_column(&new_map.build());
output.append_column(&build_new_map(&source, key, value, ctx));
}),
);

fn build_new_map(
source: &KvColumn<GenericType<0>, GenericType<1>>,
insert_key: ScalarRef,
insert_value: ScalarRef,
ctx: &EvalContext
) -> ArrayColumn<KvPair<GenericType<0>, GenericType<1>>> {
let duplicate_key = source.iter().any(|(k, _)| k == insert_key);
let mut new_map = ArrayType::create_builder(source.len() + 1, ctx.generics);
for (k, v) in source.iter() {
if k == insert_key {
new_map.put_item((k.clone(), insert_value.clone()));
continue;
}
new_map.put_item((k.clone(), v.clone()));
}
if !duplicate_key {
new_map.put_item((insert_key.clone(), insert_value.clone()));
}
new_map.commit_row();

new_map.build()
}
}
20 changes: 20 additions & 0 deletions src/query/functions/tests/it/scalars/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,16 @@ fn test_map_size(file: &mut impl Write) {
fn test_map_insert(file: &mut impl Write) {
run_ast(file, "map_insert({}, 'k1', 'v1')", &[]);
run_ast(file, "map_insert({'k1': 'v1'}, 'k2', 'v2')", &[]);
run_ast(
file,
"map_insert({'k1': 'v1', 'k2': 'v2'}, 'k1', 'v10', false)",
&[],
);
run_ast(
file,
"map_insert({'k1': 'v1', 'k2': 'v2'}, 'k1', 'v10', true)",
&[],
);

let columns = [
("a_col", StringType::from_data(vec!["a", "b", "c"])),
Expand All @@ -306,4 +316,14 @@ fn test_map_insert(file: &mut impl Write) {
"map_insert(map([a_col, b_col, c_col], [d_col, e_col, f_col]), 'k1', 'v10')",
&columns,
);
run_ast(
file,
"map_insert(map([a_col, b_col, c_col], [d_col, e_col, f_col]), 'a', 'v10', true)",
&columns,
);
run_ast(
file,
"map_insert(map([a_col, b_col, c_col], [d_col, e_col, f_col]), 'a', 'v10', false)",
&columns,
);
}
Loading

0 comments on commit ea0c1d4

Please sign in to comment.