Skip to content

Commit

Permalink
Bulk Load CDK: Nullable-field-to-union-null & union-merging mappers (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
johnny-schmidt authored Oct 23, 2024
1 parent 48707f3 commit 284c29c
Show file tree
Hide file tree
Showing 9 changed files with 376 additions and 78 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.load.data

class MergeUnions : AirbyteSchemaIdentityMapper {
override fun mapUnion(schema: UnionType): AirbyteType {
// Map the options first so they're in their final form
val mappedOptions = schema.options.map { map(it) }
val mergedOptions = mergeOptions(mappedOptions)
if (mergedOptions.size == 1) {
return mergedOptions.first()
}
return UnionType(mergedOptions.toList())
}

private fun mergeOptions(options: List<AirbyteType>): Set<AirbyteType> {
val mergedOptions = mutableSetOf<AirbyteType>()
mergeOptions(mergedOptions, options)
return mergedOptions
}

private fun mergeOptions(into: MutableSet<AirbyteType>, from: List<AirbyteType>) {
for (option in from) {
if (option is UnionType) {
// If this is a union of a union, recursively merge the other union's options in
mergeOptions(into, option.options)
} else if (option is ObjectType) {
val existingObjOption: ObjectType? = into.find { it is ObjectType } as ObjectType?
if (existingObjOption == null) {
// No other object in the set, so just add this one
into.add(option)
continue
}

into.remove(existingObjOption)
val newProperties = existingObjOption.properties
for ((name, field) in option.properties) {
val existingField = newProperties[name]
newProperties[name] = field
if (existingField == null) {
// If no field exists with the same name, just adding this one is fine
continue
}

if (existingField != field) {
throw IllegalArgumentException(
"Cannot merge unions of objects with different types for the same field"
)
}

// If the fields are identical, we can just keep the existing field
}
into.add(ObjectType(newProperties))
} else {
into.add(option)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.load.data

class NullableToUnionNull : AirbyteSchemaIdentityMapper {
override fun mapField(field: FieldType): FieldType {
if (field.nullable) {
return FieldType(UnionType(listOf(field.type, NullType)), nullable = false)
}
return field
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.load.command

import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.IntegerType
import io.airbyte.cdk.load.data.MergeUnions
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.StringType
import io.airbyte.cdk.load.data.UnionType
import io.airbyte.cdk.load.test.util.Root
import io.airbyte.cdk.load.test.util.SchemaRecordBuilder
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows

class MergeUnionsTest {
@Test
fun testBasicBehavior() {
val (inputSchema, expectedOutput) =
SchemaRecordBuilder<Root>()
.withUnion(
expectedInstead =
FieldType(
ObjectType(
properties =
linkedMapOf(
"foo" to FieldType(StringType, false),
"bar" to FieldType(IntegerType, false)
)
),
nullable = false
)
)
.withRecord()
.with(StringType, nameOverride = "foo")
.endRecord()
.withRecord()
.with(IntegerType, nameOverride = "bar")
.endRecord()
.endUnion()
.build()
val output = MergeUnions().map(inputSchema)
Assertions.assertEquals(expectedOutput, output)
}

@Test
fun testNameClashFails() {
val (inputSchema, _) =
SchemaRecordBuilder<Root>()
.withUnion()
.withRecord()
.with(StringType, nameOverride = "foo")
.endRecord()
.withRecord()
.with(IntegerType, nameOverride = "foo")
.endRecord()
.endUnion()
.build()
assertThrows<IllegalArgumentException> { MergeUnions().map(inputSchema) }
}

@Test
fun testMergeLikeTypes() {
val (inputSchema, expectedOutput) =
SchemaRecordBuilder<Root>()
.withUnion(
expectedInstead =
FieldType(UnionType(listOf(StringType, IntegerType)), nullable = false)
)
.with(StringType)
.with(IntegerType)
.with(IntegerType)
.endUnion()
.build()
val output = MergeUnions().map(inputSchema)
Assertions.assertEquals(expectedOutput, output)
}

@Test
fun testNestedUnion() {
val (inputSchema, expectedOutput) =
SchemaRecordBuilder<Root>()
.withUnion(
expectedInstead =
FieldType(UnionType(listOf(StringType, IntegerType)), nullable = false)
)
.with(StringType)
.with(UnionType(listOf(StringType, UnionType(listOf(IntegerType, StringType)))))
.with(IntegerType)
.endUnion()
.build()
val output = MergeUnions().map(inputSchema)
Assertions.assertEquals(expectedOutput, output)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@

package io.airbyte.cdk.load.data

import io.airbyte.cdk.load.test.util.SchemaTestBuilder
import io.airbyte.cdk.load.test.util.Root
import io.airbyte.cdk.load.test.util.SchemaRecordBuilder
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test

class AirbyteSchemaIdentityMapperTest {
@Test
fun testIdMapping() {
val (inputSchema, expectedOutput) =
SchemaTestBuilder()
SchemaRecordBuilder<Root>()
.with(DateType)
.with(StringType)
.with(IntegerType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
package io.airbyte.cdk.load.data

import io.airbyte.cdk.load.message.DestinationRecord
import io.airbyte.cdk.load.test.util.Root
import io.airbyte.cdk.load.test.util.SchemaRecordBuilder
import io.airbyte.cdk.load.test.util.ValueTestBuilder
import io.airbyte.protocol.models.v0.AirbyteRecordMessageMetaChange
import org.junit.jupiter.api.Assertions
Expand All @@ -14,7 +16,7 @@ class AirbyteValueIdentityMapperTest {
@Test
fun testIdentityMapping() {
val (inputValues, inputSchema, expectedValues) =
ValueTestBuilder()
ValueTestBuilder<SchemaRecordBuilder<Root>>()
.with(StringValue("a"), StringType)
.with(IntegerValue(1), IntegerType)
.with(BooleanValue(true), BooleanType)
Expand Down Expand Up @@ -47,7 +49,7 @@ class AirbyteValueIdentityMapperTest {
@Test
fun testIdentityMappingWithBadSchema() {
val (inputValues, inputSchema, _) =
ValueTestBuilder()
ValueTestBuilder<SchemaRecordBuilder<Root>>()
.with(StringValue("a"), StringType)
.with(
TimestampValue("2021-01-01T12:00:00Z"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.load.data

import io.airbyte.cdk.load.test.util.Root
import io.airbyte.cdk.load.test.util.SchemaRecordBuilder
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test

class NullableToUnionNullTest {
@Test
fun testBasicBehavior() {
val (inputSchema, expectedOutput) =
SchemaRecordBuilder<Root>()
.with(FieldType(StringType, nullable = false))
.with(
FieldType(IntegerType, nullable = true),
FieldType(UnionType(listOf(IntegerType, NullType)), nullable = false)
)
.build()
Assertions.assertEquals(NullableToUnionNull().map(inputSchema), expectedOutput)
}

@Test
fun testWackyBehavior() {
val (inputSchema, expectedOutput) =
SchemaRecordBuilder<Root>()
.with(FieldType(UnionType(listOf(StringType, IntegerType)), nullable = false))
.with(
FieldType(UnionType(listOf(StringType, IntegerType)), nullable = true),
FieldType(
UnionType(listOf(UnionType(listOf(StringType, IntegerType)), NullType)),
nullable = false
)
)
.with(FieldType(UnionType(listOf(StringType, NullType)), nullable = false))
.with(
FieldType(UnionType(listOf(StringType, NullType)), nullable = true),
FieldType(
UnionType(listOf(UnionType(listOf(StringType, NullType)), NullType)),
nullable = false
)
)
.build()
Assertions.assertEquals(NullableToUnionNull().map(inputSchema), expectedOutput)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Copyright (c) 2024 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.cdk.load.test.util

import io.airbyte.cdk.load.data.AirbyteType
import io.airbyte.cdk.load.data.FieldType
import io.airbyte.cdk.load.data.ObjectType
import io.airbyte.cdk.load.data.UnionType
import java.util.*
import kotlin.collections.LinkedHashMap

sealed interface SchemaRecordBuilderType

class Root : SchemaRecordBuilderType

class SchemaRecordBuilder<T : SchemaRecordBuilderType>(
val inputSchema: ObjectType = ObjectType(properties = LinkedHashMap()),
val expectedSchema: ObjectType = ObjectType(properties = LinkedHashMap()),
val parent: T? = null
) : SchemaRecordBuilderType {
fun with(
given: FieldType,
expected: FieldType = given,
nameOverride: String? = null
): SchemaRecordBuilder<T> {
val name = nameOverride ?: UUID.randomUUID().toString()
inputSchema.properties[name] = given
expectedSchema.properties[name] = expected
return this
}

fun with(
given: AirbyteType,
expected: AirbyteType = given,
nameOverride: String? = null
): SchemaRecordBuilder<T> {
return with(FieldType(given, false), FieldType(expected, false), nameOverride)
}

fun withRecord(
nullable: Boolean = false,
nameOverride: String? = null
): SchemaRecordBuilder<SchemaRecordBuilder<T>> {
val name = nameOverride ?: UUID.randomUUID().toString()
val inputRecord = ObjectType(properties = LinkedHashMap())
val outputRecord = ObjectType(properties = LinkedHashMap())
inputSchema.properties[name] = FieldType(inputRecord, nullable = nullable)
expectedSchema.properties[name] = FieldType(outputRecord, nullable = nullable)
return SchemaRecordBuilder(
inputSchema = inputRecord,
expectedSchema = outputRecord,
parent = this
)
}

fun withUnion(
nullable: Boolean = false,
nameOverride: String? = null,
expectedInstead: FieldType? = null
): SchemaTestUnionBuilder<T> {
val name = nameOverride ?: UUID.randomUUID().toString()
val inputOptions = mutableListOf<AirbyteType>()
val expectedOptions =
if (expectedInstead == null) {
mutableListOf<AirbyteType>()
} else {
null
}
inputSchema.properties[name] = FieldType(UnionType(inputOptions), nullable = nullable)
expectedSchema.properties[name] =
expectedInstead ?: FieldType(UnionType(expectedOptions!!), nullable = nullable)
return SchemaTestUnionBuilder(this, inputOptions, expectedOptions)
}

fun endRecord(): T {
if (parent == null) {
throw IllegalStateException("Cannot end record without parent")
}
return parent
}

fun build(): Pair<ObjectType, ObjectType> {
if (parent != null) {
throw IllegalStateException("Cannot build nested schema")
}
return Pair(inputSchema, expectedSchema)
}
}

class SchemaTestUnionBuilder<T : SchemaRecordBuilderType>(
private val parent: SchemaRecordBuilder<T>,
private val options: MutableList<AirbyteType>,
private val expectedOptions: MutableList<AirbyteType>?
) : SchemaRecordBuilderType {
fun with(option: AirbyteType, expected: AirbyteType? = null): SchemaTestUnionBuilder<T> {
options.add(option)
if (expected != null && expectedOptions == null) {
throw IllegalStateException(
"Cannot specify expected options for union without nullable"
)
}
expected?.let { expectedOptions!!.add(it) }
return this
}

fun withRecord(): SchemaRecordBuilder<SchemaTestUnionBuilder<T>> {
val inputRecord = ObjectType(properties = LinkedHashMap())
val outputRecord = ObjectType(properties = LinkedHashMap())
options.add(inputRecord)
expectedOptions?.add(outputRecord)
return SchemaRecordBuilder(
inputSchema = inputRecord,
expectedSchema = outputRecord,
parent = this
)
}

fun endUnion(): SchemaRecordBuilder<T> {
return parent
}
}
Loading

0 comments on commit 284c29c

Please sign in to comment.