From 23641700043eefa331ff6ea2b1120ac811f68676 Mon Sep 17 00:00:00 2001 From: Julien Buret Date: Mon, 11 Nov 2019 12:15:06 +0100 Subject: [PATCH] fixes #154 register custom serializers for kotlinx.serialization --- .../kotlin/KMongoSerializationRepository.kt | 55 ++++++++++- .../src/main/kotlin/SerializationCodec.kt | 5 +- .../src/test/kotlin/SerializationCodecTest.kt | 91 ++++++++++++++++++- 3 files changed, 141 insertions(+), 10 deletions(-) diff --git a/kmongo-serialization-mapping/src/main/kotlin/KMongoSerializationRepository.kt b/kmongo-serialization-mapping/src/main/kotlin/KMongoSerializationRepository.kt index 0435f481..4274cb98 100644 --- a/kmongo-serialization-mapping/src/main/kotlin/KMongoSerializationRepository.kt +++ b/kmongo-serialization-mapping/src/main/kotlin/KMongoSerializationRepository.kt @@ -49,9 +49,28 @@ import java.util.Calendar import java.util.Date import java.util.GregorianCalendar import java.util.Locale +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.CopyOnWriteArraySet import kotlin.reflect.KClass import kotlin.reflect.KProperty +@PublishedApi +internal val customSerializersMap: MutableMap, KSerializer<*>> = ConcurrentHashMap() +private val customModules = CopyOnWriteArraySet() + +/** + * Add a custom [SerialModule] to KMongo kotlinx.serialization mapping. + */ +fun registerModule(module: SerialModule) { + customModules.add(module) +} + +/** + * Add a custom serializer to KMongo kotlinx.serialization mapping + */ +inline fun registerSerializer(serializer: KSerializer) { + customSerializersMap[T::class] = serializer +} /** * @@ -97,7 +116,8 @@ internal object KMongoSerializationRepository { if (it.isEmpty()) StringSerializer else getSerializer(it.first()) } as KSerializer ) - else -> null + else -> module.getContextual(obj.javaClass.kotlin) + ?: module.getPolymorphic(obj.javaClass.kotlin, obj) } } @@ -107,15 +127,42 @@ internal object KMongoSerializationRepository { if (obj == null) { JsonNullSerializer as? KSerializer ?: error("no serializer for null") } else { - - //TODO don't known yet how to do this without reflection (serializersMap[obj.javaClass.kotlin] ?: getBaseSerializer(obj) ?: obj.javaClass.kotlin.serializer()) as? KSerializer ?: error("no serializer for $obj of class ${obj.javaClass.kotlin}") } - val module: SerialModule = SerializersModule { + @Suppress("UNCHECKED_CAST") + @ImplicitReflectionSerializer + fun getSerializer(kClass: KClass): KSerializer = + (serializersMap[kClass] + ?: module.getContextual(kClass) + ?: kClass.serializer()) as? KSerializer + ?: error("no serializer for $kClass of class $kClass") + + @Volatile + private var baseModule: SerialModule = SerializersModule { include(serializersModuleOf(serializersMap)) + include(serializersModuleOf(customSerializersMap)) + customModules.forEach { include(it) } } + @Volatile + private var customModulesSize: Int = customModules.size + @Volatile + private var customSerializersSize: Int = customSerializersMap.size + + val module: SerialModule + get() { + if (customSerializersSize != customSerializersMap.size || customModulesSize != customModules.size) { + customSerializersSize = customSerializersMap.size + customModulesSize = customModules.size + baseModule = SerializersModule { + include(serializersModuleOf(serializersMap)) + include(serializersModuleOf(customSerializersMap)) + customModules.forEach { include(it) } + } + } + return baseModule + } } \ No newline at end of file diff --git a/kmongo-serialization-mapping/src/main/kotlin/SerializationCodec.kt b/kmongo-serialization-mapping/src/main/kotlin/SerializationCodec.kt index 22cc40f1..c7e61454 100644 --- a/kmongo-serialization-mapping/src/main/kotlin/SerializationCodec.kt +++ b/kmongo-serialization-mapping/src/main/kotlin/SerializationCodec.kt @@ -22,7 +22,6 @@ import com.github.jershell.kbson.Configuration import kotlinx.serialization.ImplicitReflectionSerializer import kotlinx.serialization.decode import kotlinx.serialization.encode -import kotlinx.serialization.serializer import org.bson.BsonDouble import org.bson.BsonInt32 import org.bson.BsonInt64 @@ -61,7 +60,9 @@ internal class SerializationCodec(val clazz: KClass) : CollectibleCo @ImplicitReflectionSerializer override fun decode(reader: BsonReader, decoderContext: DecoderContext): T { - return BsonDocumentDecoder(reader, module, Configuration()).decode(clazz.serializer()) + return BsonDocumentDecoder(reader, module, Configuration()).decode( + KMongoSerializationRepository.getSerializer(clazz) + ) } override fun getDocumentId(document: T): BsonValue { diff --git a/kmongo-serialization-mapping/src/test/kotlin/SerializationCodecTest.kt b/kmongo-serialization-mapping/src/test/kotlin/SerializationCodecTest.kt index 9b979d5a..2bb3a31d 100644 --- a/kmongo-serialization-mapping/src/test/kotlin/SerializationCodecTest.kt +++ b/kmongo-serialization-mapping/src/test/kotlin/SerializationCodecTest.kt @@ -1,8 +1,17 @@ package org.litote.kmongo.serialization +import kotlinx.serialization.CompositeDecoder +import kotlinx.serialization.CompositeEncoder import kotlinx.serialization.ContextualSerialization +import kotlinx.serialization.Decoder +import kotlinx.serialization.Encoder import kotlinx.serialization.ImplicitReflectionSerializer +import kotlinx.serialization.KSerializer import kotlinx.serialization.Serializable +import kotlinx.serialization.Serializer +import kotlinx.serialization.internal.SerialClassDescImpl +import kotlinx.serialization.internal.StringDescriptor +import kotlinx.serialization.modules.SerializersModule import org.bson.BsonDocument import org.bson.BsonDocumentReader import org.bson.BsonDocumentWriter @@ -14,6 +23,7 @@ import org.litote.kmongo.Id import org.litote.kmongo.model.Friend import org.litote.kmongo.newId import kotlin.test.assertEquals +import kotlin.test.assertFalse /** * @@ -29,7 +39,6 @@ class SerializationCodecTest { val writer = BsonDocumentWriter(document) codec.encode(writer, friend, EncoderContext.builder().build()) - println(document) val newFriend = codec.decode(BsonDocumentReader(document), DecoderContext.builder().build()) assertEquals(friend, newFriend) @@ -47,7 +56,6 @@ class SerializationCodecTest { val writer = BsonDocumentWriter(document) codec.encode(writer, id, EncoderContext.builder().build()) - println(document) val newFriend = codec.decode(BsonDocumentReader(document), DecoderContext.builder().build()) assertEquals(id, newFriend) @@ -65,7 +73,6 @@ class SerializationCodecTest { val writer = BsonDocumentWriter(document) codec.encode(writer, idList, EncoderContext.builder().build()) - println(document) val newFriend = codec.decode(BsonDocumentReader(document), DecoderContext.builder().build()) assertEquals(idList, newFriend) @@ -86,9 +93,85 @@ class SerializationCodecTest { val writer = BsonDocumentWriter(document) codec.encode(writer, idList, EncoderContext.builder().build()) - println(document) val newFriend = codec.decode(BsonDocumentReader(document), DecoderContext.builder().build()) assertEquals(idList, newFriend) } + + data class Custom(val s: String, val b: Boolean = false) + + @Serializer(forClass = Custom::class) + class CustomSerializer : KSerializer { + + object CustomClassDesc : SerialClassDescImpl("Custom") { + init { + addElement("s") + pushDescriptor(StringDescriptor) + } + } + + override fun deserialize(decoder: Decoder): Custom { + decoder as CompositeDecoder + decoder.beginStructure(CustomClassDesc) + val c = Custom(decoder.decodeStringElement(CustomClassDesc, 0)) + decoder.endStructure(CustomClassDesc) + return c + } + + override fun serialize(encoder: Encoder, obj: Custom) { + encoder as CompositeEncoder + encoder.beginStructure(CustomClassDesc) + encoder.encodeStringElement(CustomClassDesc, 0, obj.s) + encoder.endStructure(CustomClassDesc) + } + } + + @ImplicitReflectionSerializer + @Test + fun `encode and decode with custom serializer`() { + registerSerializer(CustomSerializer()) + val c = Custom("a", true) + val codec = SerializationCodec(Custom::class) + val document = BsonDocument() + val writer = BsonDocumentWriter(document) + codec.encode(writer, c, EncoderContext.builder().build()) + + val newC = codec.decode(BsonDocumentReader(document), DecoderContext.builder().build()) + + assertEquals(c.s, newC.s) + assertFalse(newC.b) + } + + interface Message + + @Serializable + data class StringMessage(val message: String) : Message + + @Serializable + data class IntMessage(val number: Int) : Message + + @Serializable + data class Container(val m:Message) + + @ImplicitReflectionSerializer + @Test + fun `encode and decode with custom polymorphic serializer`() { + registerModule( + SerializersModule { + polymorphic(Message::class) { + StringMessage::class with StringMessage.serializer() + IntMessage::class with IntMessage.serializer() + } + }) + val c = Container(StringMessage("a")) + val codec = SerializationCodec(Container::class) + val document = BsonDocument() + val writer = BsonDocumentWriter(document) + codec.encode(writer, c, EncoderContext.builder().build()) + + val newC = codec.decode(BsonDocumentReader(document), DecoderContext.builder().build()) + + assertEquals(c, newC) + } + } \ No newline at end of file