Skip to content

Commit

Permalink
Merge pull request #8 from bytedance/7-unsafe-mock
Browse files Browse the repository at this point in the history
feat: support unsafe mock
  • Loading branch information
ycydsxy authored Dec 29, 2022
2 parents 56e41c4 + 0d52431 commit e8a79c6
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 25 deletions.
4 changes: 2 additions & 2 deletions internal/monkey/inst/disasm_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"golang.org/x/arch/x86/x86asm"
)

func Disassemble(code []byte, required int) int {
func Disassemble(code []byte, required int, checkLen bool) int {
var pos int
var err error
var inst x86asm.Inst
Expand All @@ -30,7 +30,7 @@ func Disassemble(code []byte, required int) int {
inst, err = x86asm.Decode(code[pos:], 64)
tool.Assert(err == nil, err)
tool.DebugPrintf("Disassemble: inst: %v\n", inst)
tool.Assert(inst.Op != x86asm.RET, "function is too short to patch")
tool.Assert(inst.Op != x86asm.RET || !checkLen, "function is too short to patch")
pos += inst.Len
}
return pos
Expand Down
2 changes: 1 addition & 1 deletion internal/monkey/inst/disasm_arm64.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
"github.com/bytedance/mockey/internal/tool"
)

func Disassemble(code []byte, required int) int {
func Disassemble(code []byte, required int, checkLen bool) int {
tool.Assert(len(code) > required, "function is too short to patch")
return required
}
14 changes: 4 additions & 10 deletions internal/monkey/patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (p *Patch) Unpatch() {

// PatchValue replace the target function with a hook function, and stores the target function in the proxy function
// for future restore. Target and hook are values of function. Proxy is a value of proxy function pointer.
func PatchValue(target, hook, proxy reflect.Value) *Patch {
func PatchValue(target, hook, proxy reflect.Value, unsafe bool) *Patch {
tool.Assert(hook.Kind() == reflect.Func, "'%s' is not a function", hook.Kind())
tool.Assert(proxy.Kind() == reflect.Ptr, "'%v' is not a function pointer", proxy.Kind())
tool.Assert(hook.Type() == target.Type(), "'%v' and '%s' mismatch", hook.Type(), target.Type())
Expand All @@ -54,7 +54,7 @@ func PatchValue(target, hook, proxy reflect.Value) *Patch {
// construct the branch instruction, i.e. jump to the hook function
hookCode := inst.BranchInto(common.PtrAt(hook))
// search the cutting point of the target code, i.e. the minimum length of full instructions that is longer than the hookCode
cuttingIdx := inst.Disassemble(targetCodeBuf, len(hookCode))
cuttingIdx := inst.Disassemble(targetCodeBuf, len(hookCode), !unsafe)

// construct the proxy code
proxyCode := common.AllocatePage()
Expand All @@ -73,14 +73,8 @@ func PatchValue(target, hook, proxy reflect.Value) *Patch {
return &Patch{base: targetAddr, code: proxyCode, size: cuttingIdx}
}

func PatchFunc(fn, hook, proxy interface{}) *Patch {
func PatchFunc(fn, hook, proxy interface{}, unsafe bool) *Patch {
vv := reflect.ValueOf(fn)
tool.Assert(vv.Kind() == reflect.Func, "'%v' is not a function", fn)
return PatchValue(vv, reflect.ValueOf(hook), reflect.ValueOf(proxy))
}

func PatchMethod(val interface{}, method string, hook, proxy interface{}) *Patch {
m, ok := reflect.TypeOf(val).MethodByName(method)
tool.Assert(ok, "unknown method '%s'", method)
return PatchValue(m.Func, reflect.ValueOf(hook), reflect.ValueOf(proxy))
return PatchValue(vv, reflect.ValueOf(hook), reflect.ValueOf(proxy), unsafe)
}
59 changes: 48 additions & 11 deletions internal/monkey/patch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,68 @@
package monkey

import (
"reflect"
"testing"

"github.com/smartystreets/goconvey/convey"
)

//go:noinline
func A() int {
func Target() int {
return 0
}

func Proxy() int {
return 1
}

func Hook() int {
return 2
}

func UnsafeTarget() {}

func TestPatchFunc(t *testing.T) {
convey.Convey("TestPatchFunc", t, func() {
fun := Proxy
patch := PatchFunc(A, Hook, &fun)
convey.So(A(), convey.ShouldEqual, 2)
convey.So(fun(), convey.ShouldEqual, 0)
patch.Unpatch()
convey.So(A(), convey.ShouldEqual, 0)
convey.Convey("normal", func() {
var proxy func() int
patch := PatchFunc(Target, Hook, &proxy, false)
convey.So(Target(), convey.ShouldEqual, 2)
convey.So(proxy(), convey.ShouldEqual, 0)
patch.Unpatch()
convey.So(Target(), convey.ShouldEqual, 0)
})
convey.Convey("anonymous hook", func() {
var proxy func() int
patch := PatchFunc(Target, func() int { return 2 }, &proxy, false)
convey.So(Target(), convey.ShouldEqual, 2)
convey.So(proxy(), convey.ShouldEqual, 0)
patch.Unpatch()
convey.So(Target(), convey.ShouldEqual, 0)
})
convey.Convey("closure hook", func() {
var proxy func() int
hookBuilder := func(x int) func() int {
return func() int { return x }
}
patch := PatchFunc(Target, hookBuilder(2), &proxy, false)
convey.So(Target(), convey.ShouldEqual, 2)
convey.So(proxy(), convey.ShouldEqual, 0)
patch.Unpatch()
convey.So(Target(), convey.ShouldEqual, 0)
})
convey.Convey("reflect hook", func() {
var proxy func() int
hookVal := reflect.MakeFunc(reflect.TypeOf(Hook), func(args []reflect.Value) (results []reflect.Value) { return []reflect.Value{reflect.ValueOf(2)} })
patch := PatchFunc(Target, hookVal.Interface(), &proxy, false)
convey.So(Target(), convey.ShouldEqual, 2)
convey.So(proxy(), convey.ShouldEqual, 0)
patch.Unpatch()
convey.So(Target(), convey.ShouldEqual, 0)
})
convey.Convey("unsafe", func() {
var proxy func()
patch := PatchFunc(UnsafeTarget, func() { panic("good") }, &proxy, true)
convey.So(func() { UnsafeTarget() }, convey.ShouldPanicWith, "good")
convey.So(func() { proxy() }, convey.ShouldNotPanic)
patch.Unpatch()
convey.So(func() { UnsafeTarget() }, convey.ShouldNotPanic)
})
})
}
14 changes: 13 additions & 1 deletion mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type MockBuilder struct {
gId int64
missHookReceiver bool
missWhenReceiver bool
unsafe bool
}

func Mock(target interface{}) *MockBuilder {
Expand All @@ -63,6 +64,17 @@ func Mock(target interface{}) *MockBuilder {
}
}

// MockUnsafe has the full ability of the Mock function and removes some security restrictions. This is an alternative
// when the Mock function fails. It may cause some unknown problems, so we recommend using Mock under normal conditions.
func MockUnsafe(target interface{}) *MockBuilder {
tool.AssertFunc(target)

return &MockBuilder{
target: target,
unsafe: true,
}
}

func (builder *MockBuilder) Origin(funcPtr interface{}) *MockBuilder {
tool.Assert(builder.proxy == nil, "re-set builder origin")
return builder.origin(funcPtr)
Expand Down Expand Up @@ -226,7 +238,7 @@ func (mocker *Mocker) Patch() *Mocker {
if mocker.isPatched {
return mocker
}
mocker.patch = monkey.PatchValue(mocker.target, mocker.hook, reflect.ValueOf(mocker.proxy))
mocker.patch = monkey.PatchValue(mocker.target, mocker.hook, reflect.ValueOf(mocker.proxy), mocker.builder.unsafe)
mocker.isPatched = true
addToGlobal(mocker)

Expand Down
10 changes: 10 additions & 0 deletions mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ func VariantParam(a int, b ...int) (int, int) {
return a, b[0]
}

func ShortFun() {}

func TestNoConvey(t *testing.T) {
origin := Fun
mock := func(p string) string {
Expand Down Expand Up @@ -310,3 +312,11 @@ func TestRePatch(t *testing.T) {
fmt.Printf("re unpatch can be run")
})
}

func TestMockUnsafe(t *testing.T) {
Convey("TestMockUnsafe", t, func() {
mock := MockUnsafe(ShortFun).To(func() { panic("in hook") }).Build()
defer mock.UnPatch()
So(func() { ShortFun() }, ShouldPanicWith, "in hook")
})
}

0 comments on commit e8a79c6

Please sign in to comment.