From 8578d510c1966a23f0729f3075755dec0bdc6027 Mon Sep 17 00:00:00 2001 From: Nicholas Wiersma Date: Sun, 4 Jun 2023 14:39:50 +0200 Subject: [PATCH] fix: defaults for ref schemas (#262) --- schema.go | 45 ++++++++++++++------------------------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/schema.go b/schema.go index 307dacb1..9e2168d9 100644 --- a/schema.go +++ b/schema.go @@ -1261,42 +1261,38 @@ func validateDefault(name string, schema Schema, def any) (any, error) { if !ok { return nil, fmt.Errorf("avro: invalid default for field %s. %+v not a %s", name, def, schema.Type()) } - return def, nil } func isValidDefault(schema Schema, def any) (any, bool) { switch schema.Type() { + case Ref: + ref := schema.(*RefSchema) + return isValidDefault(ref.Schema(), def) case Null: return nullDefault, def == nil - case Enum: v, ok := def.(string) if !ok || len(v) == 0 { return def, false } - enumSchema := schema.(*EnumSchema) - found := false - for i := range enumSchema.symbols { - if def == enumSchema.symbols[i] { + var found bool + for _, sym := range schema.(*EnumSchema).symbols { + if def == sym { found = true break } } - return def, found - case String, Bytes, Fixed: if _, ok := def.(string); ok { return def, true } - case Boolean: if _, ok := def.(bool); ok { return def, true } - case Int: if i, ok := def.(int8); ok { return int(i), true @@ -1313,7 +1309,6 @@ func isValidDefault(schema Schema, def any) (any, bool) { if f, ok := def.(float64); ok { return int(f), true } - case Long: if _, ok := def.(int64); ok { return def, true @@ -1321,7 +1316,6 @@ func isValidDefault(schema Schema, def any) (any, bool) { if f, ok := def.(float64); ok { return int64(f), true } - case Float: if _, ok := def.(float32); ok { return def, true @@ -1329,59 +1323,51 @@ func isValidDefault(schema Schema, def any) (any, bool) { if f, ok := def.(float64); ok { return float32(f), true } - case Double: if _, ok := def.(float64); ok { return def, true } - case Array: arr, ok := def.([]any) if !ok { return nil, false } - arrSchema := schema.(*ArraySchema) + as := schema.(*ArraySchema) for i, v := range arr { - v, ok := isValidDefault(arrSchema.Items(), v) + v, ok := isValidDefault(as.Items(), v) if !ok { return nil, false } arr[i] = v } - return arr, true - case Map: m, ok := def.(map[string]any) if !ok { return nil, false } - mapSchema := schema.(*MapSchema) + ms := schema.(*MapSchema) for k, v := range m { - v, ok := isValidDefault(mapSchema.Values(), v) + v, ok := isValidDefault(ms.Values(), v) if !ok { return nil, false } m[k] = v } - return m, true - case Union: unionSchema := schema.(*UnionSchema) return isValidDefault(unionSchema.Types()[0], def) - case Record: m, ok := def.(map[string]any) if !ok { return nil, false } - recordSchema := schema.(*RecordSchema) - for _, field := range recordSchema.Fields() { + for _, field := range schema.(*RecordSchema).Fields() { fieldDef := field.Default() if newDef, ok := m[field.Name()]; ok { fieldDef = newDef @@ -1394,10 +1380,8 @@ func isValidDefault(schema Schema, def any) (any, bool) { m[field.Name()] = v } - return m, true } - return nil, false } @@ -1410,10 +1394,9 @@ func schemaTypeName(schema Schema) string { return n.FullName() } - name := string(schema.Type()) + sname := string(schema.Type()) if lt := getLogicalType(schema); lt != "" { - name += "." + string(lt) + sname += "." + string(lt) } - - return name + return sname }