From 0492af6a5ac0076ced5ef53f17efad04f9f54c25 Mon Sep 17 00:00:00 2001 From: xufei Date: Wed, 14 Aug 2019 14:04:16 +0800 Subject: [PATCH] support udf in (#175) * fix cop test regression * address comments * format code * fix npe for dag execute * format code * address comment * add some comments * throw exception when meet error duing cop request handling * address comments * add error code * throw exception when meet error duing cop request handling * address comments * add DAGContext so InterpreterDAG can exchange information with DAGDriver * fix bug * 1. refine code, 2. address comments * update comments * columnref index is based on executor output schema * handle error in coprocessor request * refine code * use Clear to clear a protobuf message completely * refine code * code refine && several minor bug fix * address comments * address comments * support udf in * refine code * address comments * address comments --- .../Coprocessor/DAGExpressionAnalyzer.cpp | 60 +++++++++++++-- .../Flash/Coprocessor/DAGExpressionAnalyzer.h | 6 ++ .../Flash/Coprocessor/DAGStringConverter.cpp | 2 +- dbms/src/Flash/Coprocessor/DAGUtils.cpp | 61 +++++++++------- dbms/src/Flash/Coprocessor/DAGUtils.h | 3 +- dbms/src/Flash/Coprocessor/InterpreterDAG.cpp | 15 ++-- dbms/src/Flash/Coprocessor/InterpreterDAG.h | 3 + dbms/src/Flash/Coprocessor/tests/cop_test.cpp | 73 ++++++++++++++++++- dbms/src/Interpreters/Set.cpp | 39 ++++++++++ dbms/src/Interpreters/Set.h | 10 +++ 10 files changed, 224 insertions(+), 48 deletions(-) diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp index 5b8b5fa9165..d2dda6a5bb7 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp @@ -1,10 +1,13 @@ #include #include +#include +#include #include #include #include #include +#include #include #include #include @@ -251,6 +254,32 @@ String DAGExpressionAnalyzer::appendCastIfNeeded(const tipb::Expr & expr, Expres return expr_name; } +void DAGExpressionAnalyzer::makeExplicitSet( + const tipb::Expr & expr, const Block & sample_block, bool create_ordered_set, const String & left_arg_name) +{ + if (prepared_sets.count(&expr)) + { + return; + } + DataTypes set_element_types; + // todo support tuple in, i.e. (a,b) in ((1,2), (3,4)), currently TiDB convert tuple in into a series of or/and/eq exprs + // which means tuple in is never be pushed to coprocessor, but it is quite in-efficient + set_element_types.push_back(sample_block.getByName(left_arg_name).type); + + // todo if this is a single value in, then convert it to equal expr + SetPtr set = std::make_shared(SizeLimits(settings.max_rows_in_set, settings.max_bytes_in_set, settings.set_overflow_mode)); + set->createFromDAGExpr(set_element_types, expr, create_ordered_set); + prepared_sets[&expr] = std::move(set); +} + +static String getUniqueName(const Block & block, const String & prefix) +{ + int i = 1; + while (block.has(prefix + toString(i))) + ++i; + return prefix + toString(i); +} + String DAGExpressionAnalyzer::getActions(const tipb::Expr & expr, ExpressionActionsPtr & actions) { String expr_name = getName(expr, getCurrentInputColumns()); @@ -288,20 +317,35 @@ String DAGExpressionAnalyzer::getActions(const tipb::Expr & expr, ExpressionActi throw Exception("agg function is not supported yet", ErrorCodes::UNSUPPORTED_METHOD); } const String & func_name = getFunctionName(expr); - if (func_name == "in" || func_name == "notIn" || func_name == "globalIn" || func_name == "globalNotIn") - { - // todo support in - throw Exception(func_name + " is not supported yet", ErrorCodes::UNSUPPORTED_METHOD); - } - const FunctionBuilderPtr & function_builder = FunctionFactory::instance().get(func_name, context); Names argument_names; DataTypes argument_types; - for (auto & child : expr.children()) + + if (isInOrGlobalInOperator(func_name)) { - String name = getActions(child, actions); + String name = getActions(expr.children(0), actions); argument_names.push_back(name); argument_types.push_back(actions->getSampleBlock().getByName(name).type); + makeExplicitSet(expr, actions->getSampleBlock(), false, name); + ColumnWithTypeAndName column; + column.type = std::make_shared(); + + const SetPtr & set = prepared_sets[&expr]; + + column.name = getUniqueName(actions->getSampleBlock(), "___set"); + column.column = ColumnSet::create(1, set); + actions->add(ExpressionAction::addColumn(column)); + argument_names.push_back(column.name); + argument_types.push_back(column.type); + } + else + { + for (auto & child : expr.children()) + { + String name = getActions(child, actions); + argument_names.push_back(name); + argument_types.push_back(actions->getSampleBlock().getByName(name).type); + } } // re-construct expr_name, because expr_name generated previously is based on expr tree, diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h index cdc3acbac5b..959729886c7 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h @@ -13,6 +13,10 @@ namespace DB { +class Set; +using SetPtr = std::shared_ptr; +using DAGPreparedSets = std::unordered_map; + /** Transforms an expression from DAG expression into a sequence of actions to execute it. * */ @@ -24,6 +28,7 @@ class DAGExpressionAnalyzer : private boost::noncopyable NamesAndTypesList source_columns; // all columns after aggregation NamesAndTypesList aggregated_columns; + DAGPreparedSets prepared_sets; Settings settings; const Context & context; bool after_agg; @@ -47,6 +52,7 @@ class DAGExpressionAnalyzer : private boost::noncopyable void appendFinalProject(ExpressionActionsChain & chain, const NamesWithAliases & final_project); String getActions(const tipb::Expr & expr, ExpressionActionsPtr & actions); const NamesAndTypesList & getCurrentInputColumns(); + void makeExplicitSet(const tipb::Expr & expr, const Block & sample_block, bool create_ordered_set, const String & left_arg_name); }; } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/DAGStringConverter.cpp b/dbms/src/Flash/Coprocessor/DAGStringConverter.cpp index 36ab11801b9..4a11d21f075 100644 --- a/dbms/src/Flash/Coprocessor/DAGStringConverter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGStringConverter.cpp @@ -58,7 +58,7 @@ void DAGStringConverter::buildTSString(const tipb::TableScan & ts, std::stringst String name = merge_tree->getTableInfo().columns[cid - 1].name; output_from_ts.push_back(std::move(name)); } - ss << "FROM " << merge_tree->getTableInfo().db_name << "." << merge_tree->getTableInfo().name << " "; + ss << "FROM " << storage->getDatabaseName() << "." << storage->getTableName() << " "; } void DAGStringConverter::buildSelString(const tipb::Selection & sel, std::stringstream & ss) diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index 0ed3db9dfc7..d46bf5acf5e 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -9,6 +9,12 @@ namespace DB { +namespace ErrorCodes +{ +extern const int COP_BAD_DAG_REQUEST; +extern const int UNSUPPORTED_METHOD; +} // namespace ErrorCodes + bool isFunctionExpr(const tipb::Expr & expr) { switch (expr.tp()) @@ -43,7 +49,7 @@ const String & getAggFunctionName(const tipb::Expr & expr) { if (!aggFunMap.count(expr.tp())) { - throw Exception(tipb::ExprType_Name(expr.tp()) + " is not supported."); + throw Exception(tipb::ExprType_Name(expr.tp()) + " is not supported.", ErrorCodes::UNSUPPORTED_METHOD); } return aggFunMap[expr.tp()]; } @@ -54,7 +60,7 @@ const String & getFunctionName(const tipb::Expr & expr) { if (!aggFunMap.count(expr.tp())) { - throw Exception(tipb::ExprType_Name(expr.tp()) + " is not supported."); + throw Exception(tipb::ExprType_Name(expr.tp()) + " is not supported.", ErrorCodes::UNSUPPORTED_METHOD); } return aggFunMap[expr.tp()]; } @@ -62,13 +68,13 @@ const String & getFunctionName(const tipb::Expr & expr) { if (!scalarFunMap.count(expr.sig())) { - throw Exception(tipb::ScalarFuncSig_Name(expr.sig()) + " is not supported."); + throw Exception(tipb::ScalarFuncSig_Name(expr.sig()) + " is not supported.", ErrorCodes::UNSUPPORTED_METHOD); } return scalarFunMap[expr.sig()]; } } -String exprToString(const tipb::Expr & expr, const NamesAndTypesList & input_col) +String exprToString(const tipb::Expr & expr, const NamesAndTypesList & input_col, bool for_parser) { std::stringstream ss; size_t cursor = 1; @@ -94,7 +100,7 @@ String exprToString(const tipb::Expr & expr, const NamesAndTypesList & input_col columnId = DecodeInt(cursor, expr.val()); if (columnId < 0 || columnId >= (ColumnID)input_col.size()) { - throw Exception("out of bound"); + throw Exception("Column id out of bound", ErrorCodes::COP_BAD_DAG_REQUEST); } return input_col.getNames()[columnId]; case tipb::ExprType::Count: @@ -105,53 +111,50 @@ String exprToString(const tipb::Expr & expr, const NamesAndTypesList & input_col case tipb::ExprType::First: if (!aggFunMap.count(expr.tp())) { - throw Exception("not supported"); + throw Exception(tipb::ExprType_Name(expr.tp()) + "not supported", ErrorCodes::UNSUPPORTED_METHOD); } func_name = aggFunMap.find(expr.tp())->second; break; case tipb::ExprType::ScalarFunc: if (!scalarFunMap.count(expr.sig())) { - throw Exception("not supported"); + throw Exception(tipb::ScalarFuncSig_Name(expr.sig()) + "not supported", ErrorCodes::UNSUPPORTED_METHOD); } func_name = scalarFunMap.find(expr.sig())->second; break; default: - throw Exception("not supported"); + throw Exception(tipb::ExprType_Name(expr.tp()) + "not supported", ErrorCodes::UNSUPPORTED_METHOD); } // build function expr - if (func_name == "in") + if (isInOrGlobalInOperator(func_name) && for_parser) { // for in, we could not represent the function expr using func_name(param1, param2, ...) - throw Exception("not supported"); + throw Exception("Function " + func_name + " not supported", ErrorCodes::UNSUPPORTED_METHOD); } - else + ss << func_name << "("; + bool first = true; + for (const tipb::Expr & child : expr.children()) { - ss << func_name << "("; - bool first = true; - for (const tipb::Expr & child : expr.children()) + String s = exprToString(child, input_col, for_parser); + if (first) { - String s = exprToString(child, input_col); - if (first) - { - first = false; - } - else - { - ss << ", "; - } - ss << s; + first = false; } - ss << ") "; - return ss.str(); + else + { + ss << ", "; + } + ss << s; } + ss << ") "; + return ss.str(); } const String & getTypeName(const tipb::Expr & expr) { return tipb::ExprType_Name(expr.tp()); } String getName(const tipb::Expr & expr, const NamesAndTypesList & current_input_columns) { - return exprToString(expr, current_input_columns); + return exprToString(expr, current_input_columns, false); } bool isAggFunctionExpr(const tipb::Expr & expr) @@ -225,7 +228,7 @@ Field decodeLiteral(const tipb::Expr & expr) case tipb::ExprType::MysqlTime: case tipb::ExprType::MysqlJson: case tipb::ExprType::ValueList: - throw Exception("mysql type literal is not supported yet"); + throw Exception(tipb::ExprType_Name(expr.tp()) + "is not supported yet", ErrorCodes::UNSUPPORTED_METHOD); default: return DecodeDatum(cursor, expr.val()); } @@ -237,6 +240,8 @@ ColumnID getColumnID(const tipb::Expr & expr) return DecodeInt(cursor, expr.val()); } +bool isInOrGlobalInOperator(const String & name) { return name == "in" || name == "notIn" || name == "globalIn" || name == "globalNotIn"; } + std::unordered_map aggFunMap({ {tipb::ExprType::Count, "count"}, {tipb::ExprType::Sum, "sum"}, {tipb::ExprType::Avg, "avg"}, {tipb::ExprType::Min, "min"}, {tipb::ExprType::Max, "max"}, {tipb::ExprType::First, "any"}, diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.h b/dbms/src/Flash/Coprocessor/DAGUtils.h index 71a52533ea3..ec6b96d2fbb 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.h +++ b/dbms/src/Flash/Coprocessor/DAGUtils.h @@ -24,7 +24,8 @@ bool isColumnExpr(const tipb::Expr & expr); ColumnID getColumnID(const tipb::Expr & expr); String getName(const tipb::Expr & expr, const NamesAndTypesList & current_input_columns); const String & getTypeName(const tipb::Expr & expr); -String exprToString(const tipb::Expr & expr, const NamesAndTypesList & input_col); +String exprToString(const tipb::Expr & expr, const NamesAndTypesList & input_col, bool for_parser = true); +bool isInOrGlobalInOperator(const String & name); extern std::unordered_map aggFunMap; extern std::unordered_map scalarFunMap; diff --git a/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp b/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp index be2f8700a04..cbc95e795e9 100644 --- a/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp +++ b/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp @@ -82,6 +82,8 @@ void InterpreterDAG::executeTS(const tipb::TableScan & ts, Pipeline & pipeline) throw Exception("No column is selected in table scan executor", ErrorCodes::COP_BAD_DAG_REQUEST); } + analyzer = std::make_unique(source_columns, context); + if (!dag.hasAggregation()) { // if the dag request does not contain agg, then the final output is @@ -175,10 +177,9 @@ InterpreterDAG::AnalysisResult InterpreterDAG::analyzeExpressions() { AnalysisResult res; ExpressionActionsChain chain; - DAGExpressionAnalyzer analyzer(source_columns, context); if (dag.hasSelection()) { - analyzer.appendWhere(chain, dag.getSelection(), res.filter_column_name); + analyzer->appendWhere(chain, dag.getSelection(), res.filter_column_name); res.has_where = true; res.before_where = chain.getLastActions(); chain.addStep(); @@ -186,7 +187,7 @@ InterpreterDAG::AnalysisResult InterpreterDAG::analyzeExpressions() // There will be either Agg... if (dag.hasAggregation()) { - analyzer.appendAggregation(chain, dag.getAggregation(), res.aggregation_keys, res.aggregate_descriptions); + analyzer->appendAggregation(chain, dag.getAggregation(), res.aggregation_keys, res.aggregate_descriptions); res.need_aggregate = true; res.before_aggregation = chain.getLastActions(); @@ -194,9 +195,9 @@ InterpreterDAG::AnalysisResult InterpreterDAG::analyzeExpressions() chain.clear(); // add cast if type is not match - analyzer.appendAggSelect(chain, dag.getAggregation()); + analyzer->appendAggSelect(chain, dag.getAggregation()); //todo use output_offset to reconstruct the final project columns - for (auto element : analyzer.getCurrentInputColumns()) + for (auto element : analyzer->getCurrentInputColumns()) { final_project.emplace_back(element.name, ""); } @@ -205,10 +206,10 @@ InterpreterDAG::AnalysisResult InterpreterDAG::analyzeExpressions() if (dag.hasTopN()) { res.has_order_by = true; - analyzer.appendOrderBy(chain, dag.getTopN(), res.order_column_names); + analyzer->appendOrderBy(chain, dag.getTopN(), res.order_column_names); } // Append final project results if needed. - analyzer.appendFinalProject(chain, final_project); + analyzer->appendFinalProject(chain, final_project); res.before_order_and_select = chain.getLastActions(); chain.finalize(); chain.clear(); diff --git a/dbms/src/Flash/Coprocessor/InterpreterDAG.h b/dbms/src/Flash/Coprocessor/InterpreterDAG.h index 099e1382e8d..22ba126df96 100644 --- a/dbms/src/Flash/Coprocessor/InterpreterDAG.h +++ b/dbms/src/Flash/Coprocessor/InterpreterDAG.h @@ -7,6 +7,7 @@ #pragma GCC diagnostic pop #include +#include #include #include #include @@ -98,6 +99,8 @@ class InterpreterDAG : public IInterpreter TMTStoragePtr storage; TableStructureReadLockPtr table_lock; + std::unique_ptr analyzer; + Poco::Logger * log; }; } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/tests/cop_test.cpp b/dbms/src/Flash/Coprocessor/tests/cop_test.cpp index e18c3c4dd74..267056e018c 100644 --- a/dbms/src/Flash/Coprocessor/tests/cop_test.cpp +++ b/dbms/src/Flash/Coprocessor/tests/cop_test.cpp @@ -56,6 +56,7 @@ class FlashClient grpc::ClientContext clientContext; clientContext.AddMetadata("user_name", ""); clientContext.AddMetadata("dag_planner", "optree"); + clientContext.AddMetadata("dag_expr_field_type_strict_check", "0"); coprocessor::Response response; grpc::Status status = sp->Coprocessor(&clientContext, *rqst, &response); if (status.ok()) @@ -64,6 +65,12 @@ class FlashClient tipb::SelectResponse selectResponse; if (selectResponse.ParseFromString(response.data())) { + if (selectResponse.has_error()) + { + std::cout << "Coprocessor request failed, error code " << selectResponse.error().code() << " error msg " + << selectResponse.error().msg(); + return status; + } for (const tipb::Chunk & chunk : selectResponse.chunks()) { size_t cursor = 0; @@ -148,6 +155,66 @@ void appendSelection(tipb::DAGRequest & dag_request) type = expr->mutable_field_type(); type->set_tp(1); type->set_flag(1 << 5); + + // selection i in (5,10,11) + selection->clear_conditions(); + expr = selection->add_conditions(); + expr->set_tp(tipb::ExprType::ScalarFunc); + expr->set_sig(tipb::ScalarFuncSig::InInt); + col = expr->add_children(); + col->set_tp(tipb::ExprType::ColumnRef); + ss.str(""); + DB::EncodeNumber(1, ss); + col->set_val(ss.str()); + type = col->mutable_field_type(); + type->set_tp(8); + type->set_flag(0); + value = expr->add_children(); + value->set_tp(tipb::ExprType::Int64); + ss.str(""); + DB::EncodeNumber(10, ss); + value->set_val(std::string(ss.str())); + type = value->mutable_field_type(); + type->set_tp(8); + type->set_flag(1); + type = expr->mutable_field_type(); + type->set_tp(1); + type->set_flag(1 << 5); + value = expr->add_children(); + value->set_tp(tipb::ExprType::Int64); + ss.str(""); + DB::EncodeNumber(5, ss); + value->set_val(std::string(ss.str())); + type = value->mutable_field_type(); + type->set_tp(8); + type->set_flag(1); + type = expr->mutable_field_type(); + type->set_tp(1); + type->set_flag(1 << 5); + value = expr->add_children(); + value->set_tp(tipb::ExprType::Int64); + ss.str(""); + DB::EncodeNumber(11, ss); + value->set_val(std::string(ss.str())); + type = value->mutable_field_type(); + type->set_tp(8); + type->set_flag(1); + type = expr->mutable_field_type(); + type->set_tp(1); + type->set_flag(1 << 5); + + // selection i is null + /* + selection->clear_conditions(); + expr = selection->add_conditions(); + expr->set_tp(tipb::ExprType::ScalarFunc); + expr->set_sig(tipb::ScalarFuncSig::IntIsNull); + col = expr->add_children(); + col->set_tp(tipb::ExprType::ColumnRef); + ss.str(""); + DB::EncodeNumber(1, ss); + col->set_val(ss.str()); + */ } void appendAgg(tipb::DAGRequest & dag_request, size_t & result_field_num) @@ -208,9 +275,9 @@ grpc::Status rpcTest() ChannelPtr cp = grpc::CreateChannel("localhost:9093", grpc::InsecureChannelCredentials()); ClientPtr clientPtr = std::make_shared(cp); size_t result_field_num = 0; - bool has_selection = false; - bool has_agg = true; - bool has_topN = false; + bool has_selection = true; + bool has_agg = false; + bool has_topN = true; bool has_limit = false; // construct a dag request tipb::DAGRequest dagRequest; diff --git a/dbms/src/Interpreters/Set.cpp b/dbms/src/Interpreters/Set.cpp index 925479e05e1..27e8757c658 100644 --- a/dbms/src/Interpreters/Set.cpp +++ b/dbms/src/Interpreters/Set.cpp @@ -12,6 +12,8 @@ #include #include +#include + #include #include #include @@ -22,6 +24,7 @@ #include #include +#include namespace DB @@ -34,6 +37,7 @@ namespace ErrorCodes extern const int TYPE_MISMATCH; extern const int INCORRECT_ELEMENT_OF_SET; extern const int NUMBER_OF_COLUMNS_DOESNT_MATCH; + extern const int COP_BAD_DAG_REQUEST; } @@ -256,6 +260,41 @@ void Set::createFromAST(const DataTypes & types, ASTPtr node, const Context & co insertFromBlock(block, fill_set_elements); } +void Set::createFromDAGExpr(const DataTypes & types, const tipb::Expr & expr, bool fill_set_elements) +{ + /// Will form a block with values from the set. + + Block header; + size_t num_columns = types.size(); + if (num_columns != 1) + { + throw Exception("Incorrect element of set, tuple in is not supported yet", ErrorCodes::INCORRECT_ELEMENT_OF_SET); + } + for (size_t i = 0; i < num_columns; ++i) + header.insert(ColumnWithTypeAndName(types[i]->createColumn(), types[i], "_" + toString(i))); + setHeader(header); + + MutableColumns columns = header.cloneEmptyColumns(); + + for (int i = 1; i < expr.children_size(); i++) + { + auto & child = expr.children(i); + // todo support constant expression by constant folding + if (!isLiteralExpr(child)) + { + throw Exception("Only literal is supported in children of expr `in`", ErrorCodes::COP_BAD_DAG_REQUEST); + } + Field value = decodeLiteral(child); + DataTypePtr type = child.has_field_type() ? getDataTypeByFieldType(child.field_type()) : types[0]; + value = convertFieldToType(value, *type); + + if (!value.isNull()) + columns[0]->insert(value); + } + + Block block = header.cloneWithColumns(std::move(columns)); + insertFromBlock(block, fill_set_elements); +} ColumnPtr Set::execute(const Block & block, bool negative) const { diff --git a/dbms/src/Interpreters/Set.h b/dbms/src/Interpreters/Set.h index e27bdf58ec6..9600ed2065f 100644 --- a/dbms/src/Interpreters/Set.h +++ b/dbms/src/Interpreters/Set.h @@ -1,5 +1,10 @@ #pragma once +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#include +#pragma GCC diagnostic pop + #include #include #include @@ -48,6 +53,11 @@ class Set */ void createFromAST(const DataTypes & types, ASTPtr node, const Context & context, bool fill_set_elements); + /** + * Create a Set from DAG Expr, used when processing DAG Request + */ + void createFromDAGExpr(const DataTypes & types, const tipb::Expr & expr, bool fill_set_elements); + /** Create a Set from stream. * Call setHeader, then call insertFromBlock for each block. */