Skip to content

Commit

Permalink
fix: aggregation empty row on join with grouping and aggregations (#1…
Browse files Browse the repository at this point in the history
…0480) (#10554)

Signed-off-by: Harshit Gangal <[email protected]>
  • Loading branch information
harshit-gangal authored Jun 21, 2022
1 parent d32e6eb commit 47eec73
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 40 deletions.
13 changes: 13 additions & 0 deletions go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,4 +354,17 @@ func TestEmptyTableAggr(t *testing.T) {
mcmp.AssertMatches(" select /*vt+ PLANNER=gen4 */ t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
})
}

mcmp.Exec("insert into t1(t1_id, `name`, `value`, shardkey) values(1,'a1','foo',100), (2,'b1','foo',200), (3,'c1','foo',300), (3,'a1','foo',100), (3,'b1','bar',200)")

for _, workload := range []string{"oltp", "olap"} {
t.Run(workload, func(t *testing.T) {
utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", workload))
mcmp.AssertMatches(" select /*vt+ PLANNER=gen4 */ count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select /*vt+ PLANNER=gen4 */ count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo'", "[[INT64(0)]]")
mcmp.AssertMatches(" select /*vt+ PLANNER=gen4 */ t1.`name`, count(*) from t1 inner join t2 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
mcmp.AssertMatches(" select /*vt+ PLANNER=gen4 */ t1.`name`, count(*) from t2 inner join t1 on (t1.t1_id = t2.id) where t1.value = 'foo' group by t1.`name`", "[]")
})
}

}
20 changes: 19 additions & 1 deletion go/vt/vtgate/planbuilder/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package planbuilder

import (
"strconv"

vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
Expand Down Expand Up @@ -162,7 +164,7 @@ func pushAggrOnRoute(
pos = newOffset(groupingCols[idx])
}

if ctx.SemTable.NeedsWeightString(expr.Inner) {
if expr.WeightStrExpr != nil && ctx.SemTable.NeedsWeightString(expr.Inner) {
wsExpr := weightStringFor(expr.WeightStrExpr)
wsCol, _, err := addExpressionToRoute(ctx, plan, &sqlparser.AliasedExpr{Expr: wsExpr}, true)
if err != nil {
Expand Down Expand Up @@ -274,6 +276,22 @@ func (hp *horizonPlanning) pushAggrOnJoin(
return nil, nil, err
}

// If the rhs has no grouping column then a count(*) will return 0 from the query and will get mapped to the record from left hand side.
// This is an incorrect behaviour as the join condition has not matched, so we add a literal 1 to the select query and also group by on it.
// So that only if join condition matches the records will be mapped and returned.
if len(rhsGrouping) == 0 && len(rhsAggrs) != 0 {
l := sqlparser.NewIntLiteral("1")
aExpr := &sqlparser.AliasedExpr{
Expr: l,
}
offset, _, err := pushProjection(ctx, aExpr, join.Right, true, true, false)
if err != nil {
return nil, nil, err
}
l = sqlparser.NewIntLiteral(strconv.Itoa(offset + 1))
rhsGrouping = append(rhsGrouping, abstract.GroupBy{Inner: l})
}

// Next we push the aggregations to both sides
newLHS, lhsOffsets, lhsAggrOffsets, _, err := hp.filteredPushAggregation(ctx, join.Left, lhsGrouping, lhsAggrs, true)
if err != nil {
Expand Down
42 changes: 21 additions & 21 deletions go/vt/vtgate/planbuilder/testdata/aggr_cases.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3332,7 +3332,7 @@ Gen4 plan same as above
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:0,R:0",
"JoinColumnIndexes": "L:0,R:1",
"TableName": "`user`_user_extra",
"Inputs": [
{
Expand All @@ -3353,8 +3353,8 @@ Gen4 plan same as above
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(*) from user_extra where 1 != 1",
"Query": "select count(*) from user_extra",
"FieldQuery": "select 1, count(*) from user_extra where 1 != 1 group by 1",
"Query": "select 1, count(*) from user_extra group by 1",
"Table": "user_extra"
}
]
Expand Down Expand Up @@ -3443,7 +3443,7 @@ Gen4 plan same as above
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:1,L:2,L:0,R:0",
"JoinColumnIndexes": "L:1,L:2,L:0,R:1",
"TableName": "`user`_user_extra",
"Inputs": [
{
Expand All @@ -3465,8 +3465,8 @@ Gen4 plan same as above
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(*) from user_extra where 1 != 1",
"Query": "select count(*) from user_extra",
"FieldQuery": "select 1, count(*) from user_extra where 1 != 1 group by 1",
"Query": "select 1, count(*) from user_extra group by 1",
"Table": "user_extra"
}
]
Expand Down Expand Up @@ -3501,7 +3501,7 @@ Gen4 plan same as above
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:1,L:2,L:0,R:0",
"JoinColumnIndexes": "L:1,L:2,L:0,R:1",
"TableName": "`user`_user_extra",
"Inputs": [
{
Expand All @@ -3523,8 +3523,8 @@ Gen4 plan same as above
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(user_extra.a) from user_extra where 1 != 1",
"Query": "select count(user_extra.a) from user_extra",
"FieldQuery": "select 1, count(user_extra.a) from user_extra where 1 != 1 group by 1",
"Query": "select 1, count(user_extra.a) from user_extra group by 1",
"Table": "user_extra"
}
]
Expand Down Expand Up @@ -4013,7 +4013,7 @@ Gen4 plan same as above
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:0,L:1,L:2,R:0,R:1",
"JoinColumnIndexes": "L:0,L:1,L:2,R:1,R:2",
"JoinVars": {
"user_col": 0
},
Expand All @@ -4038,8 +4038,8 @@ Gen4 plan same as above
"Name": "user",
"Sharded": true
},
"FieldQuery": "select min(user_extra.foo), max(user_extra.bar) from user_extra where 1 != 1",
"Query": "select min(user_extra.foo), max(user_extra.bar) from user_extra where user_extra.bar = :user_col",
"FieldQuery": "select 1, min(user_extra.foo), max(user_extra.bar) from user_extra where 1 != 1 group by 1",
"Query": "select 1, min(user_extra.foo), max(user_extra.bar) from user_extra where user_extra.bar = :user_col group by 1",
"Table": "user_extra"
}
]
Expand Down Expand Up @@ -4110,7 +4110,7 @@ Gen4 error: aggregate functions take a single argument 'count(distinct user_id,
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:0,L:1,L:2,R:0",
"JoinColumnIndexes": "L:0,L:1,L:2,R:1",
"TableName": "`user`_user_extra",
"Inputs": [
{
Expand All @@ -4131,8 +4131,8 @@ Gen4 error: aggregate functions take a single argument 'count(distinct user_id,
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(*) from user_extra where 1 != 1",
"Query": "select count(*) from user_extra",
"FieldQuery": "select 1, count(*) from user_extra where 1 != 1 group by 1",
"Query": "select 1, count(*) from user_extra group by 1",
"Table": "user_extra"
}
]
Expand Down Expand Up @@ -4598,7 +4598,7 @@ Gen4 error: aggregate functions take a single argument 'count(distinct user_id,
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:1,L:2,L:0,R:0",
"JoinColumnIndexes": "L:1,L:2,L:0,R:1",
"JoinVars": {
"user_col": 0
},
Expand All @@ -4623,8 +4623,8 @@ Gen4 error: aggregate functions take a single argument 'count(distinct user_id,
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(*) from user_extra where 1 != 1",
"Query": "select count(*) from user_extra where user_extra.col = :user_col",
"FieldQuery": "select 1, count(*) from user_extra where 1 != 1 group by 1",
"Query": "select 1, count(*) from user_extra where user_extra.col = :user_col group by 1",
"Table": "user_extra"
}
]
Expand Down Expand Up @@ -4734,7 +4734,7 @@ Gen4 error: aggregate functions take a single argument 'count(distinct user_id,
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:0,L:1,L:0,R:0",
"JoinColumnIndexes": "L:0,L:1,L:0,R:1",
"TableName": "`user`_user_extra",
"Inputs": [
{
Expand All @@ -4756,8 +4756,8 @@ Gen4 error: aggregate functions take a single argument 'count(distinct user_id,
"Name": "user",
"Sharded": true
},
"FieldQuery": "select count(*) from user_extra where 1 != 1",
"Query": "select count(*) from user_extra",
"FieldQuery": "select 1, count(*) from user_extra where 1 != 1 group by 1",
"Query": "select 1, count(*) from user_extra group by 1",
"Table": "user_extra"
}
]
Expand Down
36 changes: 18 additions & 18 deletions go/vt/vtgate/planbuilder/testdata/tpch_cases.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Gen4 error: unsupported: cross-shard correlated subquery
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:3,L:5,L:4,L:6,L:1,R:0",
"JoinColumnIndexes": "L:3,L:5,L:4,L:6,L:1,R:1",
"JoinVars": {
"o_custkey": 0
},
Expand Down Expand Up @@ -98,8 +98,8 @@ Gen4 error: unsupported: cross-shard correlated subquery
"Name": "main",
"Sharded": true
},
"FieldQuery": "select count(*) from customer where 1 != 1",
"Query": "select count(*) from customer where c_mktsegment = 'BUILDING' and c_custkey = :o_custkey",
"FieldQuery": "select 1, count(*) from customer where 1 != 1 group by 1",
"Query": "select 1, count(*) from customer where c_mktsegment = 'BUILDING' and c_custkey = :o_custkey group by 1",
"Table": "customer",
"Values": [
":o_custkey"
Expand Down Expand Up @@ -318,7 +318,7 @@ Gen4 error: unsupported: cross-shard correlated subquery
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:3,L:4,L:1,R:0",
"JoinColumnIndexes": "L:3,L:4,L:1,R:1",
"JoinVars": {
"n_regionkey": 0
},
Expand Down Expand Up @@ -346,8 +346,8 @@ Gen4 error: unsupported: cross-shard correlated subquery
"Name": "main",
"Sharded": true
},
"FieldQuery": "select count(*) from region where 1 != 1",
"Query": "select count(*) from region where r_name = 'ASIA' and r_regionkey = :n_regionkey",
"FieldQuery": "select 1, count(*) from region where 1 != 1 group by 1",
"Query": "select 1, count(*) from region where r_name = 'ASIA' and r_regionkey = :n_regionkey group by 1",
"Table": "region",
"Values": [
":n_regionkey"
Expand Down Expand Up @@ -664,7 +664,7 @@ Gen4 error: aggregation on columns from different sources not supported yet
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:1,L:1,L:4,L:2,R:0",
"JoinColumnIndexes": "L:1,L:1,L:4,L:2,R:1",
"JoinVars": {
"o_orderkey": 0
},
Expand All @@ -688,8 +688,8 @@ Gen4 error: aggregation on columns from different sources not supported yet
"Name": "main",
"Sharded": true
},
"FieldQuery": "select sum(l_extendedprice * (1 - l_discount)) as revenue from lineitem where 1 != 1",
"Query": "select sum(l_extendedprice * (1 - l_discount)) as revenue from lineitem where l_returnflag = 'R' and l_orderkey = :o_orderkey",
"FieldQuery": "select 1, sum(l_extendedprice * (1 - l_discount)) as revenue from lineitem where 1 != 1 group by 1",
"Query": "select 1, sum(l_extendedprice * (1 - l_discount)) as revenue from lineitem where l_returnflag = 'R' and l_orderkey = :o_orderkey group by 1",
"Table": "lineitem",
"Values": [
":o_orderkey"
Expand Down Expand Up @@ -998,7 +998,7 @@ Gen4 error: unsupported: group by on: *planbuilder.joinGen4
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:4,R:0",
"JoinColumnIndexes": "L:4,R:1",
"JoinVars": {
"l_partkey": 0,
"l_quantity": 1,
Expand All @@ -1025,8 +1025,8 @@ Gen4 error: unsupported: group by on: *planbuilder.joinGen4
"Name": "main",
"Sharded": true
},
"FieldQuery": "select count(*) from part where 1 != 1",
"Query": "select count(*) from part where p_partkey = :l_partkey and p_brand = 'Brand#12' and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') and :l_quantity \u003e= 1 and :l_quantity \u003c= 1 + 10 and p_size between 1 and 5 and :l_shipmode in ('AIR', 'AIR REG') and :l_shipinstruct = 'DELIVER IN PERSON' or p_partkey = :l_partkey and p_brand = 'Brand#23' and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') and :l_quantity \u003e= 10 and :l_quantity \u003c= 10 + 10 and p_size between 1 and 10 and :l_shipmode in ('AIR', 'AIR REG') and :l_shipinstruct = 'DELIVER IN PERSON' or p_partkey = :l_partkey and p_brand = 'Brand#34' and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') and :l_quantity \u003e= 20 and :l_quantity \u003c= 20 + 10 and p_size between 1 and 15 and :l_shipmode in ('AIR', 'AIR REG') and :l_shipinstruct = 'DELIVER IN PERSON'",
"FieldQuery": "select 1, count(*) from part where 1 != 1 group by 1",
"Query": "select 1, count(*) from part where p_partkey = :l_partkey and p_brand = 'Brand#12' and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') and :l_quantity \u003e= 1 and :l_quantity \u003c= 1 + 10 and p_size between 1 and 5 and :l_shipmode in ('AIR', 'AIR REG') and :l_shipinstruct = 'DELIVER IN PERSON' or p_partkey = :l_partkey and p_brand = 'Brand#23' and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') and :l_quantity \u003e= 10 and :l_quantity \u003c= 10 + 10 and p_size between 1 and 10 and :l_shipmode in ('AIR', 'AIR REG') and :l_shipinstruct = 'DELIVER IN PERSON' or p_partkey = :l_partkey and p_brand = 'Brand#34' and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') and :l_quantity \u003e= 20 and :l_quantity \u003c= 20 + 10 and p_size between 1 and 15 and :l_shipmode in ('AIR', 'AIR REG') and :l_shipinstruct = 'DELIVER IN PERSON' group by 1",
"Table": "part"
}
]
Expand Down Expand Up @@ -1089,7 +1089,7 @@ Gen4 error: unsupported: cross-shard correlated subquery
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:1,L:1,L:4,L:2,R:0",
"JoinColumnIndexes": "L:1,L:1,L:4,L:2,R:1",
"JoinVars": {
"l1_l_orderkey": 0
},
Expand All @@ -1113,8 +1113,8 @@ Gen4 error: unsupported: cross-shard correlated subquery
"Name": "main",
"Sharded": true
},
"FieldQuery": "select count(*) as numwait from orders where 1 != 1",
"Query": "select count(*) as numwait from orders where o_orderstatus = 'F' and exists (select 1 from lineitem as l2 where l2.l_orderkey = l1.l_orderkey and l2.l_suppkey != l1.l_suppkey limit 1) and not exists (select 1 from lineitem as l3 where l3.l_orderkey = l1.l_orderkey and l3.l_suppkey != l1.l_suppkey and l3.l_receiptdate \u003e l3.l_commitdate limit 1) and o_orderkey = :l1_l_orderkey",
"FieldQuery": "select 1, count(*) as numwait from orders where 1 != 1 group by 1",
"Query": "select 1, count(*) as numwait from orders where o_orderstatus = 'F' and exists (select 1 from lineitem as l2 where l2.l_orderkey = l1.l_orderkey and l2.l_suppkey != l1.l_suppkey limit 1) and not exists (select 1 from lineitem as l3 where l3.l_orderkey = l1.l_orderkey and l3.l_suppkey != l1.l_suppkey and l3.l_receiptdate \u003e l3.l_commitdate limit 1) and o_orderkey = :l1_l_orderkey group by 1",
"Table": "orders",
"Values": [
":l1_l_orderkey"
Expand All @@ -1126,7 +1126,7 @@ Gen4 error: unsupported: cross-shard correlated subquery
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:3,L:4,L:1,R:0",
"JoinColumnIndexes": "L:3,L:4,L:1,R:1",
"JoinVars": {
"s_nationkey": 0
},
Expand Down Expand Up @@ -1154,8 +1154,8 @@ Gen4 error: unsupported: cross-shard correlated subquery
"Name": "main",
"Sharded": true
},
"FieldQuery": "select count(*) as numwait from nation where 1 != 1",
"Query": "select count(*) as numwait from nation where n_name = 'SAUDI ARABIA' and exists (select 1 from lineitem as l2 where l2.l_orderkey = l1.l_orderkey and l2.l_suppkey != l1.l_suppkey limit 1) and not exists (select 1 from lineitem as l3 where l3.l_orderkey = l1.l_orderkey and l3.l_suppkey != l1.l_suppkey and l3.l_receiptdate \u003e l3.l_commitdate limit 1) and n_nationkey = :s_nationkey",
"FieldQuery": "select 1, count(*) as numwait from nation where 1 != 1 group by 1",
"Query": "select 1, count(*) as numwait from nation where n_name = 'SAUDI ARABIA' and exists (select 1 from lineitem as l2 where l2.l_orderkey = l1.l_orderkey and l2.l_suppkey != l1.l_suppkey limit 1) and not exists (select 1 from lineitem as l3 where l3.l_orderkey = l1.l_orderkey and l3.l_suppkey != l1.l_suppkey and l3.l_receiptdate \u003e l3.l_commitdate limit 1) and n_nationkey = :s_nationkey group by 1",
"Table": "nation",
"Values": [
":s_nationkey"
Expand Down

0 comments on commit 47eec73

Please sign in to comment.