Skip to content

Commit

Permalink
remove duplicate agg funcs (#283)
Browse files Browse the repository at this point in the history
* 1. remove duplicate agg funcs, 2. for column ref expr, change column_id to column_index since the value stored in column ref expr is not column id

* bug fix
  • Loading branch information
windtalker authored Oct 15, 2019
1 parent bc075c5 commit fbcbdc0
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 50 deletions.
37 changes: 19 additions & 18 deletions dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ static String genFuncString(const String & func_name, const Names & argument_nam
return ss.str();
}

DAGExpressionAnalyzer::DAGExpressionAnalyzer(const std::vector<NameAndTypePair> && source_columns_, const Context & context_)
: source_columns(source_columns_),
DAGExpressionAnalyzer::DAGExpressionAnalyzer(std::vector<NameAndTypePair> && source_columns_, const Context & context_)
: source_columns(std::move(source_columns_)),
context(context_),
after_agg(false),
implicit_cast_count(0),
Expand All @@ -68,7 +68,6 @@ void DAGExpressionAnalyzer::appendAggregation(
initChain(chain, getCurrentInputColumns());
ExpressionActionsChain::Step & step = chain.steps.back();

Names agg_argument_names;
for (const tipb::Expr & expr : agg.agg_func())
{
const String & agg_func_name = getAggFunctionName(expr);
Expand All @@ -78,13 +77,24 @@ void DAGExpressionAnalyzer::appendAggregation(
for (Int32 i = 0; i < expr.children_size(); i++)
{
String arg_name = getActions(expr.children(i), step.actions);
agg_argument_names.push_back(arg_name);
types[i] = step.actions->getSampleBlock().getByName(arg_name).type;
aggregate.argument_names[i] = arg_name;
step.required_output.push_back(arg_name);
}
String func_string = genFuncString(agg_func_name, agg_argument_names);
String func_string = genFuncString(agg_func_name, aggregate.argument_names);
bool duplicate = false;
for (const auto & pre_agg : aggregate_descriptions)
{
if (pre_agg.column_name == func_string)
{
aggregated_columns.emplace_back(func_string, pre_agg.function->getReturnType());
duplicate = true;
break;
}
}
if (duplicate)
continue;
aggregate.column_name = func_string;
//todo de-duplicate aggregation column
aggregate.parameters = Array();
aggregate.function = AggregateFunctionFactory::instance().get(agg_func_name, types);
aggregate_descriptions.push_back(aggregate);
Expand All @@ -93,8 +103,6 @@ void DAGExpressionAnalyzer::appendAggregation(
aggregated_columns.emplace_back(func_string, result_type);
}

std::move(agg_argument_names.begin(), agg_argument_names.end(), std::back_inserter(step.required_output));

for (const tipb::Expr & expr : agg.group_by())
{
String name = getActions(expr, step.actions);
Expand Down Expand Up @@ -286,7 +294,7 @@ void DAGExpressionAnalyzer::appendAggSelect(
{
initChain(chain, getCurrentInputColumns());
bool need_update_aggregated_columns = false;
NamesAndTypesList updated_aggregated_columns;
std::vector<NameAndTypePair> updated_aggregated_columns;
ExpressionActionsChain::Step step = chain.steps.back();
bool need_append_timezone_cast = hasMeaningfulTZInfo(rqst);
tipb::Expr tz_expr;
Expand Down Expand Up @@ -344,12 +352,10 @@ void DAGExpressionAnalyzer::appendAggSelect(

if (need_update_aggregated_columns)
{
auto updated_agg_col_names = updated_aggregated_columns.getNames();
auto updated_agg_col_types = updated_aggregated_columns.getTypes();
aggregated_columns.clear();
for (size_t i = 0; i < updated_aggregated_columns.size(); i++)
{
aggregated_columns.emplace_back(updated_agg_col_names[i], updated_agg_col_types[i]);
aggregated_columns.emplace_back(updated_aggregated_columns[i].name, updated_aggregated_columns[i].type);
}
}
}
Expand Down Expand Up @@ -471,13 +477,8 @@ String DAGExpressionAnalyzer::getActions(const tipb::Expr & expr, ExpressionActi
}
else if (isColumnExpr(expr))
{
ColumnID column_id = getColumnID(expr);
if (column_id < 0 || column_id >= (ColumnID)getCurrentInputColumns().size())
{
throw Exception("column id out of bound", ErrorCodes::COP_BAD_DAG_REQUEST);
}
//todo check if the column type need to be cast to field type
return getCurrentInputColumns()[column_id].name;
return getColumnNameForColumnExpr(expr, getCurrentInputColumns());
}
else if (isFunctionExpr(expr))
{
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DAGExpressionAnalyzer : private boost::noncopyable
Poco::Logger * log;

public:
DAGExpressionAnalyzer(const std::vector<NameAndTypePair> && source_columns_, const Context & context_);
DAGExpressionAnalyzer(std::vector<NameAndTypePair> && source_columns_, const Context & context_);
void appendWhere(ExpressionActionsChain & chain, const tipb::Selection & sel, String & filter_column_name);
void appendOrderBy(ExpressionActionsChain & chain, const tipb::TopN & topN, Strings & order_column_names);
void appendAggregation(ExpressionActionsChain & chain, const tipb::Aggregation & agg, Names & aggregate_keys,
Expand Down
10 changes: 3 additions & 7 deletions dbms/src/Flash/Coprocessor/DAGQueryInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,10 @@ namespace DB

struct DAGQueryInfo
{
DAGQueryInfo(const DAGQuerySource & dag_, DAGPreparedSets dag_sets_, std::vector<NameAndTypePair> & source_columns_)
: dag(dag_), dag_sets(std::move(dag_sets_))
{
for (auto & c : source_columns_)
source_columns.emplace_back(c.name, c.type);
};
DAGQueryInfo(const DAGQuerySource & dag_, DAGPreparedSets dag_sets_, const std::vector<NameAndTypePair> & source_columns_)
: dag(dag_), dag_sets(std::move(dag_sets_)), source_columns(source_columns_){};
const DAGQuerySource & dag;
DAGPreparedSets dag_sets;
NamesAndTypesList source_columns;
const std::vector<NameAndTypePair> & source_columns;
};
} // namespace DB
18 changes: 8 additions & 10 deletions dbms/src/Flash/Coprocessor/DAGUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ const String & getFunctionName(const tipb::Expr & expr)
String exprToString(const tipb::Expr & expr, const std::vector<NameAndTypePair> & input_col)
{
std::stringstream ss;
Int64 column_id = 0;
String func_name;
Field f;
switch (expr.tp())
Expand Down Expand Up @@ -94,12 +93,7 @@ String exprToString(const tipb::Expr & expr, const std::vector<NameAndTypePair>
return std::to_string(TiDB::DatumFlat(t, static_cast<TiDB::TP>(expr.field_type().tp())).field().get<Int64>());
}
case tipb::ExprType::ColumnRef:
column_id = decodeDAGInt64(expr.val());
if (column_id < 0 || column_id >= (ColumnID)input_col.size())
{
throw Exception("Column id out of bound", ErrorCodes::COP_BAD_DAG_REQUEST);
}
return input_col[column_id].name;
return getColumnNameForColumnExpr(expr, input_col);
case tipb::ExprType::Count:
case tipb::ExprType::Sum:
case tipb::ExprType::Avg:
Expand Down Expand Up @@ -247,10 +241,14 @@ Field decodeLiteral(const tipb::Expr & expr)
}
}

ColumnID getColumnID(const tipb::Expr & expr)
String getColumnNameForColumnExpr(const tipb::Expr & expr, const std::vector<NameAndTypePair> & input_col)
{
auto column_id = decodeDAGInt64(expr.val());
return column_id;
auto column_index = decodeDAGInt64(expr.val());
if (column_index < 0 || column_index >= (Int64)input_col.size())
{
throw Exception("Column index out of bound", ErrorCodes::COP_BAD_DAG_REQUEST);
}
return input_col[column_index].name;
}

bool isInOrGlobalInOperator(const String & name) { return name == "in" || name == "notIn" || name == "globalIn" || name == "globalNotIn"; }
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Flash/Coprocessor/DAGUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ bool isAggFunctionExpr(const tipb::Expr & expr);
const String & getFunctionName(const tipb::Expr & expr);
const String & getAggFunctionName(const tipb::Expr & expr);
bool isColumnExpr(const tipb::Expr & expr);
ColumnID getColumnID(const tipb::Expr & expr);
String getColumnNameForColumnExpr(const tipb::Expr & expr, const std::vector<NameAndTypePair> & input_col);
const String & getTypeName(const tipb::Expr & expr);
String exprToString(const tipb::Expr & expr, const std::vector<NameAndTypePair> & input_col);
bool isInOrGlobalInOperator(const String & name);
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Flash/Coprocessor/InterpreterDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ void InterpreterDAG::executeTS(const tipb::TableScan & ts, Pipeline & pipeline)
SelectQueryInfo query_info;
// set query to avoid unexpected NPE
query_info.query = dag.getAST();
query_info.dag_query = std::make_unique<DAGQueryInfo>(dag, analyzer->getPreparedSets(), source_columns);
query_info.dag_query = std::make_unique<DAGQueryInfo>(dag, analyzer->getPreparedSets(), analyzer->getCurrentInputColumns());
query_info.mvcc_query_info = std::make_unique<MvccQueryInfo>();
query_info.mvcc_query_info->resolve_locks = true;
query_info.mvcc_query_info->read_tso = settings.read_tso;
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Storages/MergeTree/KeyCondition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ KeyCondition::KeyCondition(

if (query_info.fromAST())
{
RPNBuilder<ASTPtr, PreparedSets> rpn_builder(key_expr_, key_columns, all_columns);
RPNBuilder<ASTPtr, PreparedSets> rpn_builder(key_expr_, key_columns, {});
PreparedSets sets(query_info.sets);

/** Evaluation of expressions that depend only on constants.
Expand Down
13 changes: 4 additions & 9 deletions dbms/src/Storages/MergeTree/RPNBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,14 @@ const String getFuncName(const ASTPtr & node)
return "";
}

const String getColumnName(const tipb::Expr & node, const NamesAndTypesList & source_columns)
const String getColumnName(const tipb::Expr & node, const std::vector<NameAndTypePair> & source_columns)
{
if (node.tp() == tipb::ExprType::ColumnRef)
{
auto col_id = getColumnID(node);
if (col_id < 0 || col_id >= (Int64)source_columns.size())
return "";
return source_columns.getNames()[col_id];
}
if (isColumnExpr(node))
return getColumnNameForColumnExpr(node, source_columns);
return "";
}

const String getColumnName(const ASTPtr & node, const NamesAndTypesList &) { return node->getColumnName(); }
const String getColumnName(const ASTPtr & node, const std::vector<NameAndTypePair> &) { return node->getColumnName(); }

bool isFuncNode(const ASTPtr & node) { return typeid_cast<const ASTFunction *>(node.get()); }

Expand Down
4 changes: 2 additions & 2 deletions dbms/src/Storages/MergeTree/RPNBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ template <typename NodeT, typename PreparedSetsT>
class RPNBuilder
{
public:
RPNBuilder(const ExpressionActionsPtr & key_expr_, ColumnIndices & key_columns_, const NamesAndTypesList & source_columns_)
RPNBuilder(const ExpressionActionsPtr & key_expr_, ColumnIndices & key_columns_, const std::vector<NameAndTypePair> & source_columns_)
: key_expr(key_expr_), key_columns(key_columns_), source_columns(source_columns_)
{}

Expand Down Expand Up @@ -62,6 +62,6 @@ class RPNBuilder
protected:
const ExpressionActionsPtr & key_expr;
ColumnIndices & key_columns;
const NamesAndTypesList & source_columns;
const std::vector<NameAndTypePair> & source_columns;
};
} // namespace DB
6 changes: 6 additions & 0 deletions tests/mutable-test/txn_dag/aggregation.test
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
│ 1 │ 777 │
└──────────────┴───────┘

=> DBGInvoke dag('select count(col_1),count(col_1) from default.test group by col_2')
┌─count(col_1)─┬─count(col_1)─┬─col_2─┐
│ 2 │ 2 │ 666 │
│ 1 │ 1 │ 777 │
└──────────────┴──────────────┴───────┘

# DAG read by explicitly specifying region id, where + group by.
=> DBGInvoke dag('select count(col_1) from default.test where col_2 = 666 group by col_2', 4)
┌─count(col_1)─┬─col_2─┐
Expand Down

0 comments on commit fbcbdc0

Please sign in to comment.