Skip to content

Commit

Permalink
feat(inserter): Allowing fields to be ignored on the inserter package (
Browse files Browse the repository at this point in the history
…#60)

updating the inserter package
  • Loading branch information
Jacobbrewer1 authored Oct 26, 2024
1 parent 0a9da7a commit 104d38e
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 42 deletions.
27 changes: 27 additions & 0 deletions inserter/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package inserter
import (
"database/sql"
"errors"

"github.com/jacobbrewer1/patcher"
)

var (
Expand Down Expand Up @@ -34,6 +36,31 @@ type SQLBatch struct {

// table is the table name to use in the SQL statement
table string

// ignoreFields is a list of fields to ignore when patching
ignoreFields []string

// ignoreFieldsFunc is a function that determines whether a field should be ignored
//
// This func should return true is the field is to be ignored
ignoreFieldsFunc patcher.IgnoreFieldsFunc
}

// newBatchDefaults returns a new SQLBatch with default values
func newBatchDefaults(opts ...BatchOpt) *SQLBatch {
b := &SQLBatch{
fields: nil,
args: nil,
db: nil,
tagName: patcher.DefaultDbTagName,
table: "",
}

for _, opt := range opts {
opt(b)
}

return b
}

func (b *SQLBatch) Fields() []string {
Expand Down
20 changes: 19 additions & 1 deletion inserter/batch_opts.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package inserter

import "database/sql"
import (
"database/sql"

"github.com/jacobbrewer1/patcher"
)

type BatchOpt func(*SQLBatch)

Expand All @@ -24,3 +28,17 @@ func WithDB(db *sql.DB) BatchOpt {
b.db = db
}
}

// WithIgnoreFields sets the fields to ignore when patching
func WithIgnoreFields(fields ...string) BatchOpt {
return func(b *SQLBatch) {
b.ignoreFields = fields
}
}

// WithIgnoreFieldsFunc sets the function that determines whether a field should be ignored
func WithIgnoreFieldsFunc(f patcher.IgnoreFieldsFunc) BatchOpt {
return func(b *SQLBatch) {
b.ignoreFieldsFunc = f
}
}
72 changes: 63 additions & 9 deletions inserter/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@ import (
"database/sql"
"fmt"
"reflect"
"slices"
"strings"
)

const (
// defaultTagName is the default tag name to look for in the struct
defaultTagName = "db"
"github.com/jacobbrewer1/patcher"
)

func NewBatch(resources []any, opts ...BatchOpt) *SQLBatch {
b := new(SQLBatch)
b.tagName = defaultTagName
b := newBatchDefaults(opts...)

for _, opt := range opts {
opt(b)
}
Expand Down Expand Up @@ -49,20 +47,34 @@ func (b *SQLBatch) genBatch(resources []any) {
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
tag := f.Tag.Get(b.tagName)
if tag == "-" {
if tag == patcher.TagOptSkip {
continue
}

// Skip unexported fields
if !f.IsExported() {
continue
}

// Skip fields that are to be ignored
if b.checkSkipField(f) {
continue
}

patcherOptsTag := f.Tag.Get(patcher.TagOptsName)
if patcherOptsTag != "" {
patcherOpts := strings.Split(patcherOptsTag, patcher.TagOptSeparator)
if slices.Contains(patcherOpts, patcher.TagOptSkip) {
continue
}
}

// if no tag is set, use the field name
if tag == "" {
tag = strings.ToLower(f.Name)
tag = f.Name
}

b.args = append(b.args, v.Field(i).Interface())
b.args = append(b.args, b.getFieldValue(v.Field(i), f))

// if the field is not unique, skip it
if _, ok := uniqueFields[tag]; ok {
Expand All @@ -76,6 +88,17 @@ func (b *SQLBatch) genBatch(resources []any) {
}
}

func (b *SQLBatch) getFieldValue(v reflect.Value, f reflect.StructField) any {
if f.Type.Kind() == reflect.Ptr {
if v.IsNil() {
return nil
}
return v.Elem().Interface()
}

return v.Interface()
}

func (b *SQLBatch) GenerateSQL() (string, []any, error) {
if err := b.validateSQLGen(); err != nil {
return "", nil, err
Expand Down Expand Up @@ -116,3 +139,34 @@ func (b *SQLBatch) Perform() (sql.Result, error) {

return b.db.Exec(sqlStr, args...)
}

func (b *SQLBatch) checkSkipField(field reflect.StructField) bool {
// The ignore fields tag takes precedence over the ignore fields list
if b.checkSkipTag(field) {
return true
}

return b.ignoredFieldsCheck(field)
}

func (b *SQLBatch) checkSkipTag(field reflect.StructField) bool {
val, ok := field.Tag.Lookup(patcher.TagOptsName)
if !ok {
return false
}

tags := strings.Split(val, patcher.TagOptSeparator)
return slices.Contains(tags, patcher.TagOptSkip)
}

func (b *SQLBatch) ignoredFieldsCheck(field reflect.StructField) bool {
return b.checkIgnoredFields(strings.ToLower(field.Name)) || b.checkIgnoreFunc(field)
}

func (b *SQLBatch) checkIgnoreFunc(field reflect.StructField) bool {
return b.ignoreFieldsFunc != nil && b.ignoreFieldsFunc(field)
}

func (b *SQLBatch) checkIgnoredFields(field string) bool {
return len(b.ignoreFields) > 0 && slices.Contains(b.ignoreFields, strings.ToLower(field))
}
87 changes: 83 additions & 4 deletions inserter/sql_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package inserter

import (
"reflect"
"testing"

"github.com/jacobbrewer1/patcher"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
)

Expand Down Expand Up @@ -219,7 +222,7 @@ func (s *generateSQLSuite) TestGenerateSQL_noDbTag() {
sql, args, err := NewBatch(resources, WithTable("temp")).GenerateSQL()
s.Require().NoError(err)

s.Require().Equal("INSERT INTO temp (id, name) VALUES (?, ?), (?, ?), (?, ?), (?, ?), (?, ?)", sql)
s.Require().Equal("INSERT INTO temp (ID, Name) VALUES (?, ?), (?, ?), (?, ?), (?, ?), (?, ?)", sql)
s.Require().Len(args, 10)
}

Expand Down Expand Up @@ -357,9 +360,85 @@ func (s *generateSQLSuite) TestGenerateSQL_Success_WithPointedFields() {
sql, args, err := NewBatch(resources, WithTable("temp"), WithTagName("db")).GenerateSQL()
s.Require().NoError(err)

s.Require().Equal("INSERT INTO temp (id, name) VALUES (?, ?), (?, ?), (?, ?), (?, ?), (?, ?)", sql)
s.Require().Len(args, 10)
s.Equal("INSERT INTO temp (id, name) VALUES (?, ?), (?, ?), (?, ?), (?, ?), (?, ?)", sql)

expectedArgs := []any{resources[0].(*temp).ID, resources[0].(*temp).Name, (*int)(nil), resources[1].(*temp).Name, resources[2].(*temp).ID, resources[2].(*temp).Name, resources[3].(*temp).ID, resources[3].(*temp).Name, resources[4].(*temp).ID, resources[4].(*temp).Name}
expectedArgs := []any{1, "test", interface{}(nil), "test2", 3, "test3", 4, "test4", 5, "test5"}
s.Require().Equal(expectedArgs, args)
}

func (s *generateSQLSuite) TestGenerateSQL_Success_WithPointedFields_noDbTag() {
type temp struct {
ID *int
Name *string
unexported string
}

resources := []any{
&temp{ID: ptr(1), Name: ptr("test")},
&temp{ID: nil, Name: ptr("test2")},
&temp{ID: ptr(3), Name: ptr("test3")},
&temp{ID: ptr(4), Name: ptr("test4")},
&temp{ID: ptr(5), Name: ptr("test5"), unexported: "test"},
}

sql, args, err := NewBatch(resources, WithTable("temp")).GenerateSQL()
s.Require().NoError(err)

s.Equal("INSERT INTO temp (ID, Name) VALUES (?, ?), (?, ?), (?, ?), (?, ?), (?, ?)", sql)

expectedArgs := []any{1, "test", interface{}(nil), "test2", 3, "test3", 4, "test4", 5, "test5"}
s.Require().Equal(expectedArgs, args)
}

func (s *generateSQLSuite) TestGenerateSQL_Success_IgnoredFields() {
type temp struct {
ID int `db:"id"`
Name string `db:"name"`
unexported string `db:"unexported"`
}

resources := []any{
&temp{ID: 1, Name: "test"},
&temp{ID: 2, Name: "test2"},
&temp{ID: 3, Name: "test3"},
&temp{ID: 4, Name: "test4"},
&temp{ID: 5, Name: "test5", unexported: "test"},
}

b := NewBatch(resources, WithTable("temp"), WithTagName("db"), WithIgnoreFields("unexported"))

sql, args, err := b.GenerateSQL()
s.Require().NoError(err)

s.Equal("INSERT INTO temp (id, name) VALUES (?, ?), (?, ?), (?, ?), (?, ?), (?, ?)", sql)
s.Len(args, 10)
}

func (s *generateSQLSuite) TestGenerateSQL_Success_IgnoredFieldsFunc() {
type temp struct {
ID int `db:"id"`
Name string `db:"name"`
unexported string `db:"unexported"`
}

resources := []any{
&temp{ID: 1, Name: "test"},
&temp{ID: 2, Name: "test2"},
&temp{ID: 3, Name: "test3"},
&temp{ID: 4, Name: "test4"},
&temp{ID: 5, Name: "test5", unexported: "test"},
}

mif := patcher.NewMockIgnoreFieldsFunc(s.T())
mif.On("Execute", mock.Anything).Return(func(f reflect.StructField) bool {
return f.Name == "ID"
})

b := NewBatch(resources, WithTable("temp"), WithTagName("db"), WithIgnoreFieldsFunc(mif.Execute))

sql, args, err := b.GenerateSQL()
s.Require().NoError(err)

s.Equal("INSERT INTO temp (name) VALUES (?), (?), (?), (?), (?)", sql)
s.Len(args, 5)
}
25 changes: 0 additions & 25 deletions loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,6 @@ var (
ErrInvalidType = errors.New("invalid type: must pointer to struct")
)

func newPatchDefaults(opts ...PatchOpt) *SQLPatch {
// Default options
p := &SQLPatch{
fields: nil,
args: nil,
db: nil,
tagName: defaultDbTagName,
table: "",
whereSql: new(strings.Builder),
whereArgs: nil,
joinSql: new(strings.Builder),
joinArgs: nil,
includeZeroValues: false,
includeNilValues: false,
ignoreFields: nil,
ignoreFieldsFunc: nil,
}

for _, opt := range opts {
opt(p)
}

return p
}

// LoadDiff inserts the fields provided in the new struct pointer into the old struct pointer and injects the new
// values into the old struct
//
Expand Down
26 changes: 26 additions & 0 deletions patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,32 @@ type SQLPatch struct {
ignoreFieldsFunc IgnoreFieldsFunc
}

// newPatchDefaults creates a new SQLPatch with default options.
func newPatchDefaults(opts ...PatchOpt) *SQLPatch {
// Default options
p := &SQLPatch{
fields: nil,
args: nil,
db: nil,
tagName: DefaultDbTagName,
table: "",
whereSql: new(strings.Builder),
whereArgs: nil,
joinSql: new(strings.Builder),
joinArgs: nil,
includeZeroValues: false,
includeNilValues: false,
ignoreFields: nil,
ignoreFieldsFunc: nil,
}

for _, opt := range opts {
opt(p)
}

return p
}

func (s *SQLPatch) Fields() []string {
return s.fields
}
Expand Down
5 changes: 2 additions & 3 deletions sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

const (
defaultDbTagName = "db"
DefaultDbTagName = "db"
)

var (
Expand Down Expand Up @@ -49,8 +49,7 @@ func (s *SQLPatch) patchGen(resource any) {
tag := fType.Tag.Get(s.tagName)

// Skip unexported fields
if fType.PkgPath != "" {
// This is an unexported field
if !fType.IsExported() {
continue
}

Expand Down

0 comments on commit 104d38e

Please sign in to comment.