Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: defaults for ref schemas #262

Merged
merged 1 commit into from
Jun 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 14 additions & 31 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -1261,42 +1261,38 @@ func validateDefault(name string, schema Schema, def any) (any, error) {
if !ok {
return nil, fmt.Errorf("avro: invalid default for field %s. %+v not a %s", name, def, schema.Type())
}

return def, nil
}

func isValidDefault(schema Schema, def any) (any, bool) {
switch schema.Type() {
case Ref:
ref := schema.(*RefSchema)
return isValidDefault(ref.Schema(), def)
case Null:
return nullDefault, def == nil

case Enum:
v, ok := def.(string)
if !ok || len(v) == 0 {
return def, false
}

enumSchema := schema.(*EnumSchema)
found := false
for i := range enumSchema.symbols {
if def == enumSchema.symbols[i] {
var found bool
for _, sym := range schema.(*EnumSchema).symbols {
if def == sym {
found = true
break
}
}

return def, found

case String, Bytes, Fixed:
if _, ok := def.(string); ok {
return def, true
}

case Boolean:
if _, ok := def.(bool); ok {
return def, true
}

case Int:
if i, ok := def.(int8); ok {
return int(i), true
Expand All @@ -1313,75 +1309,65 @@ func isValidDefault(schema Schema, def any) (any, bool) {
if f, ok := def.(float64); ok {
return int(f), true
}

case Long:
if _, ok := def.(int64); ok {
return def, true
}
if f, ok := def.(float64); ok {
return int64(f), true
}

case Float:
if _, ok := def.(float32); ok {
return def, true
}
if f, ok := def.(float64); ok {
return float32(f), true
}

case Double:
if _, ok := def.(float64); ok {
return def, true
}

case Array:
arr, ok := def.([]any)
if !ok {
return nil, false
}

arrSchema := schema.(*ArraySchema)
as := schema.(*ArraySchema)
for i, v := range arr {
v, ok := isValidDefault(arrSchema.Items(), v)
v, ok := isValidDefault(as.Items(), v)
if !ok {
return nil, false
}
arr[i] = v
}

return arr, true

case Map:
m, ok := def.(map[string]any)
if !ok {
return nil, false
}

mapSchema := schema.(*MapSchema)
ms := schema.(*MapSchema)
for k, v := range m {
v, ok := isValidDefault(mapSchema.Values(), v)
v, ok := isValidDefault(ms.Values(), v)
if !ok {
return nil, false
}

m[k] = v
}

return m, true

case Union:
unionSchema := schema.(*UnionSchema)
return isValidDefault(unionSchema.Types()[0], def)

case Record:
m, ok := def.(map[string]any)
if !ok {
return nil, false
}

recordSchema := schema.(*RecordSchema)
for _, field := range recordSchema.Fields() {
for _, field := range schema.(*RecordSchema).Fields() {
fieldDef := field.Default()
if newDef, ok := m[field.Name()]; ok {
fieldDef = newDef
Expand All @@ -1394,10 +1380,8 @@ func isValidDefault(schema Schema, def any) (any, bool) {

m[field.Name()] = v
}

return m, true
}

return nil, false
}

Expand All @@ -1410,10 +1394,9 @@ func schemaTypeName(schema Schema) string {
return n.FullName()
}

name := string(schema.Type())
sname := string(schema.Type())
if lt := getLogicalType(schema); lt != "" {
name += "." + string(lt)
sname += "." + string(lt)
}

return name
return sname
}