Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(1265): Added schema validation for nested record types | Jeyam #1293

Merged
merged 1 commit into from
Jan 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions src/main/java/org/akhq/utils/AvroSerializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.util.Utf8;
import org.apache.commons.collections.CollectionUtils;

import java.math.BigDecimal;
import java.math.MathContext;
Expand Down Expand Up @@ -55,26 +54,53 @@ public class AvroSerializer {
.toFormatter();

public static GenericRecord recordSerializer(Map<String, Object> record, Schema schema) {
GenericRecord returnValue = new GenericData.Record(schema);
Set<String> schemaFields = schema.getFields().stream()
.map(Schema.Field::name).collect(Collectors.toSet());

Set<String> recordFields = record.keySet();

if (schemaFields.size() != recordFields.size()) {
Object[] missingFields = CollectionUtils.disjunction(schemaFields, recordFields).stream().toArray();
throw new IllegalArgumentException(" Record does not contain followings fields ".concat(Arrays.toString(missingFields)));
}
validateSchema(schema.getFields(), record);

GenericRecord returnValue = new GenericData.Record(schema);
schema
.getFields()
.forEach(field -> {
Object fieldValue = record.getOrDefault(field.name(), field.defaultVal());
returnValue.put(field.name(), AvroSerializer.objectSerializer(fieldValue, field.schema()));
});

return returnValue;
}

private static void validateSchema(List<Schema.Field> fields, Map<String, Object> record) {
for (Schema.Field field : fields) {
var schema = field.schema();
var type = schema.getType();
var value = Optional.ofNullable(record)
.filter(Objects::nonNull)
.map(r -> r.get(field.name()));
var hasEmptyValue = value.isEmpty();

validateSchemaHasDefaultValue(field, schema, hasEmptyValue);

if (Schema.Type.RECORD.getName().equals(type.getName()) && !hasEmptyValue) {
validateSchema(schema.getFields(), (Map<String, Object>) value.get());
}
else if (Schema.Type.ARRAY.getName().equals(type.getName()) && !hasEmptyValue) {
Schema elementType = schema.getElementType();
if (elementType.getType().equals(Schema.Type.RECORD)) {
for(Map<String, Object> val : (List<Map<String, Object>>) value.get()) {
validateSchema(elementType.getFields(), val);
}
}
}
}
}

private static void validateSchemaHasDefaultValue(Schema.Field field, Schema schema, boolean hasEmptyValue) {
var isFieldHasNullValue = field.hasDefaultValue() || schema.isNullable();

if ((!isFieldHasNullValue) && hasEmptyValue) {
var message = String.format("Field %s is missing in the payload", field.name());
throw new IllegalArgumentException(message);
}
}

@SuppressWarnings("unchecked")
private static Object objectSerializer(Object value, Schema schema) {
if (value == org.apache.avro.JsonProperties.NULL_VALUE) {
Expand Down
97 changes: 91 additions & 6 deletions src/test/java/org/akhq/modules/AvroSchemaSerializerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import io.confluent.kafka.schemaregistry.avro.AvroSchema;
import org.akhq.configs.SchemaRegistryType;
import org.akhq.modules.schemaregistry.AvroSerializer;

import org.apache.avro.Schema;
import org.apache.avro.SchemaBuilder;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand All @@ -11,8 +13,7 @@

import java.nio.ByteBuffer;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.*;

@ExtendWith(MockitoExtension.class)
class AvroSchemaSerializerTest {
Expand All @@ -26,6 +27,60 @@ class AvroSchemaSerializerTest {
.name("rating").type().doubleType().noDefault()
.endRecord();

private final org.apache.avro.Schema NESTED_SCHEMA =
new Schema.Parser().parse("{\n" +
" \"type\": \"record\",\n" +
" \"name\": \"userInfo\",\n" +
" \"namespace\": \"org.akhq\",\n" +
" \"fields\": [\n" +
" {\n" +
" \"name\": \"username\",\n" +
" \"type\": \"string\",\n" +
" \"default\": \"NONE\"\n" +
" },\n" +
" {\n" +
" \"name\": \"age\",\n" +
" \"type\": \"int\",\n" +
" \"default\": -1\n" +
" },\n" +
" {\n" +
" \"name\": \"phone\",\n" +
" \"type\": \"string\"\n" +
" },\n" +
" {\n" +
" \"name\": \"address\",\n" +
" \"type\": {\n" +
" \"type\": \"record\",\n" +
" \"name\": \"mailing_address\",\n" +
" \"fields\": [\n" +
" {\n" +
" \"name\": \"street\",\n" +
" \"type\": \"string\"\n" +
" },\n" +
" {\n" +
" \"name\": \"detailaddress\",\n" +
" \"type\": {\n" +
" \"type\": \"record\",\n" +
" \"name\": \"homeaddress\",\n" +
" \"fields\": [\n" +
" {\n" +
" \"name\": \"houseNo\",\n" +
" \"type\": \"int\",\n" +
" \"default\": 1\n" +
" },\n" +
" {\n" +
" \"name\": \"roomNo\",\n" +
" \"type\": \"int\"\n" +
" }\n" +
" ]\n" +
" }\n" +
" }\n" +
" ]\n" +
" }\n" +
" }\n" +
" ]\n" +
"}");

public static final String VALID_JSON = "{\n" +
" \"title\": \"the-title\",\n" +
" \"release_year\": 123,\n" +
Expand All @@ -38,11 +93,34 @@ class AvroSchemaSerializerTest {
" \"rating\": 2.5\n" +
"}";

public static final String INVALID_NESTED_JSON = "{\n" +
" \"phone\": \"12345\",\n" +
" \"address\": {\n" +
" \"street\": \"Test Street\",\n" +
" \"detailaddress\" : {\n" +
" \n" +
" }\n" +
" }\n" +
"}";

public static final String VALID_NESTED_JSON = "{\n" +
" \"phone\": \"2312331\",\n" +
" \"address\": {\n" +
" \"street\": \"Test Street\",\n" +
" \"detailaddress\" : {\n" +
" \"houseNo\" : 1,\n" +
" \"roomNo\" : 2\n" +
" }\n" +
" }\n" +
"}";

private AvroSerializer avroSerializer;
private AvroSerializer avroDeepSerializer;

@BeforeEach
void setUp() {
avroSerializer = AvroSerializer.newInstance(SCHEMA_ID, new AvroSchema(SCHEMA), SchemaRegistryType.CONFLUENT);
avroDeepSerializer = AvroSerializer.newInstance(SCHEMA_ID, new AvroSchema(NESTED_SCHEMA), SchemaRegistryType.CONFLUENT);
}

@Test
Expand All @@ -59,9 +137,16 @@ void shouldSerializeSchemaId() {

@Test
void shouldFailIfDoesntMatchSchemaId() {
assertThrows(NullPointerException.class, () -> {
int schemaId = 3;
avroSerializer.serialize(INVALID_JSON);
});
assertThrows(IllegalArgumentException.class, () -> avroSerializer.serialize(INVALID_JSON));
}

@Test
void shouldThrowForDeepNestedInvalidJSON() {
assertThrows(IllegalArgumentException.class, () -> avroDeepSerializer.serialize(INVALID_NESTED_JSON));
}

@Test
void shouldNotThrowForValidNestedJSON() {
assertDoesNotThrow(() -> avroDeepSerializer.serialize(VALID_NESTED_JSON));
}
}