diff --git a/pkg/config/config.go b/pkg/config/config.go index 36c1d6c3..1546b0fc 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -24,6 +24,8 @@ type Config struct { Quiet bool Recursive bool SrcPkg string + // StructName overrides the name given to the mock struct and should only be nonempty + // when generating for an exact match (non regex expression in -name). StructName string Tags string TestOnly bool diff --git a/pkg/generator.go b/pkg/generator.go index df576d2f..cd48a621 100644 --- a/pkg/generator.go +++ b/pkg/generator.go @@ -40,27 +40,22 @@ func stripChars(str, chr string) string { // Generator is responsible for generating the string containing // imports and the mock struct that will later be written out as file. type Generator struct { + config.Config buf bytes.Buffer - ip bool iface *Interface pkg string localPackageName *string - importsWerePopulated bool - localizationCache map[string]string - packagePathToName map[string]string - nameToPackagePath map[string]string + localizationCache map[string]string + packagePathToName map[string]string + nameToPackagePath map[string]string packageRoots []string - - // structName overrides the name given to the mock struct and should only be nonempty - // when generating for an exact match (non regex expression in -name). - structName string } // NewGenerator builds a Generator. -func NewGenerator(ctx context.Context, iface *Interface, pkg string, inPackage bool, structName string) *Generator { +func NewGenerator(ctx context.Context, c config.Config, iface *Interface, pkg string) *Generator { var roots []string @@ -69,14 +64,13 @@ func NewGenerator(ctx context.Context, iface *Interface, pkg string, inPackage b } g := &Generator{ + Config: c, iface: iface, pkg: pkg, - ip: inPackage, localizationCache: make(map[string]string), packagePathToName: make(map[string]string), nameToPackagePath: make(map[string]string), packageRoots: roots, - structName: structName, } g.addPackageImportWithName(ctx, "github.com/stretchr/testify/mock", "mock") @@ -84,9 +78,10 @@ func NewGenerator(ctx context.Context, iface *Interface, pkg string, inPackage b } func (g *Generator) populateImports(ctx context.Context) { - if g.importsWerePopulated { - return - } + log := zerolog.Ctx(ctx) + + log.Debug().Msgf("populating imports") + for i := 0; i < g.iface.Type.NumMethods(); i++ { fn := g.iface.Type.Method(i) ftype := fn.Type().(*types.Signature) @@ -218,11 +213,11 @@ func (g *Generator) getLocalizedPath(ctx context.Context, path string) string { } func (g *Generator) mockName() string { - if g.structName != "" { - return g.structName + if g.StructName != "" { + return g.StructName } - if g.ip { + if g.InPackage { if ast.IsExported(g.iface.Name) { return "Mock" + g.iface.Name } @@ -258,12 +253,21 @@ func (g *Generator) sortedImportNames() (importNames []string) { return } -func (g *Generator) generateImports() { +func (g *Generator) generateImports(ctx context.Context) { + log := zerolog.Ctx(ctx) + + log.Debug().Msgf("generating imports") + log.Debug().Msgf("%v", g.nameToPackagePath) + pkgPath := g.nameToPackagePath[g.iface.Pkg.Name()] // Sort by import name so that we get a deterministic order for _, name := range g.sortedImportNames() { + logImport := log.With().Str(logging.LogKeyImport, g.nameToPackagePath[name]).Logger() + logImport.Debug().Msgf("found import") + path := g.nameToPackagePath[name] - if g.ip && path == pkgPath { + if g.InPackage && path == pkgPath { + logImport.Debug().Msgf("import (%s) equals interface's package path (%s), skipping", path, pkgPath) continue } g.printf("import %s \"%s\"\n", name, path) @@ -273,13 +277,13 @@ func (g *Generator) generateImports() { // GeneratePrologue generates the prologue of the mock. func (g *Generator) GeneratePrologue(ctx context.Context, pkg string) { g.populateImports(ctx) - if g.ip { + if g.InPackage { g.printf("package %s\n\n", g.iface.Pkg.Name()) } else { g.printf("package %v\n\n", pkg) } - g.generateImports() + g.generateImports(ctx) g.printf("\n") } @@ -340,7 +344,7 @@ func (g *Generator) renderType(ctx context.Context, typ types.Type) string { switch t := typ.(type) { case *types.Named: o := t.Obj() - if o.Pkg() == nil || o.Pkg().Name() == "main" || (g.ip && o.Pkg() == g.iface.Pkg) { + if o.Pkg() == nil || o.Pkg().Name() == "main" || (g.InPackage && o.Pkg() == g.iface.Pkg) { return o.Name() } return g.addPackageImport(ctx, o.Pkg()) + "." + o.Name() diff --git a/pkg/generator_test.go b/pkg/generator_test.go index 71d011b1..884764a0 100644 --- a/pkg/generator_test.go +++ b/pkg/generator_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/vektra/mockery/pkg/config" ) const pkg = "test" @@ -48,7 +49,10 @@ func (s *GeneratorSuite) getGenerator( filepath, interfaceName string, inPackage bool, structName string, ) *Generator { return NewGenerator( - s.ctx, s.getInterfaceFromFile(filepath, interfaceName), pkg, inPackage, structName, + s.ctx, config.Config{ + StructName: structName, + InPackage: inPackage, + }, s.getInterfaceFromFile(filepath, interfaceName), pkg, ) } @@ -74,7 +78,7 @@ func (s *GeneratorSuite) checkGeneration( s.Equal( expectedLines, actualLines, - "The generator produced the expected output.", + "The generator produced unexpected output.", ) return generator } @@ -85,7 +89,7 @@ func (s *GeneratorSuite) checkPrologueGeneration( generator.GeneratePrologue(ctx, "mocks") s.Equal( expected, generator.buf.String(), - "The generator produced the expected prologue.", + "The generator produced an unexpected prologue.", ) } @@ -954,7 +958,9 @@ func (_m *Example) B(_a0 string) fixtureshttp.MyStruct { func (s *GeneratorSuite) TestGeneratorWithImportSameAsLocalPackageInpkgNoCycle() { iface := s.getInterfaceFromFile("imports_same_as_package.go", "ImportsSameAsPackage") pkg := iface.QualifiedName - gen := NewGenerator(s.ctx, iface, pkg, true, "") + gen := NewGenerator(s.ctx, config.Config{ + InPackage: true, + }, iface, pkg) gen.GeneratePrologue(s.ctx, pkg) s.NotContains(gen.buf.String(), `import test "github.com/vektra/mockery/pkg/fixtures/test"`) } diff --git a/pkg/logging/logging.go b/pkg/logging/logging.go index 52f368eb..6fc1f97f 100644 --- a/pkg/logging/logging.go +++ b/pkg/logging/logging.go @@ -6,6 +6,7 @@ const ( LogKeyDryRun = "dry-run" LogKeyFile = "file" LogKeyInterface = "interface" + LogKeyImport = "import" LogKeyPath = "path" LogKeyQualifiedName = "qualified-name" ) diff --git a/pkg/walker.go b/pkg/walker.go index 75b4bb05..1b78a815 100644 --- a/pkg/walker.go +++ b/pkg/walker.go @@ -142,7 +142,7 @@ func (this *GeneratorVisitor) VisitWalk(ctx context.Context, iface *Interface) e } defer closer() - gen := NewGenerator(ctx, iface, pkg, this.InPackage, this.StructName) + gen := NewGenerator(ctx, this.Config, iface, pkg) gen.GeneratePrologueNote(this.Note) gen.GeneratePrologue(ctx, pkg)