Skip to content

Commit

Permalink
mysql support
Browse files Browse the repository at this point in the history
* mysql support

* add column_values to readme

* fix for batch files

* add ColumnNames

* fix returning
  • Loading branch information
sxwebdev authored Jul 29, 2024
1 parent 14a21ff commit 44ccf98
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 24 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,15 @@ sqlc:
skip_columns:
- id
- updated_at
column_values:
created_at: now()
returning: "*"
update:
skip_columns:
- id
- created_at
column_values:
updated_at: now()
returning: "*"
find:
where:
Expand Down
2 changes: 1 addition & 1 deletion cmd/pgxgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (

var (
appName = "pgxgen"
version = "v0.3.4"
version = "v0.3.5"
)

func main() {
Expand Down
14 changes: 12 additions & 2 deletions internal/assets/templates/constants.templ
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ func (s ColumnName) StructName() string {
v := stringy.New(string(s)).CamelCase().Get()
v = stringy.New(v).UcFirst()
return strings.ReplaceAll(v, "Id", "ID")
}
type ColumnNames []ColumnName
func (s ColumnNames) Strings() []string {
res := make([]string, len(s))
for idx, colName := range s {
res[idx] = colName.String()
}
return res
}`)
content.WriteString("\nconst (\n")

Expand All @@ -84,8 +94,8 @@ func (s ColumnName) StructName() string {
content.WriteString(")\n\n")

for _, tableName := range p.Tables {
content.WriteString(fmt.Sprintf("func %sColumnNames() []ColumnName {\n", stringy.New(stringy.New(tableName.Name).CamelCase().Get()).UcFirst()))
content.WriteString("return []ColumnName{\n")
content.WriteString(fmt.Sprintf("func %sColumnNames() ColumnNames {\n", stringy.New(stringy.New(tableName.Name).CamelCase().Get()).UcFirst()))
content.WriteString("return ColumnNames{\n")
for _, item := range p.GetColumnsForTable(tableName.Name) {
content.WriteString(fmt.Sprintf("ColumnName%s,\n", item.NamePreffix))
}
Expand Down
14 changes: 12 additions & 2 deletions internal/assets/templates/constants_templ.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions internal/config/crud.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type Method struct {
Where map[string]WhereParamsItem `yaml:"where"`
WhereAdditional []string `yaml:"where_additional"`
SkipColumns []string `yaml:"skip_columns"`
ColumnValues map[string]string `yaml:"column_values"`

// For find method
Limit bool `yaml:"limit"`
Expand Down
24 changes: 24 additions & 0 deletions internal/config/sqlc.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type GetPathsResponse struct {
QueriesPaths []string
OutPaths []string
SchemaPaths []string
Engines []string
}

func (s GetPathsResponse) GetModelPathByIndex(index int) string {
Expand All @@ -90,6 +91,7 @@ func (s *Sqlc) GetPaths() GetPathsResponse {
res.QueriesPaths = append(res.QueriesPaths, p.Queries)
res.OutPaths = append(res.OutPaths, p.Path)
res.SchemaPaths = append(res.SchemaPaths, p.Schema)
res.Engines = append(res.Engines, p.Engine)
}
}

Expand All @@ -105,6 +107,7 @@ func (s *Sqlc) GetPaths() GetPathsResponse {
res.QueriesPaths = append(res.QueriesPaths, p.Queries)
res.OutPaths = append(res.OutPaths, p.Gen.Go.Out)
res.SchemaPaths = append(res.SchemaPaths, p.Schema)
res.Engines = append(res.Engines, p.Engine)
}
}

Expand Down Expand Up @@ -169,3 +172,24 @@ func GetPathsByScheme(gpr GetPathsResponse, inSchemaDir string, pathType string)

return filteredModelPaths, nil
}

func GetEnginesByScheme(gpr GetPathsResponse, inSchemaDir string) ([]string, error) {
engines := []string{}
for index, item := range gpr.SchemaPaths {
absFirst, err := filepath.Abs(item)
if err != nil {
return nil, err
}

absSecond, err := filepath.Abs(inSchemaDir)
if err != nil {
return nil, err
}

if absFirst == absSecond {
engines = append(engines, gpr.Engines[index])
}
}

return engines, nil
}
93 changes: 75 additions & 18 deletions internal/crud/crud.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (s *crud) Generate(_ context.Context, args []string) error {

// validate config
if err := cfg.Validate(); err != nil {
return err
return fmt.Errorf("config validation error: %w", err)
}

// get queries paths for current schema
Expand All @@ -53,10 +53,19 @@ func (s *crud) Generate(_ context.Context, args []string) error {
return fmt.Errorf("GetPathsByScheme error: %w", err)
}

engines, err := config.GetEnginesByScheme(s.config.Sqlc.GetPaths(), cfg.SchemaDir)
if err != nil {
return fmt.Errorf("GetEnginesByScheme error: %w", err)
}

if len(engines) != len(outputPaths) {
return fmt.Errorf("engines count does not match output paths count")
}

// get catalogs
allCatalogs, err := sqlc.GetCatalogs(s.config.ConfigPaths.SqlcConfigFilePath)
if err != nil {
return err
return fmt.Errorf("getCatalogs error: %w", err)
}

s.catalogs = make(map[string]cmd.GetCatalogResultItem, len(outputPaths))
Expand All @@ -68,8 +77,16 @@ func (s *crud) Generate(_ context.Context, args []string) error {
s.catalogs[path] = item
}

params := make([]generateSQLForEachTableParams, len(outputPaths))
for i, p := range outputPaths {
params[i] = generateSQLForEachTableParams{
outputPath: p,
engine: engines[i],
}
}

// get sql code for each tables
sqlData, err := s.generateSQLForEachTable(cfg.CrudParams, outputPaths)
sqlData, err := s.generateSQLForEachTable(cfg.CrudParams, params)
if err != nil {
return fmt.Errorf("generate sql for each tables error: %w", err)
}
Expand Down Expand Up @@ -130,17 +147,26 @@ func (s *crud) Generate(_ context.Context, args []string) error {
return nil
}

type generateSQLForEachTableParams struct {
outputPath string
engine string
}

// generateSQLForEachTable - generate sql queries for each tables
func (s *crud) generateSQLForEachTable(crudParams config.CrudParams, outputPaths []string) (map[string][]byte, error) {
result := make(map[string][]byte, len(outputPaths))
func (s *crud) generateSQLForEachTable(crudParams config.CrudParams, params []generateSQLForEachTableParams) (map[string][]byte, error) {
result := make(map[string][]byte, len(params))

for _, outPath := range outputPaths {
for _, param := range params {
// Get all tables from postgres
tablesData, err := s.getTableMeta(outPath)
tablesData, err := s.getTableMeta(param.outputPath)
if err != nil {
return nil, fmt.Errorf("getTableMeta error: %w", err)
}

if !engineType(param.engine).Valid() {
return nil, fmt.Errorf("invalid engine type: %s", param.engine)
}

// headText := fmt.Sprintf("-- Code generated by pgxgen. DO NOT EDIT.\n-- versions:\n-- pgxgen %s\n\n", s.config.Pgxgen.Version)
// builder.WriteString(headText)

Expand Down Expand Up @@ -177,6 +203,7 @@ func (s *crud) generateSQLForEachTable(crudParams config.CrudParams, outputPaths
*metaData,
methodParams,
tableParams,
engineType(param.engine),
}

var err error
Expand Down Expand Up @@ -270,17 +297,28 @@ func (s *crud) processCreate(cfg config.CrudParams, p processParams) error {
p.builder.WriteString(", ")
}

if name == "created_at" {
p.builder.WriteString("now()")
} else {
if len(p.methodParams.ColumnValues) > 0 {
if value, ok := p.methodParams.ColumnValues[name]; ok {
p.builder.WriteString(value)
continue
}
}

switch p.engine {
case EnginesPostgres:
p.builder.WriteString(fmt.Sprintf("$%d", lastIndex))
lastIndex++
case EnginesMysql:
p.builder.WriteString("?")
default:
return fmt.Errorf("engine %s is not supported", p.engine)
}

lastIndex++
}

p.builder.WriteString(")")
if p.methodParams.Returning != "" {
p.builder.WriteString("\n\tRETURNING *")
p.builder.WriteString("\n\tRETURNING " + p.methodParams.Returning)
}
p.builder.WriteString(";\n\n")

Expand Down Expand Up @@ -322,12 +360,23 @@ func (s *crud) processUpdate(cfg config.CrudParams, p processParams) error {
}
}

if name == "updated_at" {
p.builder.WriteString("updated_at=now()")
} else {
if len(p.methodParams.ColumnValues) > 0 {
if value, ok := p.methodParams.ColumnValues[name]; ok {
p.builder.WriteString(name + "=" + value)
continue
}
}

switch p.engine {
case EnginesPostgres:
p.builder.WriteString(fmt.Sprintf("%s=$%d", name, lastIndex))
lastIndex++
case EnginesMysql:
p.builder.WriteString(fmt.Sprintf("%s=?", name))
default:
return fmt.Errorf("engine %s is not supported", p.engine)
}

lastIndex++
}

p.builder.WriteString("\n\t")
Expand All @@ -338,7 +387,7 @@ func (s *crud) processUpdate(cfg config.CrudParams, p processParams) error {
}

if p.methodParams.Returning != "" {
p.builder.WriteString("\n\tRETURNING *")
p.builder.WriteString("\n\tRETURNING " + p.methodParams.Returning)
}
p.builder.WriteString(";\n\n")

Expand Down Expand Up @@ -499,7 +548,15 @@ func (s *crud) processWhereParam(p processParams, method config.MethodType, last
operator = "="
}

p.builder.WriteString(fmt.Sprintf("%s%s$%d", param, operator, *lastIndex))
switch p.engine {
case EnginesPostgres:
p.builder.WriteString(fmt.Sprintf("%s%s$%d", param, operator, *lastIndex))
case EnginesMysql:
p.builder.WriteString(fmt.Sprintf("%s%s?", param, operator))
default:
return fmt.Errorf("engine %s is not supported", p.engine)
}

*lastIndex++
} else {
p.builder.WriteString(param)
Expand Down
20 changes: 20 additions & 0 deletions internal/crud/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,30 @@ func (t tables) getTableMetaData(tableName string) *tableMetaData {
return nil
}

type engineType string

func (e engineType) String() string {
return string(e)
}

func (s engineType) Valid() bool {
switch s {
case EnginesPostgres, EnginesMysql:
return true
}
return false
}

const (
EnginesPostgres engineType = "postgresql"
EnginesMysql engineType = "mysql"
)

type processParams struct {
builder *strings.Builder
table string
metaData tableMetaData
methodParams config.Method
tableParams config.TableParams
engine engineType
}
2 changes: 1 addition & 1 deletion internal/sqlc/movemodels.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (s *sqlc) moveModels(
}

// replace imports in generated files by sqlc
if strings.HasSuffix(file.Name(), ".sql.go") || file.Name() == "querier.go" {
if strings.HasSuffix(file.Name(), ".sql.go") || file.Name() == "querier.go" || file.Name() == "batch.go" {
if err := s.replace(
filepath.Join(modelFileDir, file.Name()),
func(c config.Config, str string) (string, error) {
Expand Down
6 changes: 6 additions & 0 deletions internal/sqlc/replace.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ func replaceImports(str string, sqlcModelParam config.SqlcModels, modelFileStruc
break
}

re = regexp.MustCompile(fmt.Sprintf(`(?sm)\s+\w+\s+%s\s+`, item.Name))
if re.MatchString(str) {
existsSomeModelStruct = true
break
}

for _, field := range item.Fields {
re := regexp.MustCompile(fmt.Sprintf(`(?sm)\s+\w+\s+%s\s+`, field.Name))
if re.MatchString(str) {
Expand Down

0 comments on commit 44ccf98

Please sign in to comment.