Skip to content

Commit

Permalink
feat: change column type
Browse files Browse the repository at this point in the history
Bufixes and improvements:
- pgdialect.Inspector canonicalizes all default expressions (lowercase)
to make sure they are always comparable with the model definition.
- sqlschema.SchemaInspector canonicalizes all default expressions (lowercase)
- pgdialect and sqlschema now support type-equivalence, which prevents unnecessary
migrations like CHAR -> CHARACTER from being created.

Changing PRIMARY KEY and UNIQUE-ness are outside of this commit's scope, because
those constraints can span multiple columns.
  • Loading branch information
bevzzz committed Aug 11, 2024
1 parent ce75416 commit 4e80ee1
Show file tree
Hide file tree
Showing 11 changed files with 747 additions and 235 deletions.
60 changes: 59 additions & 1 deletion dialect/pgdialect/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pgdialect
import (
"context"
"fmt"
"log"

"github.com/uptrace/bun"
"github.com/uptrace/bun/internal"
Expand Down Expand Up @@ -61,6 +62,8 @@ func (m *migrator) Apply(ctx context.Context, changes ...sqlschema.Operation) er
b, err = m.addForeignKey(fmter, b, change)
case *migrate.RenameConstraint:
b, err = m.renameConstraint(fmter, b, change)
case *migrate.ChangeColumnType:
b, err = m.changeColumnType(fmter, b, change)
default:
return fmt.Errorf("apply changes: unknown operation %T", change)
}
Expand All @@ -69,7 +72,7 @@ func (m *migrator) Apply(ctx context.Context, changes ...sqlschema.Operation) er
}

query := internal.String(b)
// log.Println("exec query: " + query)
log.Println("exec query: " + query)
if _, err = conn.ExecContext(ctx, query); err != nil {
return fmt.Errorf("apply changes: %w", err)
}
Expand Down Expand Up @@ -174,3 +177,58 @@ func (m *migrator) addForeignKey(fmter schema.Formatter, b []byte, add *migrate.

return b, nil
}

func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *migrate.ChangeColumnType) (_ []byte, err error) {
b = append(b, "ALTER TABLE "...)
fqn := colDef.FQN()
if b, err = fqn.AppendQuery(fmter, b); err != nil {
return b, err
}

var i int
appendAlterColumn := func() {
if i > 0 {
b = append(b, ", "...)
}
b = append(b, " ALTER COLUMN "...)
b, err = bun.Ident(colDef.Column).AppendQuery(fmter, b)
i++
}

got, want := colDef.From, colDef.To

if want.SQLType != got.SQLType {
if appendAlterColumn(); err != nil {
return b, err
}
b = append(b, " SET DATA TYPE "...)
if b, err = want.AppendQuery(fmter, b); err != nil {
return b, err
}
}

if want.IsNullable != got.IsNullable {
if appendAlterColumn(); err != nil {
return b, err
}
if !want.IsNullable {
b = append(b, " SET NOT NULL"...)
} else {
b = append(b, " DROP NOT NULL"...)
}
}

if want.DefaultValue != got.DefaultValue {
if appendAlterColumn(); err != nil {
return b, err
}
if want.DefaultValue == "" {
b = append(b, " DROP DEFAULT"...)
} else {
b = append(b, " SET DEFAULT "...)
b = append(b, want.DefaultValue...)
}
}

return b, nil
}
46 changes: 22 additions & 24 deletions dialect/pgdialect/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ package pgdialect

import (
"context"
"fmt"
"strings"

"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/migrate/sqlschema"
)

Expand Down Expand Up @@ -52,23 +50,21 @@ func (in *Inspector) Inspect(ctx context.Context) (sqlschema.State, error) {
}
colDefs := make(map[string]sqlschema.Column)
for _, c := range columns {
dataType := fromDatabaseType(c.DataType)
if strings.EqualFold(dataType, sqltype.VarChar) && c.VarcharLen > 0 {
dataType = fmt.Sprintf("%s(%d)", dataType, c.VarcharLen)
}

def := c.Default
if c.IsSerial || c.IsIdentity {
def = ""
} else if !c.IsDefaultLiteral {
def = strings.ToLower(def)
}

colDefs[c.Name] = sqlschema.Column{
SQLType: strings.ToLower(dataType),
SQLType: c.DataType,
VarcharLen: c.VarcharLen,
DefaultValue: def,
IsPK: c.IsPK,
IsNullable: c.IsNullable,
IsAutoIncrement: c.IsSerial,
IsIdentity: c.IsIdentity,
DefaultValue: def,
}
}

Expand Down Expand Up @@ -96,21 +92,22 @@ type InformationSchemaTable struct {
}

type InformationSchemaColumn struct {
Schema string `bun:"table_schema"`
Table string `bun:"table_name"`
Name string `bun:"column_name"`
DataType string `bun:"data_type"`
VarcharLen int `bun:"varchar_len"`
IsArray bool `bun:"is_array"`
ArrayDims int `bun:"array_dims"`
Default string `bun:"default"`
IsPK bool `bun:"is_pk"`
IsIdentity bool `bun:"is_identity"`
IndentityType string `bun:"identity_type"`
IsSerial bool `bun:"is_serial"`
IsNullable bool `bun:"is_nullable"`
IsUnique bool `bun:"is_unique"`
UniqueGroup []string `bun:"unique_group,array"`
Schema string `bun:"table_schema"`
Table string `bun:"table_name"`
Name string `bun:"column_name"`
DataType string `bun:"data_type"`
VarcharLen int `bun:"varchar_len"`
IsArray bool `bun:"is_array"`
ArrayDims int `bun:"array_dims"`
Default string `bun:"default"`
IsDefaultLiteral bool `bun:"default_is_literal_expr"`
IsPK bool `bun:"is_pk"`
IsIdentity bool `bun:"is_identity"`
IndentityType string `bun:"identity_type"`
IsSerial bool `bun:"is_serial"`
IsNullable bool `bun:"is_nullable"`
IsUnique bool `bun:"is_unique"`
UniqueGroup []string `bun:"unique_group,array"`
}

type ForeignKey struct {
Expand Down Expand Up @@ -153,6 +150,7 @@ SELECT
WHEN "c".column_default ~ '^''.*''::.*$' THEN substring("c".column_default FROM '^''(.*)''::.*$')
ELSE "c".column_default
END AS "default",
"c".column_default ~ '^''.*''::.*$' OR "c".column_default ~ '^[0-9\.]+$' AS default_is_literal_expr,
'p' = ANY("c".constraint_type) AS is_pk,
"c".is_identity = 'YES' AS is_identity,
"c".column_default = format('nextval(''%s_%s_seq''::regclass)', "c".table_name, "c".column_name) AS is_serial,
Expand Down
74 changes: 63 additions & 11 deletions dialect/pgdialect/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@ import (
"strings"

"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/migrate/sqlschema"
"github.com/uptrace/bun/schema"
)

const (
// Date / Time
pgTypeTimestampTz = "TIMESTAMPTZ" // Timestamp with a time zone
pgTypeDate = "DATE" // Date
pgTypeTime = "TIME" // Time without a time zone
pgTypeTimeTz = "TIME WITH TIME ZONE" // Time with a time zone
pgTypeInterval = "INTERVAL" // Time Interval
pgTypeTimestamp = "TIMESTAMP" // Timestamp
pgTypeTimestampWithTz = "TIMESTAMP WITH TIME ZONE" // Timestamp with a time zone
pgTypeTimestampTz = "TIMESTAMPTZ" // Timestamp with a time zone (alias)
pgTypeDate = "DATE" // Date
pgTypeTime = "TIME" // Time without a time zone
pgTypeTimeTz = "TIME WITH TIME ZONE" // Time with a time zone
pgTypeInterval = "INTERVAL" // Time interval

// Network Addresses
pgTypeInet = "INET" // IPv4 or IPv6 hosts and networks
Expand All @@ -30,6 +33,7 @@ const (

// Character Types
pgTypeChar = "CHAR" // fixed length string (blank padded)
pgTypeCharacter = "CHARACTER" // alias for CHAR
pgTypeText = "TEXT" // variable length string without limit
pgTypeVarchar = "VARCHAR" // variable length string with optional limit
pgTypeCharacterVarying = "CHARACTER VARYING" // alias for VARCHAR
Expand Down Expand Up @@ -114,11 +118,59 @@ func sqlType(typ reflect.Type) string {
return sqlType
}

// fromDatabaseType converts Postgres-specific type to a more generic `sqltype`.
func fromDatabaseType(dbType string) string {
switch strings.ToUpper(dbType) {
case pgTypeChar, pgTypeVarchar, pgTypeCharacterVarying, pgTypeText:
return sqltype.VarChar
var (
char = newAliases(pgTypeChar, pgTypeCharacter)
varchar = newAliases(pgTypeVarchar, pgTypeCharacterVarying)
timestampTz = newAliases(sqltype.Timestamp, pgTypeTimestampTz, pgTypeTimestampWithTz)
)

func (d *Dialect) EquivalentType(col1, col2 sqlschema.Column) bool {
if col1.SQLType == col2.SQLType {
return checkVarcharLen(col1, col2, d.DefaultVarcharLen())
}

typ1, typ2 := strings.ToUpper(col1.SQLType), strings.ToUpper(col2.SQLType)

switch {
case char.IsAlias(typ1) && char.IsAlias(typ2):
return checkVarcharLen(col1, col2, d.DefaultVarcharLen())
case varchar.IsAlias(typ1) && varchar.IsAlias(typ2):
return checkVarcharLen(col1, col2, d.DefaultVarcharLen())
case timestampTz.IsAlias(typ1) && timestampTz.IsAlias(typ2):
return true
}
return false
}

// checkVarcharLen returns true if columns have the same VarcharLen, or,
// if one specifies no VarcharLen and the other one has the default lenght for pgdialect.
// We assume that the types are otherwise equivalent and that any non-character column
// would have VarcharLen == 0;
func checkVarcharLen(col1, col2 sqlschema.Column, defaultLen int) bool {
if col1.VarcharLen == col2.VarcharLen {
return true
}

if (col1.VarcharLen == 0 && col2.VarcharLen == defaultLen) || (col1.VarcharLen == defaultLen && col2.VarcharLen == 0) {
return true
}
return false
}

// typeAlias defines aliases for common data types. It is a lightweight string set implementation.
type typeAlias map[string]struct{}

// IsAlias checks if typ1 and typ2 are aliases of the same data type.
func (t typeAlias) IsAlias(typ string) bool {
_, ok := t[typ]
return ok
}

// newAliases creates a set of aliases.
func newAliases(aliases ...string) typeAlias {
types := make(typeAlias)
for _, a := range aliases {
types[a] = struct{}{}
}
return dbType
return types
}
84 changes: 84 additions & 0 deletions dialect/pgdialect/sqltype_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package pgdialect

import (
"testing"

"github.com/stretchr/testify/require"
"github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/migrate/sqlschema"
)

func TestInspectorDialect_EquivalentType(t *testing.T) {
d := New()

t.Run("common types", func(t *testing.T) {
for _, tt := range []struct {
typ1, typ2 string
want bool
}{
{"text", "text", true}, // identical types

{sqltype.VarChar, pgTypeVarchar, true},
{sqltype.VarChar, pgTypeCharacterVarying, true},
{sqltype.VarChar, pgTypeChar, false},
{sqltype.VarChar, pgTypeCharacter, false},
{pgTypeCharacterVarying, pgTypeVarchar, true},
{pgTypeCharacter, pgTypeChar, true},
{sqltype.VarChar, pgTypeText, false},
{pgTypeChar, pgTypeText, false},
{pgTypeVarchar, pgTypeText, false},

// SQL standards require that TIMESTAMP be default alias for "TIMESTAMP WITH TIME ZONE"
{sqltype.Timestamp, pgTypeTimestampTz, true},
{sqltype.Timestamp, pgTypeTimestampWithTz, true},
{sqltype.Timestamp, pgTypeTimestamp, true}, // Still, TIMESTAMP == TIMESTAMP
{sqltype.Timestamp, pgTypeTimeTz, false},
{pgTypeTimestampTz, pgTypeTimestampWithTz, true},
} {
eq := " ~ "
if !tt.want {
eq = " !~ "
}
t.Run(tt.typ1+eq+tt.typ2, func(t *testing.T) {
got := d.EquivalentType(
sqlschema.Column{SQLType: tt.typ1},
sqlschema.Column{SQLType: tt.typ2},
)
require.Equal(t, tt.want, got)
})
}

})

t.Run("custom varchar length", func(t *testing.T) {
for _, tt := range []struct {
name string
col1, col2 sqlschema.Column
want bool
}{
{
name: "varchars of different length are not equivalent",
col1: sqlschema.Column{SQLType: "varchar", VarcharLen: 10},
col2: sqlschema.Column{SQLType: "varchar"},
want: false,
},
{
name: "varchar with no explicit length is equivalent to varchar of default length",
col1: sqlschema.Column{SQLType: "varchar", VarcharLen: d.DefaultVarcharLen()},
col2: sqlschema.Column{SQLType: "varchar"},
want: true,
},
{
name: "characters with equal custom length",
col1: sqlschema.Column{SQLType: "character varying", VarcharLen: 200},
col2: sqlschema.Column{SQLType: "varchar", VarcharLen: 200},
want: true,
},
} {
t.Run(tt.name, func(t *testing.T) {
got := d.EquivalentType(tt.col1, tt.col2)
require.Equal(t, tt.want, got)
})
}
})
}
Loading

0 comments on commit 4e80ee1

Please sign in to comment.