From 0f41de64f86b88924a6b3fbf78f0f2afd8abea70 Mon Sep 17 00:00:00 2001 From: "Giau. Tran Minh" Date: Fri, 9 Jul 2021 01:26:11 +0700 Subject: [PATCH] fix: change cwd to load gqlgen config correctly (#103) Code copied from https://github.com/99designs/gqlgen/pull/1511 --- entgql/extension.go | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/entgql/extension.go b/entgql/extension.go index 4f542cf50..370021eb2 100644 --- a/entgql/extension.go +++ b/entgql/extension.go @@ -17,6 +17,7 @@ package entgql import ( "fmt" "io/ioutil" + "os" "path/filepath" "strings" @@ -31,7 +32,6 @@ import ( "github.com/graphql-go/graphql/language/printer" "github.com/graphql-go/graphql/language/source" "github.com/graphql-go/graphql/language/visitor" - gqlparserast "github.com/vektah/gqlparser/v2/ast" ) type ( @@ -84,18 +84,28 @@ func WithSchemaPath(path string) ExtensionOption { // Note that, enabling this option is recommended as it improves the // GraphQL integration, func WithConfigPath(path string) ExtensionOption { - return func(ex *Extension) error { - cfg, err := config.LoadConfig(path) + return func(ex *Extension) (err error) { + cwd, err := os.Getwd() + if err != nil { + return fmt.Errorf("unable to get working directory: %w", err) + } + if err := os.Chdir(filepath.Dir(path)); err != nil { + return fmt.Errorf("unable to enter config dir: %w", err) + } + defer func() { + if cerr := os.Chdir(cwd); cerr != nil { + err = fmt.Errorf("unable to restore working directory: %w", cerr) + } + }() + cfg, err := config.LoadConfig(filepath.Base(path)) if err != nil { return err } - if cfg.Schema == nil { if err := cfg.LoadSchema(); err != nil { return err } } - ex.cfg = cfg return nil } @@ -271,11 +281,8 @@ func (e *Extension) hasMapping(f *gen.Field) (string, bool) { // isInput reports if the given type is an input object. func (e *Extension) isInput(name string) bool { if t, ok := e.cfg.Schema.Types[name]; ok && t != nil { - return t.Kind == gqlparserast.Enum || - t.Kind == gqlparserast.Scalar || - t.Kind == gqlparserast.InputObject + return t.IsInputType() } - return false }