Skip to content

Commit

Permalink
fix: Origin(&ori) NOT works properly with struct methods
Browse files Browse the repository at this point in the history
see #15

Change-Id: I1949762b22f78e9b84d9dec8771387f2994e94d4
  • Loading branch information
Sychorius committed Apr 24, 2023
1 parent f53d249 commit 3dba582
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 27 deletions.
72 changes: 45 additions & 27 deletions mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,13 @@ type Mocker struct {
}

type MockBuilder struct {
target interface{} // 目标函数
hook interface{} // mock函数
proxy interface{} // mock之后,原函数地址
when interface{} // 条件函数
filterGoroutine FilterGoroutineType
gId int64
missHookReceiver bool
missWhenReceiver bool
unsafe bool
target interface{} // 目标函数
hook interface{} // mock函数
proxyCaller interface{} // mock之后,原函数地址
when interface{} // 条件函数
filterGoroutine FilterGoroutineType
gId int64
unsafe bool
}

func Mock(target interface{}) *MockBuilder {
Expand All @@ -78,13 +76,13 @@ func MockUnsafe(target interface{}) *MockBuilder {
}

func (builder *MockBuilder) Origin(funcPtr interface{}) *MockBuilder {
tool.Assert(builder.proxy == nil, "re-set builder origin")
tool.Assert(builder.proxyCaller == nil, "re-set builder origin")
return builder.origin(funcPtr)
}

func (builder *MockBuilder) origin(funcPtr interface{}) *MockBuilder {
tool.AssertPtr(funcPtr)
builder.proxy = funcPtr
builder.proxyCaller = funcPtr
return builder
}

Expand Down Expand Up @@ -159,7 +157,7 @@ func (builder *MockBuilder) Build() *Mocker {

func (mocker *Mocker) checkReceiver(target reflect.Type, hook interface{}) bool {
hType := reflect.TypeOf(hook)
tool.Assert(hType.Kind() == reflect.Func, "Param a is not a func")
tool.Assert(hType.Kind() == reflect.Func, "Param(%v) a is not a func", hType.Kind())
tool.Assert(target.IsVariadic() == hType.IsVariadic(), "target:%v, hook:%v args not match", target, hook)
// has receiver
if tool.CheckFuncArgs(target, hType, 0) {
Expand All @@ -175,24 +173,44 @@ func (mocker *Mocker) checkReceiver(target reflect.Type, hook interface{}) bool
func (mocker *Mocker) buildHook(builder *MockBuilder) {
when := builder.when
hook := builder.hook
proxy := builder.proxy

var p reflect.Value
proxy := reflect.New(mocker.target.Type())

if proxy == nil {
proxy := mocker.target
p = reflect.New(proxy.Type())
} else {
p = reflect.ValueOf(proxy)
}
var missWhenReceiver, missHookReceiver, missProxyReceiver bool
if when != nil {
builder.missWhenReceiver = mocker.checkReceiver(mocker.target.Type(), when)
missWhenReceiver = mocker.checkReceiver(mocker.target.Type(), when)
}
if hook != nil {
builder.missHookReceiver = mocker.checkReceiver(mocker.target.Type(), hook)
missHookReceiver = mocker.checkReceiver(mocker.target.Type(), hook)
}

proxyCallerSetter := func(args []reflect.Value) {}

if builder.proxyCaller != nil {
pVal := reflect.ValueOf(builder.proxyCaller)
tool.Assert(pVal.Kind() == reflect.Ptr && pVal.Elem().Kind() == reflect.Func, "origin receiver must be a function pointer")
pElem := pVal.Elem()
missProxyReceiver = mocker.checkReceiver(mocker.target.Type(), pElem.Interface())

if missProxyReceiver {
proxyCallerSetter = func(args []reflect.Value) {
pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) {
return tool.ReflectCall(proxy.Elem(), append(args[0:1], innerArgs...))
}))
}
} else {
proxyCallerSetter = func(args []reflect.Value) {
pElem.Set(reflect.MakeFunc(pElem.Type(), func(innerArgs []reflect.Value) (results []reflect.Value) {
return tool.ReflectCall(proxy.Elem(), innerArgs)
}))
}
}
}

mockerHook := reflect.MakeFunc(mocker.target.Type(), func(args []reflect.Value) []reflect.Value {
origin := p.Elem()
proxyCallerSetter(args) // 设置origin调用proxy

origin := proxy.Elem()
mocker.access()

switch builder.filterGoroutine {
Expand All @@ -212,13 +230,13 @@ func (mocker *Mocker) buildHook(builder *MockBuilder) {

if when != nil {
wVal := reflect.ValueOf(when)
ret := tool.ReflectCallWithShiftOne(wVal, args, builder.missWhenReceiver)
ret := tool.ReflectCallWithShiftOne(wVal, args, missWhenReceiver)
b := ret[0].Bool()

if b && hook != nil {
hVal = reflect.ValueOf(hook)
mocker.mock()
return tool.ReflectCallWithShiftOne(hVal, args, builder.missHookReceiver)
return tool.ReflectCallWithShiftOne(hVal, args, missHookReceiver)
}
return tool.ReflectCall(origin, args)
}
Expand All @@ -228,10 +246,10 @@ func (mocker *Mocker) buildHook(builder *MockBuilder) {
hVal = reflect.ValueOf(hook)
}
mocker.mock()
return tool.ReflectCallWithShiftOne(hVal, args, builder.missHookReceiver)
return tool.ReflectCallWithShiftOne(hVal, args, missHookReceiver)
})
mocker.hook = mockerHook
mocker.proxy = p.Interface()
mocker.proxy = proxy.Interface()
}

func (mocker *Mocker) Patch() *Mocker {
Expand Down
65 changes: 65 additions & 0 deletions mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,68 @@ func TestMockUnsafe(t *testing.T) {
So(func() { ShortFun() }, ShouldPanicWith, "in hook")
})
}

type foo struct{ i int }

func (f *foo) Name(i int) string { return fmt.Sprintf("Fn-%v-%v", f.i, i) }

func (f *foo) Foo() int { return f.i }

func TestMockOrigin(t *testing.T) {
PatchConvey("struct-origin", t, func() {
PatchConvey("with receiver", func() {
var ori1 func(*foo, int) string
var ori2 func(*foo, int) string
mocker := Mock((*foo).Name).To(func(f *foo, i int) string {
if i == 1 {
return ori1(f, i)
}
return ori2(f, i)
}).Origin(&ori1).Build()

ori2 = func(f *foo, i int) string { return fmt.Sprintf("Fn-mock2-%v", i) }
So((&foo{100}).Name(1), ShouldEqual, "Fn-100-1")
So((&foo{200}).Name(1), ShouldEqual, "Fn-200-1")
So((&foo{100}).Name(2), ShouldEqual, "Fn-mock2-2")
So((&foo{200}).Name(2), ShouldEqual, "Fn-mock2-2")

ori1 = func(f *foo, i int) string { return fmt.Sprintf("Fn-mock1-%v", i) }
mocker.Origin(&ori2)
So((&foo{100}).Name(1), ShouldEqual, "Fn-mock1-1")
So((&foo{200}).Name(1), ShouldEqual, "Fn-mock1-1")
So((&foo{100}).Name(2), ShouldEqual, "Fn-100-2")
So((&foo{200}).Name(2), ShouldEqual, "Fn-200-2")
})
PatchConvey("without receiver", func() {
var ori1 func(int) string
var ori2 func(int) string
mocker := Mock((*foo).Name).To(func(i int) string {
if i == 1 {
return ori1(i)
}
return ori2(i)
}).Origin(&ori1).Build()

ori2 = func(i int) string { return fmt.Sprintf("Fn-mock2-%v", i) }
So((&foo{100}).Name(1), ShouldEqual, "Fn-100-1")
So((&foo{200}).Name(1), ShouldEqual, "Fn-200-1")
So((&foo{100}).Name(2), ShouldEqual, "Fn-mock2-2")
So((&foo{200}).Name(2), ShouldEqual, "Fn-mock2-2")

ori1 = func(i int) string { return fmt.Sprintf("Fn-mock1-%v", i) }
mocker.Origin(&ori2)
So((&foo{100}).Name(1), ShouldEqual, "Fn-mock1-1")
So((&foo{200}).Name(1), ShouldEqual, "Fn-mock1-1")
So((&foo{100}).Name(2), ShouldEqual, "Fn-100-2")
So((&foo{200}).Name(2), ShouldEqual, "Fn-200-2")
})
})
PatchConvey("issue https://github.com/bytedance/mockey/issues/15", t, func() {
var origin func() int
f := &foo{}
Mock(GetMethod(f, "Foo")).To(func() int { return origin() + 1 }).Origin(&origin).Build()
So((&foo{1}).Foo(), ShouldEqual, 2)
So((&foo{2}).Foo(), ShouldEqual, 3)
So((&foo{3}).Foo(), ShouldEqual, 4)
})
}

0 comments on commit 3dba582

Please sign in to comment.