Skip to content

Commit

Permalink
Merge pull request #321 from pingcap/siddontang/dev-union
Browse files Browse the repository at this point in the history
dev union
  • Loading branch information
siddontang committed Oct 8, 2015
2 parents 7de38a6 + d40ee39 commit 2e6fd07
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 25 deletions.
69 changes: 58 additions & 11 deletions parser/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ import (
TableRefs "table references"
TruncateTableStmt "TRANSACTION TABLE statement"
UnionOpt "Union Option(empty/ALL/DISTINCT)"
UnionSelect "Union select/(select)"
UnionStmt "Union statement"
UpdateStmt "UPDATE statement"
Username "Username"
Expand Down Expand Up @@ -1680,6 +1681,10 @@ InsertRest:
{
$$ = &stmts.InsertIntoStmt{ColNames: $2.([]string), Sel: $4.(*stmts.SelectStmt)}
}
| '(' ColumnNameListOpt ')' UnionStmt
{
$$ = &stmts.InsertIntoStmt{ColNames: $2.([]string), Sel: $4.(*stmts.UnionStmt)}
}
| ValueSym ExpressionListList %prec insertValues
{
$$ = &stmts.InsertIntoStmt{Lists: $2.([][]expression.Expression)}
Expand All @@ -1688,6 +1693,10 @@ InsertRest:
{
$$ = &stmts.InsertIntoStmt{Sel: $1.(*stmts.SelectStmt)}
}
| UnionStmt
{
$$ = &stmts.InsertIntoStmt{Sel: $1.(*stmts.UnionStmt)}
}
| "SET" ColumnSetValueList
{
$$ = &stmts.InsertIntoStmt{Setlist: $2.([]*expression.Assignment)}
Expand Down Expand Up @@ -2692,10 +2701,6 @@ QuickOptional:
$$ = true
}

semiOpt:
/* EMPTY */
| ';'


/***************************Prepared Statement Start******************************
* See: https://dev.mysql.com/doc/refman/5.7/en/prepare.html
Expand Down Expand Up @@ -2884,9 +2889,13 @@ TableFactor:
{
$$ = &rsets.TableSource{Source: $1, Name: $2.(string)}
}
| '(' SelectStmt semiOpt ')' AsOpt
| '(' SelectStmt ')' AsOpt
{
$$ = &rsets.TableSource{Source: $2, Name: $5.(string)}
$$ = &rsets.TableSource{Source: $2, Name: $4.(string)}
}
| '(' UnionStmt ')' AsOpt
{
$$ = &rsets.TableSource{Source: $2, Name: $4.(string)}
}
| '(' TableRefs ')'
{
Expand Down Expand Up @@ -3040,6 +3049,14 @@ SubSelect:
s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1])
$$ = &subquery.SubQuery{Stmt: s}
}
| '(' UnionStmt ')'
{
s := $2.(*stmts.UnionStmt)
src := yylex.(*lexer).src
// See the implemention of yyParse function
s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1])
$$ = &subquery.SubQuery{Stmt: s}
}

// See: https://dev.mysql.com/doc/refman/5.7/en/innodb-locking-reads.html
SelectLockOpt:
Expand Down Expand Up @@ -3948,7 +3965,7 @@ StringList:
* See: https://dev.mysql.com/doc/refman/5.7/en/union.html
***********************************************************************************/
UnionStmt:
SelectStmt "UNION" UnionOpt SelectStmt
UnionSelect "UNION" UnionOpt SelectStmt
{
ds := []bool {$3.(bool)}
ss := []*stmts.SelectStmt{$1.(*stmts.SelectStmt), $4.(*stmts.SelectStmt)}
Expand All @@ -3957,14 +3974,44 @@ UnionStmt:
Selects: ss,
}
}
| UnionStmt "UNION" UnionOpt SelectStmt
| UnionSelect "UNION" UnionOpt '(' SelectStmt ')' SelectStmtOrder SelectStmtLimit
{
ds := []bool {$3.(bool)}
ss := []*stmts.SelectStmt{$1.(*stmts.SelectStmt), $5.(*stmts.SelectStmt)}
st := &stmts.UnionStmt{
Distincts: ds,
Selects: ss,
}
if $7 != nil {
st.OrderBy = $7.(*rsets.OrderByRset)
}

if $8 != nil {
ay := $8.([]interface{})
st.Limit = ay[0].(*rsets.LimitRset)
st.Offset = ay[1].(*rsets.OffsetRset)
}
$$ = st
}
| UnionSelect "UNION" UnionOpt UnionStmt
{
s := $1.(*stmts.UnionStmt)
s.Distincts = append(s.Distincts, $3.(bool))
s.Selects = append(s.Selects, $4.(*stmts.SelectStmt))
s := $4.(*stmts.UnionStmt)
s.Distincts = append([]bool {$3.(bool)}, s.Distincts...)
s.Selects = append([]*stmts.SelectStmt{$1.(*stmts.SelectStmt)}, s.Selects...)
$$ = s
}

UnionSelect:
SelectStmt
{
$$ = $1
}
| '(' SelectStmt ')'
{
$$ = $2
}


UnionOpt:
{
$$ = true
Expand Down
18 changes: 18 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,24 @@ func (s *testParserSuite) TestParser0(c *C) {
// For quote identifier
{"select `a`, `a.b`, `a b` from t", true},

// For union
{"select c1 from t1 union select c2 from t2", true},
{"select c1 from t1 union (select c2 from t2)", true},
{"select c1 from t1 union (select c2 from t2) order by c1", true},
{"select c1 from t1 union select c2 from t2 order by c2", true},
{"select c1 from t1 union (select c2 from t2) limit 1", true},
{"select c1 from t1 union (select c2 from t2) limit 1, 1", true},
{"select c1 from t1 union (select c2 from t2) order by c1 limit 1", true},
{"(select c1 from t1) union distinct select c2 from t2", true},
{"(select c1 from t1) union all select c2 from t2", true},
{"(select c1 from t1) union (select c2 from t2) order by c1 union select c3 from t3", false},
{"(select c1 from t1) union (select c2 from t2) limit 1 union select c3 from t3", false},
{"(select c1 from t1) union select c2 from t2 union (select c3 from t3) order by c1 limit 1", true},
{"select (select 1 union select 1) as a", true},
{"select * from (select 1 union select 2) as a", true},
{"insert into t select c1 from t1 union select c2 from t2", true},
{"insert into t (c) select c1 from t1 union select c2 from t2", true},

// For unquoted identifier
{"create table MergeContextTest$Simple (value integer not null, primary key (value))", true},

Expand Down
29 changes: 25 additions & 4 deletions plan/plans/union.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (p *UnionPlan) Explain(w format.Formatter) {

// GetFields implements plan.Plan GetFields interface.
func (p *UnionPlan) GetFields() []*field.ResultField {
return p.Srcs[0].GetFields()
return p.RFields
}

// Filter implements plan.Plan Filter interface.
Expand Down Expand Up @@ -101,6 +101,15 @@ func (p *UnionPlan) fetchAll(ctx context.Context) error {
}
}

// Update return result types
// e,g, for select 'abc' union select 'a', we can only get result type after
// executing first select.
for i, f := range p.RFields {
if f.Col.FieldType.Tp == 0 {
f.Col.FieldType = src.GetFields()[i].Col.FieldType
}
}

// Fetch results of the following select statements.
for i := range p.Distincts {
err = p.fetchSrc(ctx, i, t)
Expand All @@ -115,8 +124,7 @@ func (p *UnionPlan) fetchSrc(ctx context.Context, i int, t memkv.Temp) error {
src := p.Srcs[i+1]
distinct := p.Distincts[i]

// Use the ResultFields of the first select statement as the final ResultFields
rfs := p.Srcs[0].GetFields()
rfs := p.GetFields()
if len(src.GetFields()) != len(rfs) {
return errors.New("The used SELECT statements have a different number of columns")
}
Expand All @@ -125,6 +133,7 @@ func (p *UnionPlan) fetchSrc(ctx context.Context, i int, t memkv.Temp) error {
if row == nil || err != nil {
return errors.Trace(err)
}

srcRfs := src.GetFields()
for i := range row.Data {
// The column value should be casted as the same type of the first select statement in corresponding position
Expand All @@ -143,11 +152,23 @@ func (p *UnionPlan) fetchSrc(ctx context.Context, i int, t memkv.Temp) error {
if srcRf.Flen > rf.Col.Flen {
rf.Col.Flen = srcRf.Col.Flen
}
row.Data[i], err = types.Convert(row.Data[i], &rf.Col.FieldType)
if rf.Col.FieldType.Tp > 0 {
row.Data[i], err = types.Convert(row.Data[i], &rf.Col.FieldType)
} else {
// First select result doesn't contain enough type information, e,g, select null union select 1.
// We cannot get the proper data type for select null.
// Now we just use the first correct return data types with following select.
// TODO: Try to merge all data types for all select like select null union select 1 union select "abc"
if tp := srcRf.Col.FieldType.Tp; tp > 0 {
rf.Col.FieldType.Tp = tp
}
}

if err != nil {
return errors.Trace(err)
}
}

if distinct {
// distinct union, check duplicate
v, getErr := t.Get(row.Data)
Expand Down
4 changes: 4 additions & 0 deletions plan/plans/union_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ func (t *testUnionSuit) TestUnion(c *C) {
},
}

// Set Flen explicitly here, because following ColToResultField will change Flen to zero.
// TODO: remove this if ColToResultField update.
cols[1].FieldType.Flen = 100

pln := &plans.UnionPlan{
Srcs: []plan.Plan{
tblPlan,
Expand Down
2 changes: 1 addition & 1 deletion stmt/stmts/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ const (
type InsertIntoStmt struct {
ColNames []string
Lists [][]expression.Expression
Sel *SelectStmt
Sel plan.Planner
TableIdent table.Ident
Setlist []*expression.Assignment
Priority int
Expand Down
4 changes: 4 additions & 0 deletions stmt/stmts/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ func (s *testStmtSuite) TestInsert(c *C) {
insertSelectSQL := `create table insert_test_1 (id int, c1 int); insert insert_test_1 select id, c1 from insert_test;`
mustExec(c, s.testDB, insertSelectSQL)

insertSelectSQL = `create table insert_test_2 (id int, c1 int);
insert insert_test_1 select id, c1 from insert_test union select id * 10, c1 * 10 from insert_test;`
mustExec(c, s.testDB, insertSelectSQL)

errInsertSelectSQL = `insert insert_test_1 select c1 from insert_test;`
tx = mustBegin(c, s.testDB)
_, err = tx.Exec(errInsertSelectSQL)
Expand Down
67 changes: 64 additions & 3 deletions stmt/stmts/union.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
package stmts

import (
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/field"
"github.com/pingcap/tidb/plan"
"github.com/pingcap/tidb/plan/plans"
"github.com/pingcap/tidb/rset"
Expand All @@ -30,8 +32,10 @@ var _ stmt.Statement = (*UnionStmt)(nil)
type UnionStmt struct {
Distincts []bool
Selects []*SelectStmt

Text string
Limit *rsets.LimitRset
Offset *rsets.OffsetRset
OrderBy *rsets.OrderByRset
Text string
}

// Explain implements the stmt.Statement Explain interface.
Expand All @@ -57,14 +61,71 @@ func (s *UnionStmt) SetText(text string) {
// Plan implements the plan.Planner interface.
func (s *UnionStmt) Plan(ctx context.Context) (plan.Plan, error) {
srcs := make([]plan.Plan, 0, len(s.Selects))
columnCount := 0
for _, s := range s.Selects {
p, err := s.Plan(ctx)
if err != nil {
return nil, err
}
if columnCount > 0 && columnCount != len(p.GetFields()) {
return nil, errors.New("The used SELECT statements have a different number of columns")
}
columnCount = len(p.GetFields())

srcs = append(srcs, p)
}
return &plans.UnionPlan{Srcs: srcs, Distincts: s.Distincts}, nil

for i := len(s.Distincts) - 1; i >= 0; i-- {
if s.Distincts[i] {
// distinct overwrites all previous all
// e.g, select * from t1 union all select * from t2 union distinct select * from t3.
// The distinct will overwrite all for t1 and t2.
i--
for ; i >= 0; i-- {
s.Distincts[i] = true
}
break
}
}

fields := srcs[0].GetFields()
selectList := &plans.SelectList{}
selectList.ResultFields = make([]*field.ResultField, len(fields))
selectList.HiddenFieldOffset = len(fields)

// Union uses first select return column names and ignores table name.
// We only care result name and type here.
for i, f := range fields {
nf := &field.ResultField{}
nf.Name = f.Name
nf.FieldType = f.FieldType
selectList.ResultFields[i] = nf
}

var (
r plan.Plan
err error
)

r = &plans.UnionPlan{Srcs: srcs, Distincts: s.Distincts, RFields: selectList.ResultFields}

if s := s.OrderBy; s != nil {
if r, err = (&rsets.OrderByRset{By: s.By,
Src: r,
SelectList: selectList,
}).Plan(ctx); err != nil {
return nil, err
}
}

if s := s.Offset; s != nil {
r = &plans.OffsetDefaultPlan{Count: s.Count, Src: r, Fields: r.GetFields()}
}
if s := s.Limit; s != nil {
r = &plans.LimitDefaultPlan{Count: s.Count, Src: r, Fields: r.GetFields()}
}

return r, nil
}

// Exec implements the stmt.Statement Exec interface.
Expand Down
Loading

0 comments on commit 2e6fd07

Please sign in to comment.