Skip to content

Commit

Permalink
fix: defaults for ref schemas (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
nrwiersma authored Jun 4, 2023
1 parent 0f17b28 commit 8578d51
Showing 1 changed file with 14 additions and 31 deletions.
45 changes: 14 additions & 31 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -1261,42 +1261,38 @@ func validateDefault(name string, schema Schema, def any) (any, error) {
if !ok {
return nil, fmt.Errorf("avro: invalid default for field %s. %+v not a %s", name, def, schema.Type())
}

return def, nil
}

func isValidDefault(schema Schema, def any) (any, bool) {
switch schema.Type() {
case Ref:
ref := schema.(*RefSchema)
return isValidDefault(ref.Schema(), def)
case Null:
return nullDefault, def == nil

case Enum:
v, ok := def.(string)
if !ok || len(v) == 0 {
return def, false
}

enumSchema := schema.(*EnumSchema)
found := false
for i := range enumSchema.symbols {
if def == enumSchema.symbols[i] {
var found bool
for _, sym := range schema.(*EnumSchema).symbols {
if def == sym {
found = true
break
}
}

return def, found

case String, Bytes, Fixed:
if _, ok := def.(string); ok {
return def, true
}

case Boolean:
if _, ok := def.(bool); ok {
return def, true
}

case Int:
if i, ok := def.(int8); ok {
return int(i), true
Expand All @@ -1313,75 +1309,65 @@ func isValidDefault(schema Schema, def any) (any, bool) {
if f, ok := def.(float64); ok {
return int(f), true
}

case Long:
if _, ok := def.(int64); ok {
return def, true
}
if f, ok := def.(float64); ok {
return int64(f), true
}

case Float:
if _, ok := def.(float32); ok {
return def, true
}
if f, ok := def.(float64); ok {
return float32(f), true
}

case Double:
if _, ok := def.(float64); ok {
return def, true
}

case Array:
arr, ok := def.([]any)
if !ok {
return nil, false
}

arrSchema := schema.(*ArraySchema)
as := schema.(*ArraySchema)
for i, v := range arr {
v, ok := isValidDefault(arrSchema.Items(), v)
v, ok := isValidDefault(as.Items(), v)
if !ok {
return nil, false
}
arr[i] = v
}

return arr, true

case Map:
m, ok := def.(map[string]any)
if !ok {
return nil, false
}

mapSchema := schema.(*MapSchema)
ms := schema.(*MapSchema)
for k, v := range m {
v, ok := isValidDefault(mapSchema.Values(), v)
v, ok := isValidDefault(ms.Values(), v)
if !ok {
return nil, false
}

m[k] = v
}

return m, true

case Union:
unionSchema := schema.(*UnionSchema)
return isValidDefault(unionSchema.Types()[0], def)

case Record:
m, ok := def.(map[string]any)
if !ok {
return nil, false
}

recordSchema := schema.(*RecordSchema)
for _, field := range recordSchema.Fields() {
for _, field := range schema.(*RecordSchema).Fields() {
fieldDef := field.Default()
if newDef, ok := m[field.Name()]; ok {
fieldDef = newDef
Expand All @@ -1394,10 +1380,8 @@ func isValidDefault(schema Schema, def any) (any, bool) {

m[field.Name()] = v
}

return m, true
}

return nil, false
}

Expand All @@ -1410,10 +1394,9 @@ func schemaTypeName(schema Schema) string {
return n.FullName()
}

name := string(schema.Type())
sname := string(schema.Type())
if lt := getLogicalType(schema); lt != "" {
name += "." + string(lt)
sname += "." + string(lt)
}

return name
return sname
}

0 comments on commit 8578d51

Please sign in to comment.