Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ddl: check expr restriction for hash partitioned table #10273

Merged
merged 9 commits into from
May 8, 2019
23 changes: 13 additions & 10 deletions ddl/db_partition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ func (s *testIntegrationSuite9) TestCreateTableWithPartition(c *C) {

sql4 := `create table t4 (
a int not null,
b int not null
b int not null
)
partition by range( id ) (
partition p1 values less than maxvalue,
partition p2 values less than (1991),
partition p3 values less than (1995)
partition p2 values less than (1991),
partition p3 values less than (1995)
);`
assertErrorCode(c, tk, sql4, tmysql.ErrPartitionMaxvalue)

Expand All @@ -121,10 +121,10 @@ func (s *testIntegrationSuite9) TestCreateTableWithPartition(c *C) {
c INT NOT NULL
)
partition by range columns(a,b,c) (
partition p0 values less than (10,5,1),
partition p2 values less than (50,maxvalue,10),
partition p3 values less than (65,30,13),
partition p4 values less than (maxvalue,30,40)
partition p0 values less than (10,5,1),
partition p2 values less than (50,maxvalue,10),
partition p3 values less than (65,30,13),
partition p4 values less than (maxvalue,30,40)
);`)
c.Assert(err, IsNil)

Expand All @@ -139,13 +139,13 @@ func (s *testIntegrationSuite9) TestCreateTableWithPartition(c *C) {

sql7 := `create table t7 (
a int not null,
b int not null
b int not null
)
partition by range( id ) (
partition p1 values less than (1991),
partition p2 values less than maxvalue,
partition p3 values less than maxvalue,
partition p4 values less than (1995),
partition p3 values less than maxvalue,
partition p4 values less than (1995),
partition p5 values less than maxvalue
);`
assertErrorCode(c, tk, sql7, tmysql.ErrPartitionMaxvalue)
Expand Down Expand Up @@ -230,6 +230,9 @@ func (s *testIntegrationSuite9) TestCreateTableWithPartition(c *C) {
assertErrorCode(c, tk, `create table t31 (a int not null) partition by range( a );`, tmysql.ErrPartitionsMustBeDefined)
assertErrorCode(c, tk, `create table t32 (a int not null) partition by range columns( a );`, tmysql.ErrPartitionsMustBeDefined)
assertErrorCode(c, tk, `create table t33 (a int, b int) partition by hash(a) partitions 0;`, tmysql.ErrNoParts)
assertErrorCode(c, tk, `create table t33 (a timestamp, b int) partition by hash(a) partitions 30;`, tmysql.ErrFieldTypeNotAllowedAsPartitionField)
// TODO: fix this one
// assertErrorCode(c, tk, `create table t33 (a timestamp, b int) partition by hash(unix_timestamp(a)) partitions 30;`, tmysql.ErrPartitionFuncNotAllowed)
}

func (s *testIntegrationSuite7) TestCreateTableWithHashPartition(c *C) {
Expand Down
30 changes: 15 additions & 15 deletions ddl/ddl_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ func buildTableInfoWithCheck(ctx sessionctx.Context, d *ddl, s *ast.CreateTableS
err = checkPartitionByRangeColumn(ctx, tbInfo, pi, s)
}
case model.PartitionTypeHash:
err = checkPartitionByHash(pi)
err = checkPartitionByHash(ctx, pi, s, cols, tbInfo)
}
if err != nil {
return nil, errors.Trace(err)
Expand Down Expand Up @@ -1365,41 +1365,41 @@ func buildViewInfoWithTableColumns(ctx sessionctx.Context, s *ast.CreateViewStmt
return viewInfo, tableColumns
}

func checkPartitionByHash(pi *model.PartitionInfo) error {
func checkPartitionByHash(ctx sessionctx.Context, pi *model.PartitionInfo, s *ast.CreateTableStmt, cols []*table.Column, tbInfo *model.TableInfo) error {
if err := checkAddPartitionTooManyPartitions(pi.Num); err != nil {
return errors.Trace(err)
return err
}
if err := checkNoHashPartitions(pi.Num); err != nil {
return errors.Trace(err)
if err := checkNoHashPartitions(ctx, pi.Num); err != nil {
return err
}
return nil
if err := checkPartitionFuncValid(ctx, tbInfo, s.Partition.Expr); err != nil {
return err
}
return checkPartitionFuncType(ctx, s, cols, tbInfo)
}

func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo, pi *model.PartitionInfo, s *ast.CreateTableStmt, cols []*table.Column, newConstraints []*ast.Constraint) error {
if err := checkPartitionNameUnique(tbInfo, pi); err != nil {
return errors.Trace(err)
return err
}

if err := checkCreatePartitionValue(ctx, tbInfo, pi, cols); err != nil {
return errors.Trace(err)
return err
}

if err := checkAddPartitionTooManyPartitions(uint64(len(pi.Definitions))); err != nil {
return errors.Trace(err)
return err
}

if err := checkNoRangePartitions(len(pi.Definitions)); err != nil {
return errors.Trace(err)
return err
}

if err := checkPartitionFuncValid(ctx, tbInfo, s.Partition.Expr); err != nil {
return errors.Trace(err)
return err
}

if err := checkPartitionFuncType(ctx, s, cols, tbInfo); err != nil {
return errors.Trace(err)
}
return nil
return checkPartitionFuncType(ctx, s, cols, tbInfo)
}

func checkPartitionByRangeColumn(ctx sessionctx.Context, tbInfo *model.TableInfo, pi *model.PartitionInfo, s *ast.CreateTableStmt) error {
Expand Down
12 changes: 9 additions & 3 deletions ddl/partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func checkPartitionFuncType(ctx sessionctx.Context, s *ast.CreateTableStmt, cols
buf := new(bytes.Buffer)
s.Partition.Expr.Format(buf)
exprStr := buf.String()
if s.Partition.Tp == model.PartitionTypeRange {
if s.Partition.Tp == model.PartitionTypeRange || s.Partition.Tp == model.PartitionTypeHash {
// if partition by columnExpr, check the column type
if _, ok := s.Partition.Expr.(*ast.ColumnNameExpr); ok {
for _, col := range cols {
Expand All @@ -215,13 +215,19 @@ func checkPartitionFuncType(ctx sessionctx.Context, s *ast.CreateTableStmt, cols
}
}

e, err := expression.ParseSimpleExprWithTableInfo(ctx, buf.String(), tblInfo)
e, err := expression.ParseSimpleExprWithTableInfo(ctx, exprStr, tblInfo)
if err != nil {
return errors.Trace(err)
}
if e.GetType().EvalType() == types.ETInt {
return nil
}
if s.Partition.Tp == model.PartitionTypeHash {
if _, ok := s.Partition.Expr.(*ast.ColumnNameExpr); ok {
return ErrNotAllowedTypeInPartition.GenWithStackByArgs(exprStr)
}
}

return ErrPartitionFuncNotAllowed.GenWithStackByArgs("PARTITION")
}

Expand Down Expand Up @@ -428,7 +434,7 @@ func checkAddPartitionTooManyPartitions(piDefs uint64) error {
return nil
}

func checkNoHashPartitions(partitionNum uint64) error {
func checkNoHashPartitions(ctx sessionctx.Context, partitionNum uint64) error {
if partitionNum == 0 {
return ErrNoParts.GenWithStackByArgs("partitions")
}
Expand Down