diff --git a/instrumentation/testing/config/testing.go b/instrumentation/testing/config/testing.go new file mode 100644 index 00000000..fb880815 --- /dev/null +++ b/instrumentation/testing/config/testing.go @@ -0,0 +1,28 @@ +package config + +import ( + "sync" +) + +var ( + testsToSkip map[string]struct{} + + m sync.RWMutex +) + +func SetFqnToSkip(fqns ...string) { + m.Lock() + defer m.Unlock() + + testsToSkip = map[string]struct{}{} + for _, val := range fqns { + testsToSkip[val] = struct{}{} + } +} + +func GetSkipMap() map[string]struct{} { + m.RLock() + defer m.RUnlock() + + return testsToSkip +} diff --git a/instrumentation/testing/init.go b/instrumentation/testing/init.go index b3cb5ece..22bd244f 100644 --- a/instrumentation/testing/init.go +++ b/instrumentation/testing/init.go @@ -18,6 +18,9 @@ func Init(m *testing.M) { tests = append(tests, testing.InternalTest{ Name: test.Name, F: func(t *testing.T) { // Creating a new test function as an indirection of the original test + if shouldSkipTest(t, funcPointer) { + return + } addAutoInstrumentedTest(t) tStruct := StartTestFromCaller(t, funcPointer) defer tStruct.end() diff --git a/instrumentation/testing/testing.go b/instrumentation/testing/testing.go index 2692933d..99d5e67e 100644 --- a/instrumentation/testing/testing.go +++ b/instrumentation/testing/testing.go @@ -18,6 +18,7 @@ import ( "go.undefinedlabs.com/scopeagent/errors" "go.undefinedlabs.com/scopeagent/instrumentation" "go.undefinedlabs.com/scopeagent/instrumentation/logging" + "go.undefinedlabs.com/scopeagent/instrumentation/testing/config" "go.undefinedlabs.com/scopeagent/reflection" "go.undefinedlabs.com/scopeagent/runner" "go.undefinedlabs.com/scopeagent/tags" @@ -135,11 +136,19 @@ func (test *Test) Context() context.Context { // Runs an auto instrumented sub test func (test *Test) Run(name string, f func(t *testing.T)) bool { + pc, _, _, _ := runtime.Caller(1) if test.span == nil { // No span = not instrumented - return test.t.Run(name, f) + return test.t.Run(name, func(cT *testing.T) { + if shouldSkipTest(cT, pc) { + return + } + f(cT) + }) } - pc, _, _, _ := runtime.Caller(1) return test.t.Run(name, func(childT *testing.T) { + if shouldSkipTest(childT, pc) { + return + } addAutoInstrumentedTest(childT) childTest := StartTestFromCaller(childT, pc) defer childTest.end() @@ -252,3 +261,28 @@ func addAutoInstrumentedTest(t *testing.T) { defer autoInstrumentedTestsMutex.Unlock() autoInstrumentedTests[t] = true } + +// Should skip test +func shouldSkipTest(t *testing.T, pc uintptr) bool { + fullTestName := runner.GetOriginalTestName(t.Name()) + testNameSlash := strings.IndexByte(fullTestName, '/') + funcName := fullTestName + if testNameSlash >= 0 { + funcName = fullTestName[:testNameSlash] + } + + funcFullName := runtime.FuncForPC(pc).Name() + funcNameIndex := strings.LastIndex(funcFullName, funcName) + if funcNameIndex < 1 { + funcNameIndex = len(funcFullName) + } + packageName := funcFullName[:funcNameIndex-1] + + fqn := fmt.Sprintf("%s.%s", packageName, fullTestName) + skipMap := config.GetSkipMap() + if _, ok := skipMap[fqn]; ok { + reflection.SkipAndFinishTest(t) + return true + } + return false +} diff --git a/reflection/reflect.go b/reflection/reflect.go index 198794ad..1b6347c8 100644 --- a/reflection/reflect.go +++ b/reflection/reflect.go @@ -46,3 +46,17 @@ func GetIsParallel(t *testing.T) bool { } return false } + +func SkipAndFinishTest(t *testing.T) { + mu := GetTestMutex(t) + if mu != nil { + mu.Lock() + defer mu.Unlock() + } + if pointer, err := GetFieldPointerOf(t, "skipped"); err == nil { + *(*bool)(pointer) = true + } + if pointer, err := GetFieldPointerOf(t, "finished"); err == nil { + *(*bool)(pointer) = true + } +}