diff --git a/convey.go b/convey.go index 5e3672a..5c8f307 100644 --- a/convey.go +++ b/convey.go @@ -50,8 +50,10 @@ func PatchConvey(items ...interface{}) { func addToGlobal(mocker mockerInstance) { tool.DebugPrintf("%v added\n", mocker.key()) - _, ok := gMocker[len(gMocker)-1][mocker.key()] - tool.Assert(!ok, "re-mock %v", mocker.name()) + last, ok := gMocker[len(gMocker)-1][mocker.key()] + if ok { + tool.Assert(!ok, "re-mock %v, previous mock at: %v", mocker.name(), last.caller()) + } gMocker[len(gMocker)-1][mocker.key()] = mocker } diff --git a/internal/tool/caller.go b/internal/tool/caller.go new file mode 100644 index 0000000..6907494 --- /dev/null +++ b/internal/tool/caller.go @@ -0,0 +1,75 @@ +/* + * Copyright 2022 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tool + +import ( + "fmt" + "runtime" + "strings" +) + +type CallerInfo runtime.Frame + +func (c CallerInfo) String() string { + return fmt.Sprintf("%s:%d", c.File, c.Line) +} + +// Caller gets non-current package caller of a function +// For example, assume we have 3 files: a/b/foo.go, a/c/bar.go and a/c/innerBar.go, +// a/b/foo.Foo calls a/c/bar.Bar, and a/c/bar.Bar calls a/c/innerBar.innerBar. +// Here is how innerBar looks like: +// +// func innerBar() CallerInfo { /*do some thing*/ return Caller() } +// +// The return value of innerBar should represent the line in a/b/foo.go where a/b/foo.Foo calls a/c/bar.Bar +func OuterCaller() CallerInfo { + caller, _, _, _ := runtime.Caller(1) + oriPkg, _ := getPackageAndFunction(caller) + + pc := make([]uintptr, 10) + n := runtime.Callers(2, pc) + pc = pc[:n] + frames := runtime.CallersFrames(pc) + for frame, more := frames.Next(); more; frame, more = frames.Next() { + curPkg, _ := getPackageAndFunction(frame.PC) + if curPkg != oriPkg { + return CallerInfo(frame) + } + } + return CallerInfo(runtime.Frame{File: "Nan"}) +} + +func Caller() CallerInfo { + caller, _, _, _ := runtime.Caller(1) + frame, _ := runtime.CallersFrames([]uintptr{caller}).Next() + return CallerInfo(frame) +} + +func getPackageAndFunction(pc uintptr) (string, string) { + parts := strings.Split(runtime.FuncForPC(pc).Name(), ".") + pl := len(parts) + packageName := "" + funcName := parts[pl-1] + + if parts[pl-2][0] == '(' { + funcName = parts[pl-2] + "." + funcName + packageName = strings.Join(parts[0:pl-2], ".") + } else { + packageName = strings.Join(parts[0:pl-1], ".") + } + return packageName, funcName +} diff --git a/mock.go b/mock.go index ecb8cb9..60ccfe1 100644 --- a/mock.go +++ b/mock.go @@ -42,6 +42,8 @@ type Mocker struct { lock sync.Mutex isPatched bool builder *MockBuilder + + outerCaller tool.CallerInfo // Mocker 的外部调用位置 } type MockBuilder struct { @@ -242,6 +244,7 @@ func (mocker *Mocker) Patch() *Mocker { mocker.isPatched = true addToGlobal(mocker) + mocker.outerCaller = tool.OuterCaller() return mocker } @@ -341,3 +344,7 @@ func (mocker *Mocker) name() string { func (mocker *Mocker) unPatch() { mocker.UnPatch() } + +func (mocker *Mocker) caller() tool.CallerInfo { + return mocker.outerCaller +} diff --git a/mock_var.go b/mock_var.go index 58ff578..910816c 100644 --- a/mock_var.go +++ b/mock_var.go @@ -27,6 +27,8 @@ type mockerInstance interface { key() uintptr name() string unPatch() + + caller() tool.CallerInfo } type MockerVar struct { @@ -36,6 +38,8 @@ type MockerVar struct { origin interface{} // 原始值 lock sync.Mutex isPatched bool + + outerCaller tool.CallerInfo } func MockValue(targetPtr interface{}) *MockerVar { @@ -72,6 +76,8 @@ func (mocker *MockerVar) Patch() *MockerVar { mocker.target.Set(mocker.hook) mocker.isPatched = true addToGlobal(mocker) + + mocker.outerCaller = tool.OuterCaller() } return mocker @@ -98,9 +104,16 @@ func (mocker *MockerVar) key() uintptr { } func (mocker *MockerVar) name() string { + if mocker.target.Kind() == reflect.String { + return "" + } return mocker.target.String() } func (mocker *MockerVar) unPatch() { mocker.UnPatch() } + +func (mocker *MockerVar) caller() tool.CallerInfo { + return mocker.outerCaller +} diff --git a/tests/remock_test.go b/tests/remock_test.go new file mode 100644 index 0000000..1c737b5 --- /dev/null +++ b/tests/remock_test.go @@ -0,0 +1,84 @@ +// Copyright 2023 2022 ByteDance Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package tests + +import ( + "fmt" + "strings" + "testing" + + "github.com/bytedance/mockey" + "github.com/bytedance/mockey/internal/tool" + "github.com/smartystreets/goconvey/convey" +) + +func callerFunc() { + panic("CallerFunc") +} + +type callerStruct struct { + _ int +} + +func (c *callerStruct) Foo() { + panic("CallerStruct") +} + +var callerValue string + +func TestReMockPanic(t *testing.T) { + mockey.PatchConvey("TestReMockPanic", t, func() { + mockey.PatchConvey("callerFunc", func() { + mocker := mockey.Mock(callerFunc).To(func() { fmt.Println("should not panic") }).Build() + mocker.To(func() { fmt.Println("should also not panic") }) + lastCaller := tool.Caller() + lastCaller.Line -= 1 + var err interface{} + func() { + defer func() { err = recover() }() + mockey.Mock(callerFunc).To(func() { fmt.Println("should panic, but recovered") }).Build() + }() + errString, ok := err.(string) + convey.So(ok, convey.ShouldBeTrue) + convey.So(strings.Contains(errString, lastCaller.String()), convey.ShouldBeTrue) + }) + mockey.PatchConvey("callerStruct", func() { + mocker := mockey.Mock((*callerStruct).Foo).To(func() { fmt.Println("should not panic") }).Build() + mocker.To(func() { fmt.Println("should also not panic") }) + lastCaller := tool.Caller() + lastCaller.Line -= 1 + var err interface{} + func() { + defer func() { err = recover() }() + mockey.Mock((*callerStruct).Foo).To(func() { fmt.Println("should panic, but recovered") }).Build() + }() + errString, ok := err.(string) + convey.So(ok, convey.ShouldBeTrue) + convey.So(strings.Contains(errString, lastCaller.String()), convey.ShouldBeTrue) + }) + mockey.PatchConvey("callerValue", func() { + mockey.MockValue(&callerValue).To("should not panic") + lastCaller := tool.Caller() + lastCaller.Line -= 1 + var err interface{} + func() { + defer func() { err = recover() }() + mockey.MockValue(&callerValue).To("should panic, but recovered") + }() + errString, ok := err.(string) + convey.So(ok, convey.ShouldBeTrue) + convey.So(strings.Contains(errString, lastCaller.String()), convey.ShouldBeTrue) + }) + }) +}