Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Origin(&ori) NOT works properly with struct methods #21

Merged
merged 1 commit into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 45 additions & 27 deletions mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,13 @@ type Mocker struct {
}

type MockBuilder struct {
target interface{} // 目标函数
hook interface{} // mock函数
proxy interface{} // mock之后,原函数地址
when interface{} // 条件函数
filterGoroutine FilterGoroutineType
gId int64
missHookReceiver bool
missWhenReceiver bool
unsafe bool
target interface{} // 目标函数
hook interface{} // mock函数
proxyCaller interface{} // mock之后,原函数地址
when interface{} // 条件函数
filterGoroutine FilterGoroutineType
gId int64
unsafe bool
}

func Mock(target interface{}) *MockBuilder {
Expand All @@ -78,13 +76,13 @@ func MockUnsafe(target interface{}) *MockBuilder {
}

func (builder *MockBuilder) Origin(funcPtr interface{}) *MockBuilder {
tool.Assert(builder.proxy == nil, "re-set builder origin")
tool.Assert(builder.proxyCaller == nil, "re-set builder origin")
return builder.origin(funcPtr)
}

func (builder *MockBuilder) origin(funcPtr interface{}) *MockBuilder {
tool.AssertPtr(funcPtr)
builder.proxy = funcPtr
builder.proxyCaller = funcPtr
return builder
}

Expand Down Expand Up @@ -159,7 +157,7 @@ func (builder *MockBuilder) Build() *Mocker {

func (mocker *Mocker) checkReceiver(target reflect.Type, hook interface{}) bool {
hType := reflect.TypeOf(hook)
tool.Assert(hType.Kind() == reflect.Func, "Param a is not a func")
tool.Assert(hType.Kind() == reflect.Func, "Param(%v) a is not a func", hType.Kind())
tool.Assert(target.IsVariadic() == hType.IsVariadic(), "target:%v, hook:%v args not match", target, hook)
// has receiver
if tool.CheckFuncArgs(target, hType, 0) {
Expand All @@ -175,24 +173,44 @@ func (mocker *Mocker) checkReceiver(target reflect.Type, hook interface{}) bool
func (mocker *Mocker) buildHook(builder *MockBuilder) {
when := builder.when
hook := builder.hook
proxy := builder.proxy

var p reflect.Value
proxy := reflect.New(mocker.target.Type())

if proxy == nil {
proxy := mocker.target
p = reflect.New(proxy.Type())
} else {
p = reflect.ValueOf(proxy)
}
var missWhenReceiver, missHookReceiver, missProxyReceiver bool
if when != nil {
builder.missWhenReceiver = mocker.checkReceiver(mocker.target.Type(), when)
missWhenReceiver = mocker.checkReceiver(mocker.target.Type(), when)
}
if hook != nil {
builder.missHookReceiver = mocker.checkReceiver(mocker.target.Type(), hook)
missHookReceiver = mocker.checkReceiver(mocker.target.Type(), hook)
}

proxyCallerSetter := func(args []reflect.Value) {}

if builder.proxyCaller != nil {
pVal := reflect.ValueOf(builder.proxyCaller)
tool.Assert(pVal.Kind() == reflect.Ptr && pVal.Elem().Kind() == reflect.Func, "origin receiver must be a function pointer")
pElem := pVal.Elem()
missProxyReceiver = mocker.checkReceiver(mocker.target.Type(), pElem.Interface())

if missProxyReceiver {
proxyCallerSetter = func(args []reflect.Value) {
pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) {
return tool.ReflectCall(proxy.Elem(), append(args[0:1], innerArgs...))
}))
}
} else {
proxyCallerSetter = func(args []reflect.Value) {
pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) {
return tool.ReflectCall(proxy.Elem(), innerArgs)
}))
}
}
}

mockerHook := reflect.MakeFunc(mocker.target.Type(), func(args []reflect.Value) []reflect.Value {
origin := p.Elem()
proxyCallerSetter(args) // 设置origin调用proxy

origin := proxy.Elem()
mocker.access()

switch builder.filterGoroutine {
Expand All @@ -212,13 +230,13 @@ func (mocker *Mocker) buildHook(builder *MockBuilder) {

if when != nil {
wVal := reflect.ValueOf(when)
ret := tool.ReflectCallWithShiftOne(wVal, args, builder.missWhenReceiver)
ret := tool.ReflectCallWithShiftOne(wVal, args, missWhenReceiver)
b := ret[0].Bool()

if b && hook != nil {
hVal = reflect.ValueOf(hook)
mocker.mock()
return tool.ReflectCallWithShiftOne(hVal, args, builder.missHookReceiver)
return tool.ReflectCallWithShiftOne(hVal, args, missHookReceiver)
}
return tool.ReflectCall(origin, args)
}
Expand All @@ -228,10 +246,10 @@ func (mocker *Mocker) buildHook(builder *MockBuilder) {
hVal = reflect.ValueOf(hook)
}
mocker.mock()
return tool.ReflectCallWithShiftOne(hVal, args, builder.missHookReceiver)
return tool.ReflectCallWithShiftOne(hVal, args, missHookReceiver)
})
mocker.hook = mockerHook
mocker.proxy = p.Interface()
mocker.proxy = proxy.Interface()
}

func (mocker *Mocker) Patch() *Mocker {
Expand Down
65 changes: 65 additions & 0 deletions mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,68 @@ func TestMockUnsafe(t *testing.T) {
So(func() { ShortFun() }, ShouldPanicWith, "in hook")
})
}

type foo struct{ i int }

func (f *foo) Name(i int) string { return fmt.Sprintf("Fn-%v-%v", f.i, i) }

func (f *foo) Foo() int { return f.i }

func TestMockOrigin(t *testing.T) {
PatchConvey("struct-origin", t, func() {
PatchConvey("with receiver", func() {
var ori1 func(*foo, int) string
var ori2 func(*foo, int) string
mocker := Mock((*foo).Name).To(func(f *foo, i int) string {
if i == 1 {
return ori1(f, i)
}
return ori2(f, i)
}).Origin(&ori1).Build()

ori2 = func(f *foo, i int) string { return fmt.Sprintf("Fn-mock2-%v", i) }
So((&foo{100}).Name(1), ShouldEqual, "Fn-100-1")
So((&foo{200}).Name(1), ShouldEqual, "Fn-200-1")
So((&foo{100}).Name(2), ShouldEqual, "Fn-mock2-2")
So((&foo{200}).Name(2), ShouldEqual, "Fn-mock2-2")

ori1 = func(f *foo, i int) string { return fmt.Sprintf("Fn-mock1-%v", i) }
mocker.Origin(&ori2)
So((&foo{100}).Name(1), ShouldEqual, "Fn-mock1-1")
So((&foo{200}).Name(1), ShouldEqual, "Fn-mock1-1")
So((&foo{100}).Name(2), ShouldEqual, "Fn-100-2")
So((&foo{200}).Name(2), ShouldEqual, "Fn-200-2")
})
PatchConvey("without receiver", func() {
var ori1 func(int) string
var ori2 func(int) string
mocker := Mock((*foo).Name).To(func(i int) string {
if i == 1 {
return ori1(i)
}
return ori2(i)
}).Origin(&ori1).Build()

ori2 = func(i int) string { return fmt.Sprintf("Fn-mock2-%v", i) }
So((&foo{100}).Name(1), ShouldEqual, "Fn-100-1")
So((&foo{200}).Name(1), ShouldEqual, "Fn-200-1")
So((&foo{100}).Name(2), ShouldEqual, "Fn-mock2-2")
So((&foo{200}).Name(2), ShouldEqual, "Fn-mock2-2")

ori1 = func(i int) string { return fmt.Sprintf("Fn-mock1-%v", i) }
mocker.Origin(&ori2)
So((&foo{100}).Name(1), ShouldEqual, "Fn-mock1-1")
So((&foo{200}).Name(1), ShouldEqual, "Fn-mock1-1")
So((&foo{100}).Name(2), ShouldEqual, "Fn-100-2")
So((&foo{200}).Name(2), ShouldEqual, "Fn-200-2")
})
})
PatchConvey("issue https://github.com/bytedance/mockey/issues/15", t, func() {
var origin func() int
f := &foo{}
Mock(GetMethod(f, "Foo")).To(func() int { return origin() + 1 }).Origin(&origin).Build()
So((&foo{1}).Foo(), ShouldEqual, 2)
So((&foo{2}).Foo(), ShouldEqual, 3)
So((&foo{3}).Foo(), ShouldEqual, 4)
})
}
4 changes: 2 additions & 2 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,15 @@ type testOuter struct {
}

type testInner struct {
i int
_ int
}

func (testInner) FooNested() {
panic("not here")
}

type testInnerP struct {
s string
_ string
}

func (*testInnerP) FooNested() {
Expand Down