Skip to content

Commit

Permalink
correctly detect indirect uses of reflection
Browse files Browse the repository at this point in the history
Fixes #554
  • Loading branch information
lu4p committed Jun 21, 2022
1 parent 2d12f41 commit 1f10d49
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 95 deletions.
319 changes: 224 additions & 95 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (
"unicode"
"unicode/utf8"

"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/mod/modfile"
"golang.org/x/mod/module"
Expand Down Expand Up @@ -1111,80 +1110,244 @@ func loadCachedOutputs() error {
return nil
}

func (tf *transformer) findReflectFunctions(files []*ast.File) {
seenReflectParams := make(map[*types.Var]bool)
visitFuncDecl := func(funcDecl *ast.FuncDecl) {
funcObj := tf.info.Defs[funcDecl.Name].(*types.Func)
funcType := funcObj.Type().(*types.Signature)
funcParams := funcType.Params()
type potentialReflectMap map[*types.Var]potentialReflectParam

type potentialReflectParam struct {
is bool
related *types.Var
}

maps.Clear(seenReflectParams)
for i := 0; i < funcParams.Len(); i++ {
seenReflectParams[funcParams.At(i)] = false
func flagParamReflected(obj *types.Var, potentialReflectParams potentialReflectMap) {
if param, ok := potentialReflectParams[obj]; ok {
if param.related != nil {
flagParamReflected(param.related, potentialReflectParams)
}

ast.Inspect(funcDecl, func(node ast.Node) bool {
call, ok := node.(*ast.CallExpr)
if !ok {
return true
}
param.is = true

potentialReflectParams[obj] = param
}
}

func getIdent(node ast.Node) *ast.Ident {
name, ok := node.(*ast.Ident)
if !ok {
sel, ok := node.(*ast.SelectorExpr)
if !ok {
return nil
}

return getIdent(sel.X)
}
return name
}

func (tf *transformer) ignoreReflectedTypes(node ast.Node) {
visit := func(node ast.Node) bool {
call, ok := node.(*ast.CallExpr)
if !ok {
return true
}

ident, ok := call.Fun.(*ast.Ident)
if !ok {
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return true
}
calledFunc, _ := tf.info.Uses[sel.Sel].(*types.Func)
if calledFunc == nil || calledFunc.Pkg() == nil {
return true

ident = sel.Sel
}

fnType, _ := tf.info.Uses[ident].(*types.Func)
if fnType == nil || fnType.Pkg() == nil {
return true
}

fullName := fnType.FullName()
for _, reflectParam := range cachedOutput.KnownReflectAPIs[fullName] {
argStart := reflectParam.Position
argEnd := argStart + 1
if reflectParam.Variadic {
argEnd = len(call.Args)
}
for _, arg := range call.Args[argStart:argEnd] {
argType := tf.info.TypeOf(arg)

fullName := calledFunc.FullName()
for _, reflectParam := range cachedOutput.KnownReflectAPIs[fullName] {
// We need a range to handle any number of variadic arguments,
// which could be 0 or multiple.
// The non-variadic case is always one argument,
// but we still use the range to deduplicate code.
argStart := reflectParam.Position
argEnd := argStart + 1
if reflectParam.Variadic {
argEnd = len(call.Args)
}
for _, arg := range call.Args[argStart:argEnd] {
ident, ok := arg.(*ast.Ident)
if !ok {
continue
}
obj, _ := tf.info.Uses[ident].(*types.Var)
if obj == nil {
continue
}
if _, ok := seenReflectParams[obj]; ok {
seenReflectParams[obj] = true
}
}
tf.recursivelyRecordAsNotObfuscated(argType)
}
}

var reflectParams []reflectParameter
for i := 0; i < funcParams.Len(); i++ {
if seenReflectParams[funcParams.At(i)] {
reflectParams = append(reflectParams, reflectParameter{
Position: i,
Variadic: funcType.Variadic() && i == funcParams.Len()-1,
})
}
return true
}

ast.Inspect(node, visit)
}

func (tf *transformer) reflectAddAssign(funcDecl *ast.FuncDecl, potentialReflectParams potentialReflectMap) {
appendValue := func(value ast.Expr, name *ast.Ident) {
objName, _ := tf.info.ObjectOf(name).(*types.Var)
if objName == nil {
return
}

ident := getIdent(value)
if ident == nil {
return
}

objVar, _ := tf.info.Uses[ident].(*types.Var)
if objVar == nil {
return
}

// Check if the Rhs is a function parameter.
_, ok := potentialReflectParams[objVar]
if !ok {
return
}
// This function paramter gets assigned to another variable.

if named := namedType(objName.Type()); named != nil {
typeObj := named.Obj()
if recordedAsNotObfuscated(typeObj) {
// The type of the Lhs is already flagged as being reflected,
// therefore also flag this parameter as being reflected
flagParamReflected(objVar, potentialReflectParams)

return
}
if len(reflectParams) > 0 {
cachedOutput.KnownReflectAPIs[funcObj.FullName()] = reflectParams
}

// Keep track of this assignment.
potentialReflectParams[objName] = potentialReflectParam{
is: false,
related: objVar,
}
}

inspectAssign := func(node ast.Node) {
spec, ok := node.(*ast.AssignStmt)
if !ok {
return
}

if len(spec.Rhs) != len(spec.Lhs) {
return
}

for i, lhs := range spec.Lhs {
name := getIdent(lhs)
if name == nil {
continue
}

value := spec.Rhs[i]

appendValue(value, name)
}
}

inspectSpec := func(node ast.Node) {
spec, ok := node.(*ast.ValueSpec)
if !ok {
return
}

if len(spec.Values) != len(spec.Names) {
return
}

for i, name := range spec.Names {
value := spec.Values[i]

appendValue(value, name)
}
}

ast.Inspect(funcDecl, func(node ast.Node) bool {
inspectAssign(node)
inspectSpec(node)

return true
})
}

func (tf *transformer) visitReflectFuncDecl(funcDecl *ast.FuncDecl) {
funcObj := tf.info.Defs[funcDecl.Name].(*types.Func)
funcType := funcObj.Type().(*types.Signature)
funcParams := funcType.Params()

potentialReflectParams := make(potentialReflectMap, funcParams.Len())
for i := 0; i < funcParams.Len(); i++ {
potentialReflectParams[funcParams.At(i)] = potentialReflectParam{
is: false,
}
}

tf.reflectAddAssign(funcDecl, potentialReflectParams)

ast.Inspect(funcDecl, func(node ast.Node) bool {
call, ok := node.(*ast.CallExpr)
if !ok {
return true
}
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return true
})
}
calledFunc, _ := tf.info.Uses[sel.Sel].(*types.Func)
if calledFunc == nil || calledFunc.Pkg() == nil {
return true
}

fullName := calledFunc.FullName()
for _, reflectParam := range cachedOutput.KnownReflectAPIs[fullName] {
// We need a range to handle any number of variadic arguments,
// which could be 0 or multiple.
// The non-variadic case is always one argument,
// but we still use the range to deduplicate code.
argStart := reflectParam.Position
argEnd := argStart + 1
if reflectParam.Variadic {
argEnd = len(call.Args)
}
for _, arg := range call.Args[argStart:argEnd] {
ident := getIdent(arg)

obj, _ := tf.info.Uses[ident].(*types.Var)
if obj == nil {
continue
}

flagParamReflected(obj, potentialReflectParams)
}
}

return true
})

var reflectParams []reflectParameter
for i := 0; i < funcParams.Len(); i++ {
if potentialReflectParams[funcParams.At(i)].is {
reflectParams = append(reflectParams, reflectParameter{
Position: i,
Variadic: funcType.Variadic() && i == funcParams.Len()-1,
})
}
}
if len(reflectParams) > 0 {
cachedOutput.KnownReflectAPIs[funcObj.FullName()] = reflectParams
}
}

func (tf *transformer) findReflectFunctions(files []*ast.File) {
lenPrevKnownReflectAPIs := len(cachedOutput.KnownReflectAPIs)
for _, file := range files {
tf.ignoreReflectedTypes(file)
for _, decl := range file.Decls {
if decl, ok := decl.(*ast.FuncDecl); ok {
visitFuncDecl(decl)
tf.visitReflectFuncDecl(decl)
}
}
}
Expand Down Expand Up @@ -1247,46 +1410,6 @@ func (tf *transformer) prefillObjectMaps(files []*ast.File) error {
}
tf.linkerVariableStrings[obj] = stringValue
})

visit := func(node ast.Node) bool {
call, ok := node.(*ast.CallExpr)
if !ok {
return true
}

ident, ok := call.Fun.(*ast.Ident)
if !ok {
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return true
}

ident = sel.Sel
}

fnType, _ := tf.info.Uses[ident].(*types.Func)
if fnType == nil || fnType.Pkg() == nil {
return true
}

fullName := fnType.FullName()
for _, reflectParam := range cachedOutput.KnownReflectAPIs[fullName] {
argStart := reflectParam.Position
argEnd := argStart + 1
if reflectParam.Variadic {
argEnd = len(call.Args)
}
for _, arg := range call.Args[argStart:argEnd] {
argType := tf.info.TypeOf(arg)
tf.recursivelyRecordAsNotObfuscated(argType)
}
}

return true
}
for _, file := range files {
ast.Inspect(file, visit)
}
return nil
}

Expand Down Expand Up @@ -1432,10 +1555,16 @@ func recordedObjectString(obj types.Object) objectString {
return fmt.Sprintf("%s.%s - %s:%d", obj.Pkg().Path(), obj.Name(),
filepath.Base(pos.Filename), pos.Line)
}

pkg := obj.Pkg()
if pkg == nil {
return ""
}

// Names which are not at the top level cannot be imported,
// so we don't need to record them either.
// Note that this doesn't apply to fields, which are never top-level.
if obj.Pkg().Scope().Lookup(obj.Name()) != obj {
if pkg.Scope().Lookup(obj.Name()) != obj {
return ""
}
// For top-level exported names, "pkgpath.Name" is unique.
Expand Down
Loading

0 comments on commit 1f10d49

Please sign in to comment.