From 53c99f818b5597c6525575c46835bc9b059f66aa Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Mon, 25 Jul 2022 13:03:48 +0200 Subject: [PATCH] Mark aggregate functions callable These should also implement the iCallable interface as they are also callable as functions. Signed-off-by: Dirkjan Bussink --- go/vt/sqlparser/ast.go | 7 ++++++ go/vt/sqlparser/ast_clone.go | 12 ++++++++++ go/vt/sqlparser/ast_equals.go | 36 ++++++++++++++++++++++++++++++ go/vt/sqlparser/ast_rewrite.go | 12 ++++++++++ go/vt/sqlparser/ast_visit.go | 12 ++++++++++ go/vt/vtgate/engine/cached_size.go | 9 +++++++- 6 files changed, 87 insertions(+), 1 deletion(-) diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 76d991b2544..9baa989d967 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -3059,6 +3059,13 @@ func (*UpdateXMLExpr) iCallable() {} func (*PerformanceSchemaFuncExpr) iCallable() {} func (*GTIDFuncExpr) iCallable() {} +func (*Sum) iCallable() {} +func (*Min) iCallable() {} +func (*Max) iCallable() {} +func (*Avg) iCallable() {} +func (*CountStar) iCallable() {} +func (*Count) iCallable() {} + func (sum *Sum) GetArg() Expr { return sum.Arg } func (min *Min) GetArg() Expr { return min.Arg } func (max *Max) GetArg() Expr { return max.Arg } diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go index b8294478cf4..c0b38e7d6e1 100644 --- a/go/vt/sqlparser/ast_clone.go +++ b/go/vt/sqlparser/ast_clone.go @@ -3215,12 +3215,18 @@ func CloneCallable(in Callable) Callable { switch in := in.(type) { case *ArgumentLessWindowExpr: return CloneRefOfArgumentLessWindowExpr(in) + case *Avg: + return CloneRefOfAvg(in) case *CharExpr: return CloneRefOfCharExpr(in) case *ConvertExpr: return CloneRefOfConvertExpr(in) case *ConvertUsingExpr: return CloneRefOfConvertUsingExpr(in) + case *Count: + return CloneRefOfCount(in) + case *CountStar: + return CloneRefOfCountStar(in) case *CurTimeFuncExpr: return CloneRefOfCurTimeFuncExpr(in) case *ExtractFuncExpr: @@ -3285,8 +3291,12 @@ func CloneCallable(in Callable) Callable { return CloneRefOfLocateExpr(in) case *MatchExpr: return CloneRefOfMatchExpr(in) + case *Max: + return CloneRefOfMax(in) case *MemberOfExpr: return CloneRefOfMemberOfExpr(in) + case *Min: + return CloneRefOfMin(in) case *NTHValueExpr: return CloneRefOfNTHValueExpr(in) case *NamedWindow: @@ -3305,6 +3315,8 @@ func CloneCallable(in Callable) Callable { return CloneRefOfRegexpSubstrExpr(in) case *SubstrExpr: return CloneRefOfSubstrExpr(in) + case *Sum: + return CloneRefOfSum(in) case *TimestampFuncExpr: return CloneRefOfTimestampFuncExpr(in) case *TrimFuncExpr: diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index 549f1a5f7b7..7a5ef2b1007 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -4771,6 +4771,12 @@ func EqualsCallable(inA, inB Callable) bool { return false } return EqualsRefOfArgumentLessWindowExpr(a, b) + case *Avg: + b, ok := inB.(*Avg) + if !ok { + return false + } + return EqualsRefOfAvg(a, b) case *CharExpr: b, ok := inB.(*CharExpr) if !ok { @@ -4789,6 +4795,18 @@ func EqualsCallable(inA, inB Callable) bool { return false } return EqualsRefOfConvertUsingExpr(a, b) + case *Count: + b, ok := inB.(*Count) + if !ok { + return false + } + return EqualsRefOfCount(a, b) + case *CountStar: + b, ok := inB.(*CountStar) + if !ok { + return false + } + return EqualsRefOfCountStar(a, b) case *CurTimeFuncExpr: b, ok := inB.(*CurTimeFuncExpr) if !ok { @@ -4981,12 +4999,24 @@ func EqualsCallable(inA, inB Callable) bool { return false } return EqualsRefOfMatchExpr(a, b) + case *Max: + b, ok := inB.(*Max) + if !ok { + return false + } + return EqualsRefOfMax(a, b) case *MemberOfExpr: b, ok := inB.(*MemberOfExpr) if !ok { return false } return EqualsRefOfMemberOfExpr(a, b) + case *Min: + b, ok := inB.(*Min) + if !ok { + return false + } + return EqualsRefOfMin(a, b) case *NTHValueExpr: b, ok := inB.(*NTHValueExpr) if !ok { @@ -5041,6 +5071,12 @@ func EqualsCallable(inA, inB Callable) bool { return false } return EqualsRefOfSubstrExpr(a, b) + case *Sum: + b, ok := inB.(*Sum) + if !ok { + return false + } + return EqualsRefOfSum(a, b) case *TimestampFuncExpr: b, ok := inB.(*TimestampFuncExpr) if !ok { diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index 831c4f412a7..1e0ac2bdba6 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -8302,12 +8302,18 @@ func (a *application) rewriteCallable(parent SQLNode, node Callable, replacer re switch node := node.(type) { case *ArgumentLessWindowExpr: return a.rewriteRefOfArgumentLessWindowExpr(parent, node, replacer) + case *Avg: + return a.rewriteRefOfAvg(parent, node, replacer) case *CharExpr: return a.rewriteRefOfCharExpr(parent, node, replacer) case *ConvertExpr: return a.rewriteRefOfConvertExpr(parent, node, replacer) case *ConvertUsingExpr: return a.rewriteRefOfConvertUsingExpr(parent, node, replacer) + case *Count: + return a.rewriteRefOfCount(parent, node, replacer) + case *CountStar: + return a.rewriteRefOfCountStar(parent, node, replacer) case *CurTimeFuncExpr: return a.rewriteRefOfCurTimeFuncExpr(parent, node, replacer) case *ExtractFuncExpr: @@ -8372,8 +8378,12 @@ func (a *application) rewriteCallable(parent SQLNode, node Callable, replacer re return a.rewriteRefOfLocateExpr(parent, node, replacer) case *MatchExpr: return a.rewriteRefOfMatchExpr(parent, node, replacer) + case *Max: + return a.rewriteRefOfMax(parent, node, replacer) case *MemberOfExpr: return a.rewriteRefOfMemberOfExpr(parent, node, replacer) + case *Min: + return a.rewriteRefOfMin(parent, node, replacer) case *NTHValueExpr: return a.rewriteRefOfNTHValueExpr(parent, node, replacer) case *NamedWindow: @@ -8392,6 +8402,8 @@ func (a *application) rewriteCallable(parent SQLNode, node Callable, replacer re return a.rewriteRefOfRegexpSubstrExpr(parent, node, replacer) case *SubstrExpr: return a.rewriteRefOfSubstrExpr(parent, node, replacer) + case *Sum: + return a.rewriteRefOfSum(parent, node, replacer) case *TimestampFuncExpr: return a.rewriteRefOfTimestampFuncExpr(parent, node, replacer) case *TrimFuncExpr: diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go index 68aa5e936f3..b3224d1dfca 100644 --- a/go/vt/sqlparser/ast_visit.go +++ b/go/vt/sqlparser/ast_visit.go @@ -4102,12 +4102,18 @@ func VisitCallable(in Callable, f Visit) error { switch in := in.(type) { case *ArgumentLessWindowExpr: return VisitRefOfArgumentLessWindowExpr(in, f) + case *Avg: + return VisitRefOfAvg(in, f) case *CharExpr: return VisitRefOfCharExpr(in, f) case *ConvertExpr: return VisitRefOfConvertExpr(in, f) case *ConvertUsingExpr: return VisitRefOfConvertUsingExpr(in, f) + case *Count: + return VisitRefOfCount(in, f) + case *CountStar: + return VisitRefOfCountStar(in, f) case *CurTimeFuncExpr: return VisitRefOfCurTimeFuncExpr(in, f) case *ExtractFuncExpr: @@ -4172,8 +4178,12 @@ func VisitCallable(in Callable, f Visit) error { return VisitRefOfLocateExpr(in, f) case *MatchExpr: return VisitRefOfMatchExpr(in, f) + case *Max: + return VisitRefOfMax(in, f) case *MemberOfExpr: return VisitRefOfMemberOfExpr(in, f) + case *Min: + return VisitRefOfMin(in, f) case *NTHValueExpr: return VisitRefOfNTHValueExpr(in, f) case *NamedWindow: @@ -4192,6 +4202,8 @@ func VisitCallable(in Callable, f Visit) error { return VisitRefOfRegexpSubstrExpr(in, f) case *SubstrExpr: return VisitRefOfSubstrExpr(in, f) + case *Sum: + return VisitRefOfSum(in, f) case *TimestampFuncExpr: return VisitRefOfTimestampFuncExpr(in, f) case *TrimFuncExpr: diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 97eec881c16..4d6193deb6b 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -614,7 +614,7 @@ func (cached *Plan) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(128) + size += int64(144) } // field Original string size += hack.RuntimeAllocSize(int64(len(cached.Original))) @@ -631,6 +631,13 @@ func (cached *Plan) CachedSize(alloc bool) int64 { size += elem.CachedSize(true) } } + // field TablesUsed []string + { + size += hack.RuntimeAllocSize(int64(cap(cached.TablesUsed)) * int64(16)) + for _, elem := range cached.TablesUsed { + size += hack.RuntimeAllocSize(int64(len(elem))) + } + } return size } func (cached *Projection) CachedSize(alloc bool) int64 {