From 1e92ea54df7f16897e89f073dd0246935a3a8bee Mon Sep 17 00:00:00 2001 From: "Iskander (Alex) Sharipov" Date: Sun, 16 Jan 2022 19:34:41 +0300 Subject: [PATCH] ruleguard/typematch: improve function type matching (#358) --- analyzer/testdata/src/filtertest/f1.go | 29 ++++++++++++++++ analyzer/testdata/src/filtertest/rules.go | 8 +++++ ruleguard/typematch/patternop_string.go | 15 +++++---- ruleguard/typematch/typematch.go | 41 +++++++++++++++++++++-- ruleguard/typematch/typematch_test.go | 17 ++++++++++ 5 files changed, 101 insertions(+), 9 deletions(-) diff --git a/analyzer/testdata/src/filtertest/f1.go b/analyzer/testdata/src/filtertest/f1.go index a64ed7c5..a0d45dcf 100644 --- a/analyzer/testdata/src/filtertest/f1.go +++ b/analyzer/testdata/src/filtertest/f1.go @@ -39,6 +39,35 @@ func _() { fileTest("f1.go") // want `true` } +func detectFunc() { + var fn func() + + { + typeTest((func(int) bool)(nil), "is predicate func") // want `true` + typeTest((func(string) bool)(nil), "is predicate func") // want `true` + typeTest((func() bool)(nil), "is predicate func") + typeTest((func(int) string)(nil), "is predicate func") + typeTest(fn, "is predicate func") + typeTest(&fn, "is predicate func") + typeTest(10, "is predicate func") + typeTest("str", "is predicate func") + } + + { + typeTest((func(int) bool)(nil), "is func") // want `true` + typeTest((func(string) bool)(nil), "is func") // want `true` + typeTest((func() bool)(nil), "is func") // want `true` + typeTest((func(int) string)(nil), "is func") // want `true` + typeTest((func())(nil), "is func") // want `true` + typeTest(func() {}, "is func") // want `true` + typeTest(fn, "is func") // want `true` + typeTest(&fn, "is func") + typeTest(53, "is func") + typeTest([]int{1}, "is func") + + } +} + func detectObject() { var vec vector2D diff --git a/analyzer/testdata/src/filtertest/rules.go b/analyzer/testdata/src/filtertest/rules.go index 4bbd70c0..472e1129 100644 --- a/analyzer/testdata/src/filtertest/rules.go +++ b/analyzer/testdata/src/filtertest/rules.go @@ -251,4 +251,12 @@ func testRules(m dsl.Matcher) { m.Match(`typeTest($x, $y, "same type sizes")`). Where(m["x"].Type.Size == m["y"].Type.Size). Report(`true`) + + m.Match(`typeTest($x, "is predicate func")`). + Where(m["x"].Type.Is(`func ($_) bool`)). + Report(`true`) + + m.Match(`typeTest($x, "is func")`). + Where(m["x"].Type.Is(`func ($*_) $*_`)). + Report(`true`) } diff --git a/ruleguard/typematch/patternop_string.go b/ruleguard/typematch/patternop_string.go index a909e3e8..672b6b45 100644 --- a/ruleguard/typematch/patternop_string.go +++ b/ruleguard/typematch/patternop_string.go @@ -16,16 +16,17 @@ func _() { _ = x[opArray-5] _ = x[opMap-6] _ = x[opChan-7] - _ = x[opFunc-8] - _ = x[opStructNoSeq-9] - _ = x[opStruct-10] - _ = x[opAnyInterface-11] - _ = x[opNamed-12] + _ = x[opFuncNoSeq-8] + _ = x[opFunc-9] + _ = x[opStructNoSeq-10] + _ = x[opStruct-11] + _ = x[opAnyInterface-12] + _ = x[opNamed-13] } -const _patternOp_name = "opBuiltinTypeopPointeropVaropVarSeqopSliceopArrayopMapopChanopFuncopStructNoSeqopStructopAnyInterfaceopNamed" +const _patternOp_name = "opBuiltinTypeopPointeropVaropVarSeqopSliceopArrayopMapopChanopFuncNoSeqopFuncopStructNoSeqopStructopAnyInterfaceopNamed" -var _patternOp_index = [...]uint8{0, 13, 22, 27, 35, 42, 49, 54, 60, 66, 79, 87, 101, 108} +var _patternOp_index = [...]uint8{0, 13, 22, 27, 35, 42, 49, 54, 60, 71, 77, 90, 98, 112, 119} func (i patternOp) String() string { if i < 0 || i >= patternOp(len(_patternOp_index)-1) { diff --git a/ruleguard/typematch/typematch.go b/ruleguard/typematch/typematch.go index 4749106e..6c4b84d2 100644 --- a/ruleguard/typematch/typematch.go +++ b/ruleguard/typematch/typematch.go @@ -24,6 +24,7 @@ const ( opArray opMap opChan + opFuncNoSeq opFunc opStructNoSeq opStruct @@ -253,6 +254,7 @@ func parseExpr(ctx *Context, e ast.Expr) *pattern { return parseExpr(ctx, e.X) case *ast.FuncType: + hasSeq := false var params []*pattern var results []*pattern if e.Params != nil { @@ -264,6 +266,9 @@ func parseExpr(ctx *Context, e ast.Expr) *pattern { if len(field.Names) != 0 { return nil } + if p.op == opVarSeq { + hasSeq = true + } params = append(params, p) } } @@ -276,11 +281,18 @@ func parseExpr(ctx *Context, e ast.Expr) *pattern { if len(field.Names) != 0 { return nil } + if p.op == opVarSeq { + hasSeq = true + } results = append(results, p) } } + op := opFuncNoSeq + if hasSeq { + op = opFunc + } return &pattern{ - op: opFunc, + op: op, value: len(params), subs: append(params, results...), } @@ -485,7 +497,7 @@ func (p *Pattern) matchIdentical(sub *pattern, typ types.Type) bool { path := strings.SplitAfter(obj.Pkg().Path(), "/vendor/") return path[len(path)-1] == pkgPath && typeName == obj.Name() - case opFunc: + case opFuncNoSeq: typ, ok := typ.(*types.Signature) if !ok { return false @@ -511,6 +523,24 @@ func (p *Pattern) matchIdentical(sub *pattern, typ types.Type) bool { } return true + case opFunc: + typ, ok := typ.(*types.Signature) + if !ok { + return false + } + numParams := sub.value.(int) + params := sub.subs[:numParams] + results := sub.subs[numParams:] + adapter := tupleFielder{x: typ.Params()} + if !p.matchIdenticalFielder(params, &adapter) { + return false + } + adapter.x = typ.Results() + if !p.matchIdenticalFielder(results, &adapter) { + return false + } + return true + case opStructNoSeq: typ, ok := typ.(*types.Struct) if !ok { @@ -549,3 +579,10 @@ type fielder interface { Field(i int) *types.Var NumFields() int } + +type tupleFielder struct { + x *types.Tuple +} + +func (tup *tupleFielder) Field(i int) *types.Var { return tup.x.At(i) } +func (tup *tupleFielder) NumFields() int { return tup.x.Len() } diff --git a/ruleguard/typematch/typematch_test.go b/ruleguard/typematch/typematch_test.go index 2aa54402..6ae6f606 100644 --- a/ruleguard/typematch/typematch_test.go +++ b/ruleguard/typematch/typematch_test.go @@ -104,9 +104,18 @@ func TestIdentical(t *testing.T) { {`func($_) int`, types.NewSignature(nil, types.NewTuple(intVar), types.NewTuple(intVar), false)}, {`func($_) int`, types.NewSignature(nil, types.NewTuple(stringVar), types.NewTuple(intVar), false)}, + {`func($*_) int`, types.NewSignature(nil, types.NewTuple(stringVar), types.NewTuple(intVar), false)}, + {`func($*_) int`, types.NewSignature(nil, nil, types.NewTuple(intVar), false)}, + {`func($*_) $_`, types.NewSignature(nil, nil, types.NewTuple(intVar), false)}, + {`func($t, $t)`, types.NewSignature(nil, types.NewTuple(stringVar, stringVar), nil, false)}, {`func($t, $t)`, types.NewSignature(nil, types.NewTuple(intVar, intVar), nil, false)}, + // Any func. + {`func($*_) $*_`, types.NewSignature(nil, nil, nil, false)}, + {`func($*_) $*_`, types.NewSignature(nil, types.NewTuple(stringVar, stringVar), nil, false)}, + {`func($*_) $*_`, types.NewSignature(nil, types.NewTuple(stringVar), types.NewTuple(intVar), false)}, + {`struct{}`, typeEstruct}, {`struct{int}`, types.NewStruct([]*types.Var{intVar}, nil)}, {`struct{string; int}`, types.NewStruct([]*types.Var{stringVar, intVar}, nil)}, @@ -197,6 +206,14 @@ func TestIdenticalNegative(t *testing.T) { {`func($t, $t)`, types.NewSignature(nil, types.NewTuple(intVar, stringVar), nil, false)}, {`func($t, $t)`, types.NewSignature(nil, types.NewTuple(stringVar, intVar), nil, false)}, + {`func($*_) int`, types.NewSignature(nil, types.NewTuple(stringVar), types.NewTuple(stringVar), false)}, + {`func($*_) int`, types.NewSignature(nil, nil, nil, false)}, + {`func($*_) $_`, types.NewSignature(nil, nil, nil, false)}, + + // Any func negative. + {`func($*_) $*_`, typeInt}, + {`func($*_) $*_`, types.NewTuple(intVar)}, + {`struct{}`, typeInt}, {`struct{}`, types.NewStruct([]*types.Var{intVar}, nil)}, {`struct{int}`, typeEstruct},