Skip to content

Commit

Permalink
Add provider exec generation
Browse files Browse the repository at this point in the history
  • Loading branch information
yogeshlonkar committed Jun 6, 2023
1 parent 5fd115a commit 46016fc
Show file tree
Hide file tree
Showing 16 changed files with 193 additions and 73 deletions.
7 changes: 7 additions & 0 deletions extension/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,21 @@ import (
//
// //go:gocode mockery --name Provider --output ../../mocks.
type Provider interface {
AddDatasource(string, string)
AddResource(string, string)
Datasources() map[string]string
Description() string
ExecFilename() string
ExecGoName() string
Filename() string
GoName() string
HasServiceClient() bool
Model() Model
ModelGoName() string
Option() *pb.Provider
PackageData() PackageData
Resources() map[string]string
SetHasServiceClient(bool)
TerraformName() string
}

Expand Down
12 changes: 6 additions & 6 deletions extensionimpl/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,18 @@ func (b *block) Description() string {
return b.option.Description
}

func (b *block) Filename() string {
return toSnakeCase(b.GoName()) + ".pb.go"
func (b *block) ExecGoName() string {
return fmt.Sprintf("%sExec", b.GoName())
}

func (b *block) ExecFilename() string {
return toSnakeCase(b.GoName()) + "_exec.go"
}

func (b *block) Filename() string {
return toSnakeCase(b.GoName()) + ".pb.go"
}

func (b *block) GoName() string {
if b._type == pb.E_Resource {
return fmt.Sprintf("%sResource", *b.option.Name)
Expand Down Expand Up @@ -82,10 +86,6 @@ func (b *block) Option() *pb.Block {
return b.option
}

func (b *block) ExecGoName() string {
return fmt.Sprintf("%sExec", b.GoName())
}

func (b *block) TerraformName() string {
return toSnakeCase(*b.option.Name)
}
Expand Down
9 changes: 2 additions & 7 deletions extensionimpl/misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"

"github.com/travix/protoc-gen-gotf/pb"
)
Expand Down Expand Up @@ -64,18 +63,14 @@ func deferToComment(direct *string, comments protogen.CommentSet) *string {
return &str
}

func getPkgName(options protoreflect.ProtoMessage) string {
fileOpt, _ := options.(*descriptorpb.FileOptions)
goPkg := fileOpt.GetGoPackage()
func getPkgName(goPkg string) string {
if i := strings.Index(goPkg, ";"); i >= 0 {
return goPkg[i+1:]
}
return GoSanitized(path.Base(goPkg))
}

func getImportPath(options protoreflect.ProtoMessage) string {
fileOpt, _ := options.(*descriptorpb.FileOptions)
goPkg := fileOpt.GetGoPackage()
func getImportPath(goPkg string) string {
if i := strings.Index(goPkg, ";"); i >= 0 {
return goPkg[:i]
}
Expand Down
4 changes: 3 additions & 1 deletion extensionimpl/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"unicode/utf8"

"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/types/descriptorpb"

"github.com/travix/protoc-gen-gotf/extension"
"github.com/travix/protoc-gen-gotf/pb"
Expand All @@ -23,7 +24,8 @@ type model struct {

func NewModel(synth synthesizer, msg *protogen.Message, explicit bool) (extension.Model, error) {
m := &model{message: msg}
m.pkgName = getPkgName(msg.Desc.ParentFile().Options())
fileOpt, _ := msg.Desc.ParentFile().Options().(*descriptorpb.FileOptions)
m.pkgName = getPkgName(fileOpt.GetGoPackage())
for _, field := range msg.Fields {
attr, err := synth.Attribute(field, explicit)
if err != nil {
Expand Down
53 changes: 42 additions & 11 deletions extensionimpl/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,21 @@ import (
var _ extension.Provider = &provider{}

type provider struct {
model extension.Model
module string
option *pb.Provider
packageData extension.PackageData
datasources map[string]string
hasServiceClient bool
model extension.Model
module string
option *pb.Provider
packageData extension.PackageData
resources map[string]string
}

func NewProvider(synth extension.Synthesizer, msg *protogen.Message) (extension.Provider, error) {
option := synth.ProviderOption(msg.Desc)
if option == nil {
return nil, nil
}
p := &provider{option: option}
p := &provider{option: option, datasources: map[string]string{}, resources: map[string]string{}}
if p.option.Name == "" {
p.option.Name = msg.GoIdent.GoName
}
Expand All @@ -39,14 +42,14 @@ func NewProvider(synth extension.Synthesizer, msg *protogen.Message) (extension.
if p.option.ProviderPackage == "" {
p.option.ProviderPackage = filepath.Join(p.module, "providerpb")
}
p.packageData.ProviderPackageName = protogen.GoPackageName(filepath.Base(p.option.ProviderPackage))
p.packageData.ProviderImportPath = protogen.GoImportPath(p.option.ProviderPackage)
p.packageData.ProviderPackageName = protogen.GoPackageName(getPkgName(p.option.ProviderPackage))
p.packageData.ProviderImportPath = protogen.GoImportPath(getImportPath(p.option.ProviderPackage))
if !strings.HasPrefix(string(p.packageData.ProviderImportPath), p.module) {
p.packageData.ProviderImportPath = protogen.GoImportPath(filepath.Join(p.module, string(p.packageData.ProviderImportPath)))
}
if p.option.ExecPackage != nil {
p.packageData.ExecImportPath = protogen.GoImportPath(*p.option.ExecPackage)
p.packageData.ExecPackageName = protogen.GoPackageName(filepath.Base(*p.option.ExecPackage))
p.packageData.ExecPackageName = protogen.GoPackageName(getPkgName(*p.option.ExecPackage))
p.packageData.ExecImportPath = protogen.GoImportPath(getImportPath(*p.option.ExecPackage))
}
var err error
p.model, err = synth.Model(msg, false)
Expand All @@ -56,14 +59,38 @@ func NewProvider(synth extension.Synthesizer, msg *protogen.Message) (extension.
return p, nil
}

func (p *provider) AddDatasource(name string, exec string) {
p.datasources[name] = exec
}

func (p *provider) AddResource(name string, exec string) {
p.resources[name] = exec
}

func (p *provider) Datasources() map[string]string {
return p.datasources
}

func (p *provider) Description() string {
return p.option.Description
}

func (p *provider) ExecFilename() string {
return "provider_exec.go"
}

func (p *provider) ExecGoName() string {
return fmt.Sprintf("%sExec", p.GoName())
}

func (p *provider) Filename() string {
return "provider.go"
}

func (p *provider) HasServiceClient() bool {
return p.hasServiceClient
}

func (p *provider) GoName() string {
return toCamelCase(p.option.Name)
}
Expand All @@ -88,8 +115,12 @@ func (p *provider) PackageData() extension.PackageData {
return p.packageData
}

func (p *provider) ExecGoName() string {
return fmt.Sprintf("%sExec", p.GoName())
func (p *provider) Resources() map[string]string {
return p.resources
}

func (p *provider) SetHasServiceClient(has bool) {
p.hasServiceClient = has
}

func (p *provider) TerraformName() string {
Expand Down
7 changes: 5 additions & 2 deletions extensionimpl/synthesizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/descriptorpb"

"github.com/travix/protoc-gen-gotf/extension"
"github.com/travix/protoc-gen-gotf/pb"
Expand All @@ -16,11 +17,13 @@ type synthesizer struct {
}

func (s synthesizer) MessagePackageName(msg *protogen.Message) protogen.GoPackageName {
return protogen.GoPackageName(getPkgName(msg.Desc.ParentFile().Options()))
fileOpt, _ := msg.Desc.ParentFile().Options().(*descriptorpb.FileOptions)
return protogen.GoPackageName(getPkgName(fileOpt.GetGoPackage()))
}

func (s synthesizer) MessageImportPath(msg *protogen.Message) protogen.GoImportPath {
return protogen.GoImportPath(getImportPath(msg.Desc.ParentFile().Options()))
fileOpt, _ := msg.Desc.ParentFile().Options().(*descriptorpb.FileOptions)
return protogen.GoImportPath(getImportPath(fileOpt.GetGoPackage()))
}

func (s synthesizer) Module() string {
Expand Down
18 changes: 15 additions & 3 deletions gocode/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,27 @@ func (w *writer) execData(block extension.Block, importsArg []_import) map[strin
entry{"Imports", w.importStrings(imports)})
}

func (w *writer) providerData(provider extension.Provider, hasServiceClient bool, importsArg []_import) map[string]any {
func (w *writer) providerExecData(provider extension.Provider, importsArg []_import) map[string]any {
imports := make([]_import, len(importsArg))
copy(imports, importsArg)
// nolint:makezero // https://github.com/ashanbrown/makezero/issues/12
imports = append(imports,
_import{path: string(w.PbImportPath), string: string(w.PbPackageName)},
_import{path: string(w.ProviderImportPath), string: string(w.ProviderPackageName)},
)
return w.data(
entry{"Provider", provider},
entry{"Imports", w.importStrings(imports)})
}

func (w *writer) providerData(provider extension.Provider, importsArg []_import) map[string]any {
imports := make([]_import, len(importsArg))
copy(imports, importsArg)
// nolint:makezero // https://github.com/ashanbrown/makezero/issues/12
imports = append(imports, _import{path: string(w.PbImportPath), string: string(w.PbPackageName)})
return w.data(
entry{"Provider", provider},
entry{"Imports", w.importStrings(imports)},
entry{"HasServiceClient", hasServiceClient})
entry{"Imports", w.importStrings(imports)})
}

func (w *writer) dependencyData(models []extension.Model, defaultImports []_import) map[string]any {
Expand Down
2 changes: 2 additions & 0 deletions gocode/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/daixiang0/gci/pkg/gci"
"github.com/daixiang0/gci/pkg/log"
"github.com/daixiang0/gci/pkg/section"
zlog "github.com/rs/zerolog/log"
)

type srcFile struct {
Expand All @@ -25,6 +26,7 @@ func (f srcFile) Path() string {
func (w *writer) Format(src []byte, path string) ([]byte, error) {
dst, err := format.Source(src)
if err != nil {
zlog.Trace().Msgf("source:\n%s", src)
return nil, fmt.Errorf("failed to gofmt file %s: %w", path, err)
}
log.InitLogger()
Expand Down
4 changes: 2 additions & 2 deletions gocode/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ var defaultProviderImports = []_import{
{path: "github.com/travix/gotf/prvdr"},
}

func (w *writer) WriteProvider(filename string, file *protogen.GeneratedFile, provider extension.Provider, hasServiceClient bool) error {
data := w.providerData(provider, hasServiceClient, defaultProviderImports)
func (w *writer) WriteProvider(filename string, file *protogen.GeneratedFile, provider extension.Provider) error {
data := w.providerData(provider, defaultProviderImports)
code := &bytes.Buffer{}
if err := w.templates.ExecuteTemplate(code, providerTemplate, data); err != nil {
return fmt.Errorf("failed to execute %s template: %w", providerTemplate, err)
Expand Down
27 changes: 27 additions & 0 deletions gocode/provider_exec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package gocode

import (
"bytes"
"fmt"

"google.golang.org/protobuf/compiler/protogen"

"github.com/travix/protoc-gen-gotf/extension"
)

var defaultProviderExecImports = []_import{
{path: "context"},
{path: "github.com/hashicorp/terraform-plugin-framework/datasource"},
{path: "github.com/hashicorp/terraform-plugin-framework/resource"},
{path: "github.com/hashicorp/terraform-plugin-framework/diag"},
{path: "google.golang.org/grpc"},
}

func (w *writer) WriteProviderExec(filename string, file *protogen.GeneratedFile, provider extension.Provider) error {
data := w.providerExecData(provider, defaultProviderExecImports)
code := &bytes.Buffer{}
if err := w.templates.ExecuteTemplate(code, providerExecTemplate, data); err != nil {
return fmt.Errorf("failed to execute %s template: %w", providerExecTemplate, err)
}
return w.formatAndWrite(filename, file, code.Bytes())
}
19 changes: 4 additions & 15 deletions gocode/tmpls/data_source_exec.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@ import (
var _ {{ .ProviderPackageName }}.{{ .Block.ExecGoName }} = &{{ .Block.ExecGoName }}{}

type {{ .Block.ExecGoName }} struct {
{{ $numClients := len .Block.Clients }}{{- if gt $numClients 1 }}
{{- range $index, $client := .Block.Clients }}
client{{ $index }} {{ $.PbPackageName }}.{{ $client }}
{{- end }}
{{- else }}
client {{ $.PbPackageName }}.{{ index .Block.Clients 0 }}
{{ ClientVarName $client }} {{ $.PbPackageName }}.{{ $client }}
{{- end }}
}

Expand All @@ -26,15 +22,8 @@ func (e *{{ .Block.ExecGoName }}) Read(ctx context.Context, req datasource.ReadR
panic("implement me")
}

{{ if gt $numClients 1 }}{{- range $index, $client := .Block.Clients }}
func (e *{{ .Block.ExecGoName }}) Set{{ $client }}(client {{ $.PbPackageName }}.{{ $client }}) {
e.client{{ $index }} = client
{{ range $index, $client := .Block.Clients }}
func (e *{{ $.Block.ExecGoName }}) Set{{ $client }}(client {{ $.PbPackageName }}.{{ $client }}) {
e.{{ ClientVarName $client }} = client
}

{{- end }}
{{- else }}
func (e *{{ .Block.ExecGoName }}) Set{{ index .Block.Clients 0 }}(client {{ $.PbPackageName }}.{{ index .Block.Clients 0 }}) {
e.client = client
}

{{- end }}
6 changes: 3 additions & 3 deletions gocode/tmpls/provider.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ var _ provider.Provider = &{{ .Provider.GoName }}Provider{}

type {{ .Provider.GoName }}Exec interface {
prvdr.Provider
{{- if .HasServiceClient }}
{{- if .Provider.HasServiceClient }}
prvdr.CanConfigureGrpc[*{{ .Provider.ModelGoName }}]
{{- end }}
}
Expand Down Expand Up @@ -48,7 +48,7 @@ func (p *{{ .Provider.GoName }}Provider) Schema(ctx context.Context, req provide
func (p *{{ .Provider.GoName }}Provider) Configure(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) {
if _exec, ok := p.exec.(prvdr.CanConfigure); ok {
_exec.Configure(ctx, req, resp)
{{- if .HasServiceClient }}
{{- if .Provider.HasServiceClient }}
if resp.DataSourceData != nil {
resp.Diagnostics.AddWarning("resp.DataSourceData not set", "DataSourceData should be set to grpc.ClientConnInterface by Configure method found nil")
}
Expand All @@ -58,7 +58,7 @@ func (p *{{ .Provider.GoName }}Provider) Configure(ctx context.Context, req prov
{{- end }}
return
}
{{- if .HasServiceClient }}
{{- if .Provider.HasServiceClient }}
data, diagnostics := gotf.GetModel[{{ .Provider.ModelGoName }}](ctx, req.Config.Raw, req.Config.Get)
if diagnostics.HasError() {
resp.Diagnostics.Append(diagnostics...)
Expand Down
Loading

0 comments on commit 46016fc

Please sign in to comment.