From d08bc2634f1c8e77445346948380093a86699e5f Mon Sep 17 00:00:00 2001 From: SYC_ Date: Mon, 24 Apr 2023 03:42:42 +0000 Subject: [PATCH] fix: Origin(&ori) NOT works properly with struct methods see https://github.com/bytedance/mockey/issues/15 Change-Id: I1949762b22f78e9b84d9dec8771387f2994e94d4 --- mock.go | 72 ++++++++++++++++++++++++++++++++------------------- mock_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++++ utils_test.go | 4 +-- 3 files changed, 112 insertions(+), 29 deletions(-) diff --git a/mock.go b/mock.go index 60ccfe1..e83d29e 100644 --- a/mock.go +++ b/mock.go @@ -47,15 +47,13 @@ type Mocker struct { } type MockBuilder struct { - target interface{} // 目标函数 - hook interface{} // mock函数 - proxy interface{} // mock之后,原函数地址 - when interface{} // 条件函数 - filterGoroutine FilterGoroutineType - gId int64 - missHookReceiver bool - missWhenReceiver bool - unsafe bool + target interface{} // 目标函数 + hook interface{} // mock函数 + proxyCaller interface{} // mock之后,原函数地址 + when interface{} // 条件函数 + filterGoroutine FilterGoroutineType + gId int64 + unsafe bool } func Mock(target interface{}) *MockBuilder { @@ -78,13 +76,13 @@ func MockUnsafe(target interface{}) *MockBuilder { } func (builder *MockBuilder) Origin(funcPtr interface{}) *MockBuilder { - tool.Assert(builder.proxy == nil, "re-set builder origin") + tool.Assert(builder.proxyCaller == nil, "re-set builder origin") return builder.origin(funcPtr) } func (builder *MockBuilder) origin(funcPtr interface{}) *MockBuilder { tool.AssertPtr(funcPtr) - builder.proxy = funcPtr + builder.proxyCaller = funcPtr return builder } @@ -159,7 +157,7 @@ func (builder *MockBuilder) Build() *Mocker { func (mocker *Mocker) checkReceiver(target reflect.Type, hook interface{}) bool { hType := reflect.TypeOf(hook) - tool.Assert(hType.Kind() == reflect.Func, "Param a is not a func") + tool.Assert(hType.Kind() == reflect.Func, "Param(%v) a is not a func", hType.Kind()) tool.Assert(target.IsVariadic() == hType.IsVariadic(), "target:%v, hook:%v args not match", target, hook) // has receiver if tool.CheckFuncArgs(target, hType, 0) { @@ -175,24 +173,44 @@ func (mocker *Mocker) checkReceiver(target reflect.Type, hook interface{}) bool func (mocker *Mocker) buildHook(builder *MockBuilder) { when := builder.when hook := builder.hook - proxy := builder.proxy - var p reflect.Value + proxy := reflect.New(mocker.target.Type()) - if proxy == nil { - proxy := mocker.target - p = reflect.New(proxy.Type()) - } else { - p = reflect.ValueOf(proxy) - } + var missWhenReceiver, missHookReceiver, missProxyReceiver bool if when != nil { - builder.missWhenReceiver = mocker.checkReceiver(mocker.target.Type(), when) + missWhenReceiver = mocker.checkReceiver(mocker.target.Type(), when) } if hook != nil { - builder.missHookReceiver = mocker.checkReceiver(mocker.target.Type(), hook) + missHookReceiver = mocker.checkReceiver(mocker.target.Type(), hook) } + + proxyCallerSetter := func(args []reflect.Value) {} + + if builder.proxyCaller != nil { + pVal := reflect.ValueOf(builder.proxyCaller) + tool.Assert(pVal.Kind() == reflect.Ptr && pVal.Elem().Kind() == reflect.Func, "origin receiver must be a function pointer") + pElem := pVal.Elem() + missProxyReceiver = mocker.checkReceiver(mocker.target.Type(), pElem.Interface()) + + if missProxyReceiver { + proxyCallerSetter = func(args []reflect.Value) { + pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) { + return tool.ReflectCall(proxy.Elem(), append(args[0:1], innerArgs...)) + })) + } + } else { + proxyCallerSetter = func(args []reflect.Value) { + pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) { + return tool.ReflectCall(proxy.Elem(), innerArgs) + })) + } + } + } + mockerHook := reflect.MakeFunc(mocker.target.Type(), func(args []reflect.Value) []reflect.Value { - origin := p.Elem() + proxyCallerSetter(args) // 设置origin调用proxy + + origin := proxy.Elem() mocker.access() switch builder.filterGoroutine { @@ -212,13 +230,13 @@ func (mocker *Mocker) buildHook(builder *MockBuilder) { if when != nil { wVal := reflect.ValueOf(when) - ret := tool.ReflectCallWithShiftOne(wVal, args, builder.missWhenReceiver) + ret := tool.ReflectCallWithShiftOne(wVal, args, missWhenReceiver) b := ret[0].Bool() if b && hook != nil { hVal = reflect.ValueOf(hook) mocker.mock() - return tool.ReflectCallWithShiftOne(hVal, args, builder.missHookReceiver) + return tool.ReflectCallWithShiftOne(hVal, args, missHookReceiver) } return tool.ReflectCall(origin, args) } @@ -228,10 +246,10 @@ func (mocker *Mocker) buildHook(builder *MockBuilder) { hVal = reflect.ValueOf(hook) } mocker.mock() - return tool.ReflectCallWithShiftOne(hVal, args, builder.missHookReceiver) + return tool.ReflectCallWithShiftOne(hVal, args, missHookReceiver) }) mocker.hook = mockerHook - mocker.proxy = p.Interface() + mocker.proxy = proxy.Interface() } func (mocker *Mocker) Patch() *Mocker { diff --git a/mock_test.go b/mock_test.go index b7ce6bb..2a734b7 100644 --- a/mock_test.go +++ b/mock_test.go @@ -320,3 +320,68 @@ func TestMockUnsafe(t *testing.T) { So(func() { ShortFun() }, ShouldPanicWith, "in hook") }) } + +type foo struct{ i int } + +func (f *foo) Name(i int) string { return fmt.Sprintf("Fn-%v-%v", f.i, i) } + +func (f *foo) Foo() int { return f.i } + +func TestMockOrigin(t *testing.T) { + PatchConvey("struct-origin", t, func() { + PatchConvey("with receiver", func() { + var ori1 func(*foo, int) string + var ori2 func(*foo, int) string + mocker := Mock((*foo).Name).To(func(f *foo, i int) string { + if i == 1 { + return ori1(f, i) + } + return ori2(f, i) + }).Origin(&ori1).Build() + + ori2 = func(f *foo, i int) string { return fmt.Sprintf("Fn-mock2-%v", i) } + So((&foo{100}).Name(1), ShouldEqual, "Fn-100-1") + So((&foo{200}).Name(1), ShouldEqual, "Fn-200-1") + So((&foo{100}).Name(2), ShouldEqual, "Fn-mock2-2") + So((&foo{200}).Name(2), ShouldEqual, "Fn-mock2-2") + + ori1 = func(f *foo, i int) string { return fmt.Sprintf("Fn-mock1-%v", i) } + mocker.Origin(&ori2) + So((&foo{100}).Name(1), ShouldEqual, "Fn-mock1-1") + So((&foo{200}).Name(1), ShouldEqual, "Fn-mock1-1") + So((&foo{100}).Name(2), ShouldEqual, "Fn-100-2") + So((&foo{200}).Name(2), ShouldEqual, "Fn-200-2") + }) + PatchConvey("without receiver", func() { + var ori1 func(int) string + var ori2 func(int) string + mocker := Mock((*foo).Name).To(func(i int) string { + if i == 1 { + return ori1(i) + } + return ori2(i) + }).Origin(&ori1).Build() + + ori2 = func(i int) string { return fmt.Sprintf("Fn-mock2-%v", i) } + So((&foo{100}).Name(1), ShouldEqual, "Fn-100-1") + So((&foo{200}).Name(1), ShouldEqual, "Fn-200-1") + So((&foo{100}).Name(2), ShouldEqual, "Fn-mock2-2") + So((&foo{200}).Name(2), ShouldEqual, "Fn-mock2-2") + + ori1 = func(i int) string { return fmt.Sprintf("Fn-mock1-%v", i) } + mocker.Origin(&ori2) + So((&foo{100}).Name(1), ShouldEqual, "Fn-mock1-1") + So((&foo{200}).Name(1), ShouldEqual, "Fn-mock1-1") + So((&foo{100}).Name(2), ShouldEqual, "Fn-100-2") + So((&foo{200}).Name(2), ShouldEqual, "Fn-200-2") + }) + }) + PatchConvey("issue https://github.com/bytedance/mockey/issues/15", t, func() { + var origin func() int + f := &foo{} + Mock(GetMethod(f, "Foo")).To(func() int { return origin() + 1 }).Origin(&origin).Build() + So((&foo{1}).Foo(), ShouldEqual, 2) + So((&foo{2}).Foo(), ShouldEqual, 3) + So((&foo{3}).Foo(), ShouldEqual, 4) + }) +} diff --git a/utils_test.go b/utils_test.go index 30aa8b7..d9ec822 100644 --- a/utils_test.go +++ b/utils_test.go @@ -258,7 +258,7 @@ type testOuter struct { } type testInner struct { - i int + _ int } func (testInner) FooNested() { @@ -266,7 +266,7 @@ func (testInner) FooNested() { } type testInnerP struct { - s string + _ string } func (*testInnerP) FooNested() {