Skip to content

Commit

Permalink
planner, expression: fix error when using IN combined with subquery (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
time-and-fate authored Mar 3, 2021
1 parent 4cf3284 commit 5049364
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 102 deletions.
123 changes: 67 additions & 56 deletions expression/builtin_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ func (c *inFunctionClass) getFunction(ctx sessionctx.Context, args []Expression)

type baseInSig struct {
baseBuiltinFunc
nonConstArgs []Expression
hasNull bool
// nonConstArgsIdx stores the indices of non-constant args in the baseBuiltinFunc.args (the first arg is not included).
// It works with builtinInXXXSig.hashset to accelerate 'eval'.
nonConstArgsIdx []int
hasNull bool
}

// builtinInIntSig see https://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_in
Expand All @@ -167,7 +169,7 @@ type builtinInIntSig struct {
}

func (b *builtinInIntSig) buildHashMapForConstArgs(ctx sessionctx.Context) error {
b.nonConstArgs = []Expression{b.args[0]}
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = make(map[int64]bool, len(b.args)-1)
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) {
Expand All @@ -181,7 +183,7 @@ func (b *builtinInIntSig) buildHashMapForConstArgs(ctx sessionctx.Context) error
}
b.hashSet[val] = mysql.HasUnsignedFlag(b.args[i].GetType().Flag)
} else {
b.nonConstArgs = append(b.nonConstArgs, b.args[i])
b.nonConstArgsIdx = append(b.nonConstArgsIdx, i)
}
}
return nil
Expand All @@ -190,10 +192,8 @@ func (b *builtinInIntSig) buildHashMapForConstArgs(ctx sessionctx.Context) error
func (b *builtinInIntSig) Clone() builtinFunc {
newSig := &builtinInIntSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs))
for _, arg := range b.nonConstArgs {
newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone())
}
newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx))
copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx)
newSig.hashSet = b.hashSet
newSig.hasNull = b.hasNull
return newSig
Expand All @@ -206,9 +206,8 @@ func (b *builtinInIntSig) evalInt(row chunk.Row) (int64, bool, error) {
}
isUnsigned0 := mysql.HasUnsignedFlag(b.args[0].GetType().Flag)

args := b.args
args := b.args[1:]
if len(b.hashSet) != 0 {
args = b.nonConstArgs
if isUnsigned, ok := b.hashSet[arg0]; ok {
if (isUnsigned0 && isUnsigned) || (!isUnsigned0 && !isUnsigned) {
return 1, false, nil
Expand All @@ -217,10 +216,14 @@ func (b *builtinInIntSig) evalInt(row chunk.Row) (int64, bool, error) {
return 1, false, nil
}
}
args = args[:0]
for _, i := range b.nonConstArgsIdx {
args = append(args, b.args[i])
}
}

hasNull := b.hasNull
for _, arg := range args[1:] {
for _, arg := range args {
evaledArg, isNull, err := arg.EvalInt(b.ctx, row)
if err != nil {
return 0, true, err
Expand Down Expand Up @@ -258,7 +261,7 @@ type builtinInStringSig struct {
}

func (b *builtinInStringSig) buildHashMapForConstArgs(ctx sessionctx.Context) error {
b.nonConstArgs = []Expression{b.args[0]}
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = set.NewStringSet()
collator := collate.GetCollator(b.collation)
for i := 1; i < len(b.args); i++ {
Expand All @@ -273,7 +276,7 @@ func (b *builtinInStringSig) buildHashMapForConstArgs(ctx sessionctx.Context) er
}
b.hashSet.Insert(string(collator.Key(val))) // should do memory copy here
} else {
b.nonConstArgs = append(b.nonConstArgs, b.args[i])
b.nonConstArgsIdx = append(b.nonConstArgsIdx, i)
}
}

Expand All @@ -283,10 +286,8 @@ func (b *builtinInStringSig) buildHashMapForConstArgs(ctx sessionctx.Context) er
func (b *builtinInStringSig) Clone() builtinFunc {
newSig := &builtinInStringSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs))
for _, arg := range b.nonConstArgs {
newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone())
}
newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx))
copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx)
newSig.hashSet = b.hashSet
newSig.hasNull = b.hasNull
return newSig
Expand All @@ -298,17 +299,20 @@ func (b *builtinInStringSig) evalInt(row chunk.Row) (int64, bool, error) {
return 0, isNull0, err
}

args := b.args
args := b.args[1:]
collator := collate.GetCollator(b.collation)
if len(b.hashSet) != 0 {
args = b.nonConstArgs
if b.hashSet.Exist(string(collator.Key(arg0))) {
return 1, false, nil
}
args = args[:0]
for _, i := range b.nonConstArgsIdx {
args = append(args, b.args[i])
}
}

hasNull := b.hasNull
for _, arg := range args[1:] {
for _, arg := range args {
evaledArg, isNull, err := arg.EvalString(b.ctx, row)
if err != nil {
return 0, true, err
Expand All @@ -331,7 +335,7 @@ type builtinInRealSig struct {
}

func (b *builtinInRealSig) buildHashMapForConstArgs(ctx sessionctx.Context) error {
b.nonConstArgs = []Expression{b.args[0]}
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = set.NewFloat64Set()
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) {
Expand All @@ -345,7 +349,7 @@ func (b *builtinInRealSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro
}
b.hashSet.Insert(val)
} else {
b.nonConstArgs = append(b.nonConstArgs, b.args[i])
b.nonConstArgsIdx = append(b.nonConstArgsIdx, i)
}
}

Expand All @@ -355,10 +359,8 @@ func (b *builtinInRealSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro
func (b *builtinInRealSig) Clone() builtinFunc {
newSig := &builtinInRealSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs))
for _, arg := range b.nonConstArgs {
newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone())
}
newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx))
copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx)
newSig.hashSet = b.hashSet
newSig.hasNull = b.hasNull
return newSig
Expand All @@ -369,15 +371,19 @@ func (b *builtinInRealSig) evalInt(row chunk.Row) (int64, bool, error) {
if isNull0 || err != nil {
return 0, isNull0, err
}
args := b.args
args := b.args[1:]
if len(b.hashSet) != 0 {
args = b.nonConstArgs
if b.hashSet.Exist(arg0) {
return 1, false, nil
}
args = args[:0]
for _, i := range b.nonConstArgsIdx {
args = append(args, b.args[i])
}
}

hasNull := b.hasNull
for _, arg := range args[1:] {
for _, arg := range args {
evaledArg, isNull, err := arg.EvalReal(b.ctx, row)
if err != nil {
return 0, true, err
Expand All @@ -400,7 +406,7 @@ type builtinInDecimalSig struct {
}

func (b *builtinInDecimalSig) buildHashMapForConstArgs(ctx sessionctx.Context) error {
b.nonConstArgs = []Expression{b.args[0]}
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = set.NewStringSet()
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) {
Expand All @@ -418,7 +424,7 @@ func (b *builtinInDecimalSig) buildHashMapForConstArgs(ctx sessionctx.Context) e
}
b.hashSet.Insert(string(key))
} else {
b.nonConstArgs = append(b.nonConstArgs, b.args[i])
b.nonConstArgsIdx = append(b.nonConstArgsIdx, i)
}
}

Expand All @@ -428,10 +434,8 @@ func (b *builtinInDecimalSig) buildHashMapForConstArgs(ctx sessionctx.Context) e
func (b *builtinInDecimalSig) Clone() builtinFunc {
newSig := &builtinInDecimalSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs))
for _, arg := range b.nonConstArgs {
newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone())
}
newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx))
copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx)
newSig.hashSet = b.hashSet
newSig.hasNull = b.hasNull
return newSig
Expand All @@ -443,20 +447,23 @@ func (b *builtinInDecimalSig) evalInt(row chunk.Row) (int64, bool, error) {
return 0, isNull0, err
}

args := b.args
args := b.args[1:]
key, err := arg0.ToHashKey()
if err != nil {
return 0, true, err
}
if len(b.hashSet) != 0 {
args = b.nonConstArgs
if b.hashSet.Exist(string(key)) {
return 1, false, nil
}
args = args[:0]
for _, i := range b.nonConstArgsIdx {
args = append(args, b.args[i])
}
}

hasNull := b.hasNull
for _, arg := range args[1:] {
for _, arg := range args {
evaledArg, isNull, err := arg.EvalDecimal(b.ctx, row)
if err != nil {
return 0, true, err
Expand All @@ -479,7 +486,7 @@ type builtinInTimeSig struct {
}

func (b *builtinInTimeSig) buildHashMapForConstArgs(ctx sessionctx.Context) error {
b.nonConstArgs = []Expression{b.args[0]}
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = make(map[types.CoreTime]struct{}, len(b.args)-1)
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) {
Expand All @@ -493,7 +500,7 @@ func (b *builtinInTimeSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro
}
b.hashSet[val.CoreTime()] = struct{}{}
} else {
b.nonConstArgs = append(b.nonConstArgs, b.args[i])
b.nonConstArgsIdx = append(b.nonConstArgsIdx, i)
}
}

Expand All @@ -503,10 +510,8 @@ func (b *builtinInTimeSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro
func (b *builtinInTimeSig) Clone() builtinFunc {
newSig := &builtinInTimeSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs))
for _, arg := range b.nonConstArgs {
newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone())
}
newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx))
copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx)
newSig.hashSet = b.hashSet
newSig.hasNull = b.hasNull
return newSig
Expand All @@ -517,15 +522,19 @@ func (b *builtinInTimeSig) evalInt(row chunk.Row) (int64, bool, error) {
if isNull0 || err != nil {
return 0, isNull0, err
}
args := b.args
args := b.args[1:]
if len(b.hashSet) != 0 {
args = b.nonConstArgs
if _, ok := b.hashSet[arg0.CoreTime()]; ok {
return 1, false, nil
}
args = args[:0]
for _, i := range b.nonConstArgsIdx {
args = append(args, b.args[i])
}
}

hasNull := b.hasNull
for _, arg := range args[1:] {
for _, arg := range args {
evaledArg, isNull, err := arg.EvalTime(b.ctx, row)
if err != nil {
return 0, true, err
Expand All @@ -548,7 +557,7 @@ type builtinInDurationSig struct {
}

func (b *builtinInDurationSig) buildHashMapForConstArgs(ctx sessionctx.Context) error {
b.nonConstArgs = []Expression{b.args[0]}
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = make(map[time.Duration]struct{}, len(b.args)-1)
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(b.ctx.GetSessionVars().StmtCtx) {
Expand All @@ -562,7 +571,7 @@ func (b *builtinInDurationSig) buildHashMapForConstArgs(ctx sessionctx.Context)
}
b.hashSet[val.Duration] = struct{}{}
} else {
b.nonConstArgs = append(b.nonConstArgs, b.args[i])
b.nonConstArgsIdx = append(b.nonConstArgsIdx, i)
}
}

Expand All @@ -572,10 +581,8 @@ func (b *builtinInDurationSig) buildHashMapForConstArgs(ctx sessionctx.Context)
func (b *builtinInDurationSig) Clone() builtinFunc {
newSig := &builtinInDurationSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.nonConstArgs = make([]Expression, 0, len(b.nonConstArgs))
for _, arg := range b.nonConstArgs {
newSig.nonConstArgs = append(newSig.nonConstArgs, arg.Clone())
}
newSig.nonConstArgsIdx = make([]int, len(b.nonConstArgsIdx))
copy(newSig.nonConstArgsIdx, b.nonConstArgsIdx)
newSig.hashSet = b.hashSet
newSig.hasNull = b.hasNull
return newSig
Expand All @@ -586,15 +593,19 @@ func (b *builtinInDurationSig) evalInt(row chunk.Row) (int64, bool, error) {
if isNull0 || err != nil {
return 0, isNull0, err
}
args := b.args
args := b.args[1:]
if len(b.hashSet) != 0 {
args = b.nonConstArgs
if _, ok := b.hashSet[arg0.Duration]; ok {
return 1, false, nil
}
args = args[:0]
for _, i := range b.nonConstArgsIdx {
args = append(args, b.args[i])
}
}

hasNull := b.hasNull
for _, arg := range args[1:] {
for _, arg := range args {
evaledArg, isNull, err := arg.EvalDuration(b.ctx, row)
if err != nil {
return 0, true, err
Expand Down
Loading

0 comments on commit 5049364

Please sign in to comment.