diff --git a/definition.go b/definition.go index 0fbbbd52..e37fa077 100644 --- a/definition.go +++ b/definition.go @@ -796,15 +796,19 @@ type Union struct { PrivateDescription string `json:"description"` ResolveType ResolveTypeFn - typeConfig UnionConfig - types []*Object - possibleTypes map[string]bool + typeConfig UnionConfig + initalizedTypes bool + types []*Object + possibleTypes map[string]bool err error } + +type UnionTypesThunk func() []*Object + type UnionConfig struct { - Name string `json:"name"` - Types []*Object `json:"types"` + Name string `json:"name"` + Types interface{} `json:"types"` ResolveType ResolveTypeFn Description string `json:"description"` } @@ -822,48 +826,80 @@ func NewUnion(config UnionConfig) *Union { objectType.PrivateDescription = config.Description objectType.ResolveType = config.ResolveType - if objectType.err = invariantf( - len(config.Types) > 0, - `Must provide Array of types for Union %v.`, config.Name, - ); objectType.err != nil { - return objectType + objectType.typeConfig = config + + return objectType +} + +func (ut *Union) Types() []*Object { + if ut.initalizedTypes { + return ut.types + } + + var unionTypes []*Object + switch utype := ut.typeConfig.Types.(type) { + case UnionTypesThunk: + unionTypes = utype() + case []*Object: + unionTypes = utype + case nil: + default: + ut.err = fmt.Errorf("Unknown Union.Types type: %T", ut.typeConfig.Types) + ut.initalizedTypes = true + return nil + } + + ut.types, ut.err = defineUnionTypes(ut, unionTypes) + ut.initalizedTypes = true + return ut.types +} + +func defineUnionTypes(objectType *Union, unionTypes []*Object) ([]*Object, error) { + definedUnionTypes := []*Object{} + + if err := invariantf( + len(unionTypes) > 0, + `Must provide Array of types for Union %v.`, objectType.Name(), + ); err != nil { + return definedUnionTypes, err } - for _, ttype := range config.Types { - if objectType.err = invariantf( + + for _, ttype := range unionTypes { + if err := invariantf( ttype != nil, `%v may only contain Object types, it cannot contain: %v.`, objectType, ttype, - ); objectType.err != nil { - return objectType + ); err != nil { + return definedUnionTypes, err } if objectType.ResolveType == nil { - if objectType.err = invariantf( + if err := invariantf( ttype.IsTypeOf != nil, `Union Type %v does not provide a "resolveType" function `+ `and possible Type %v does not provide a "isTypeOf" `+ `function. There is no way to resolve this possible type `+ `during execution.`, objectType, ttype, - ); objectType.err != nil { - return objectType + ); err != nil { + return definedUnionTypes, err } } + definedUnionTypes = append(definedUnionTypes, ttype) } - objectType.types = config.Types - objectType.typeConfig = config - return objectType -} -func (ut *Union) Types() []*Object { - return ut.types + return definedUnionTypes, nil } + func (ut *Union) String() string { return ut.PrivateName } + func (ut *Union) Name() string { return ut.PrivateName } + func (ut *Union) Description() string { return ut.PrivateDescription } + func (ut *Union) Error() error { return ut.err } diff --git a/definition_test.go b/definition_test.go index 12824f21..9141917c 100644 --- a/definition_test.go +++ b/definition_test.go @@ -519,6 +519,7 @@ func TestTypeSystem_DefinitionExample_ProhibitsNilTypeInUnions(t *testing.T) { Name: "BadUnion", Types: []*graphql.Object{nil}, }) + ttype.Types() expected := `BadUnion may only contain Object types, it cannot contain: .` if ttype.Error().Error() != expected { t.Fatalf(`expected %v , got: %v`, expected, ttype.Error()) @@ -666,3 +667,42 @@ func TestTypeSystem_DefinitionExample_CanAddInputObjectField(t *testing.T) { t.Fatal("Unexpected result, inputObject should have a field named 'newValue'") } } + +func TestTypeSystem_DefinitionExample_IncludesUnionTypesThunk(t *testing.T) { + someObject := graphql.NewObject(graphql.ObjectConfig{ + Name: "SomeObject", + Fields: graphql.Fields{ + "f": &graphql.Field{ + Type: graphql.Int, + }, + }, + }) + + someOtherObject := graphql.NewObject(graphql.ObjectConfig{ + Name: "SomeOtherObject", + Fields: graphql.Fields{ + "g": &graphql.Field{ + Type: graphql.Int, + }, + }, + }) + + someUnion := graphql.NewUnion(graphql.UnionConfig{ + Name: "SomeUnion", + Types: (graphql.UnionTypesThunk)(func() []*graphql.Object { + return []*graphql.Object{someObject, someOtherObject} + }), + ResolveType: func(p graphql.ResolveTypeParams) *graphql.Object { + return nil + }, + }) + + unionTypes := someUnion.Types() + + if someUnion.Error() != nil { + t.Fatalf("unexpected error, got: %v", someUnion.Error().Error()) + } + if len(unionTypes) != 2 { + t.Fatalf("Unexpected result, someUnion should have two unionTypes, has %d", len(unionTypes)) + } +}