-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
beaa2da
commit 80b4523
Showing
4 changed files
with
348 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package structarg | ||
|
||
import ( | ||
"github.com/rrgmc/litsql" | ||
"github.com/rrgmc/litsql/sq" | ||
) | ||
|
||
// WithGetArgsValuesOption adds a [litsql.ArgValues] to be parsed by [sq.ParseArgs]. | ||
func WithGetArgsValuesOption(options ...Option) sq.GetArgValuesInstanceOption { | ||
return sq.WithGetArgValuesInstanceOptionCustom(func(values any) (litsql.ArgValues, error) { | ||
return New(values, options...), nil | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
package structarg | ||
|
||
import ( | ||
"reflect" | ||
"strings" | ||
) | ||
|
||
func (s *argValues) getStructFieldByName(value reflect.Value, fieldName string) (any, bool) { | ||
typ := value.Type() | ||
for i := 0; i < typ.NumField(); i++ { | ||
// Get the StructField first since this is a cheap operation. If the | ||
// field is unexported, then ignore it. | ||
f := typ.Field(i) | ||
if f.PkgPath != "" { | ||
continue | ||
} | ||
|
||
// Next get the actual value of this field and verify it is assignable | ||
// to the map value. | ||
v := value.Field(i) | ||
|
||
tagValue := f.Tag.Get(s.tagName) | ||
keyName := f.Name | ||
|
||
if f.Anonymous && reflect.Indirect(v).Kind() == reflect.Struct { | ||
// embedded struct | ||
eval, ok := s.getStructFieldByName(reflect.Indirect(v), fieldName) | ||
if ok { | ||
return eval, true | ||
} | ||
} else { | ||
// Determine the name of the key in the map | ||
if index := strings.Index(tagValue, ","); index != -1 { | ||
if tagValue[:index] == "-" { | ||
continue | ||
} | ||
|
||
if keyNameTagValue := tagValue[:index]; keyNameTagValue != "" { | ||
keyName = keyNameTagValue | ||
} | ||
} else if len(tagValue) > 0 { | ||
if tagValue == "-" { | ||
continue | ||
} | ||
keyName = tagValue | ||
} | ||
|
||
if s.mapperFunc != nil { | ||
keyName = s.mapperFunc(keyName) | ||
} | ||
|
||
if keyName != fieldName { | ||
continue | ||
} | ||
|
||
if v.Kind() == reflect.Ptr && v.IsNil() { | ||
// avoid sending a pointer to a nil | ||
return nil, true | ||
} | ||
return v.Interface(), true | ||
} | ||
} | ||
|
||
return nil, false | ||
} | ||
|
||
func getReflectValue(value any) reflect.Value { | ||
v := reflect.ValueOf(value) | ||
v = reflect.Indirect(v) | ||
if k := v.Kind(); k != reflect.Struct { | ||
return reflect.Value{} | ||
} | ||
return v | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
package structarg | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/rrgmc/litsql" | ||
"gotest.tools/v3/assert" | ||
) | ||
|
||
func TestReflect(t *testing.T) { | ||
type x struct { | ||
H string | ||
J int | ||
L int `r:"LA"` | ||
M int `r:"MM,omitempty,x=15"` | ||
O *int | ||
P *string | ||
} | ||
|
||
oval := 45 | ||
|
||
value := &x{ | ||
H: "99", | ||
J: 11, | ||
L: 45, | ||
M: 91, | ||
O: &oval, | ||
} | ||
|
||
a := New(value, WithTagName("r")) | ||
|
||
for _, test := range []struct { | ||
name string | ||
expected any | ||
expectedNotFound bool | ||
}{ | ||
{ | ||
name: "H", | ||
expected: "99", | ||
}, | ||
{ | ||
name: "J", | ||
expected: 11, | ||
}, | ||
{ | ||
name: "L", | ||
expectedNotFound: true, | ||
}, | ||
{ | ||
name: "LA", | ||
expected: 45, | ||
}, | ||
{ | ||
name: "M", | ||
expectedNotFound: true, | ||
}, | ||
{ | ||
name: "MM", | ||
expected: 91, | ||
}, | ||
{ | ||
name: "O", | ||
expected: 45, | ||
}, | ||
{ | ||
name: "P", | ||
expected: nil, | ||
}, | ||
} { | ||
t.Run(test.name, func(t *testing.T) { | ||
v, ok := a.Get(test.name) | ||
if test.expectedNotFound { | ||
assert.Assert(t, !ok) | ||
} else { | ||
assert.Assert(t, ok) | ||
assert.DeepEqual(t, test.expected, v) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestReflectEmbed(t *testing.T) { | ||
type Xembed struct { | ||
A string | ||
B int | ||
} | ||
|
||
type x struct { | ||
Xembed | ||
H string | ||
J int | ||
} | ||
|
||
value := &x{ | ||
Xembed: Xembed{ | ||
A: "77", | ||
B: 88, | ||
}, | ||
H: "99", | ||
J: 11, | ||
} | ||
|
||
reflectValuesTest(t, New(value)) | ||
} | ||
|
||
func TestReflectEmbedPtr(t *testing.T) { | ||
type Xembed struct { | ||
A string | ||
B int | ||
} | ||
|
||
type x struct { | ||
*Xembed | ||
H string | ||
J int | ||
} | ||
|
||
value := &x{ | ||
Xembed: &Xembed{ | ||
A: "77", | ||
B: 88, | ||
}, | ||
H: "99", | ||
J: 11, | ||
} | ||
|
||
reflectValuesTest(t, New(value)) | ||
} | ||
|
||
func TestReflectEmbedPtrNil(t *testing.T) { | ||
type Xembed struct { | ||
A string | ||
B int | ||
} | ||
|
||
type x struct { | ||
*Xembed | ||
H string | ||
J int | ||
} | ||
|
||
value := &x{ | ||
Xembed: nil, | ||
H: "99", | ||
J: 11, | ||
} | ||
|
||
a := New(value) | ||
|
||
for _, test := range []struct { | ||
name string | ||
expected any | ||
}{ | ||
{ | ||
name: "H", | ||
expected: "99", | ||
}, | ||
{ | ||
name: "J", | ||
expected: 11, | ||
}, | ||
{ | ||
name: "A", | ||
}, | ||
{ | ||
name: "B", | ||
}, | ||
} { | ||
t.Run(test.name, func(t *testing.T) { | ||
v, ok := a.Get(test.name) | ||
if test.expected == nil { | ||
assert.Assert(t, !ok) | ||
} else { | ||
assert.Assert(t, ok) | ||
assert.DeepEqual(t, test.expected, v) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func reflectValuesTest(t *testing.T, a litsql.ArgValues) { | ||
for _, test := range []struct { | ||
name string | ||
expected any | ||
}{ | ||
{ | ||
name: "H", | ||
expected: "99", | ||
}, | ||
{ | ||
name: "J", | ||
expected: 11, | ||
}, | ||
{ | ||
name: "A", | ||
expected: "77", | ||
}, | ||
{ | ||
name: "B", | ||
expected: 88, | ||
}, | ||
} { | ||
t.Run(test.name, func(t *testing.T) { | ||
v, ok := a.Get(test.name) | ||
assert.Assert(t, ok) | ||
assert.DeepEqual(t, test.expected, v) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
package structarg | ||
|
||
import ( | ||
"reflect" | ||
|
||
"github.com/rrgmc/litsql" | ||
) | ||
|
||
// New returns a [litsql.ArgValues] from struct fields. If value is not a struct, returns nil. | ||
func New(value any, options ...Option) litsql.ArgValues { | ||
v := getReflectValue(value) | ||
if !v.IsValid() { | ||
return nil | ||
} | ||
|
||
ret := &argValues{ | ||
tagName: "json", | ||
value: v, | ||
} | ||
for _, opt := range options { | ||
opt(ret) | ||
} | ||
return ret | ||
} | ||
|
||
type argValues struct { | ||
value reflect.Value | ||
|
||
tagName string | ||
mapperFunc func(string) string | ||
} | ||
|
||
func (s *argValues) Get(name string) (any, bool) { | ||
return s.getStructFieldByName(s.value, name) | ||
} | ||
|
||
|
||
type Option func(*argValues) | ||
|
||
// WithTagName sets the struct tag name to use. Default is "json". | ||
func WithTagName(tagName string) Option { | ||
return func(o *argValues) { | ||
o.tagName = tagName | ||
} | ||
} | ||
|
||
// WithMapperFunc sets the field name mapper function. | ||
func WithMapperFunc(mapperFunc func(string) string) Option { | ||
return func(o *argValues) { | ||
o.mapperFunc = mapperFunc | ||
} | ||
} |