diff --git a/core/src/test/scala/com/avsystem/commons/serialization/IgnoreTransientDefaultMarkerTest.scala b/core/src/test/scala/com/avsystem/commons/serialization/IgnoreTransientDefaultMarkerTest.scala index 8032a8b0a..30fe2c0b2 100644 --- a/core/src/test/scala/com/avsystem/commons/serialization/IgnoreTransientDefaultMarkerTest.scala +++ b/core/src/test/scala/com/avsystem/commons/serialization/IgnoreTransientDefaultMarkerTest.scala @@ -11,6 +11,12 @@ object IgnoreTransientDefaultMarkerTest { @transientDefault defaultObj: HasDefaults = HasDefaults(), ) object NestedHasDefaults extends HasGenCodec[NestedHasDefaults] + + final case class HasOptParam( + @transientDefault flag: Boolean = false, + @optionalParam str: Opt[String] = Opt.Empty, + ) + object HasOptParam extends HasGenCodec[HasOptParam] } class IgnoreTransientDefaultMarkerTest extends AbstractCodecTest { @@ -24,16 +30,31 @@ class IgnoreTransientDefaultMarkerTest extends AbstractCodecTest { result } - def createInput(raw: Any): Input = new SimpleValueInput(raw) + def createInput(raw: Any): Input = + CustomMarkersInputWrapper(new SimpleValueInput(raw), IgnoreTransientDefaultMarker) - test("case class with default values") { + test("write case class with default values") { testWrite(HasDefaults(str = "lol"), Map("str" -> "lol", "int" -> 42)) testWrite(HasDefaults(43, "lol"), Map("int" -> 43, "str" -> "lol")) testWrite(HasDefaults(str = null), Map("str" -> null, "int" -> 42)) testWrite(HasDefaults(str = "dafuq"), Map("str" -> "dafuq", "int" -> 42)) } - test("nested case class with default values") { + test("read case class with default values") { + testRead(Map("str" -> "lol", "int" -> 42), HasDefaults(str = "lol", int = 42)) + testRead(Map("str" -> "lol"), HasDefaults(str = "lol", int = 42)) + testRead(Map("int" -> 43, "str" -> "lol"), HasDefaults(int = 43, str = "lol")) + testRead(Map("str" -> null, "int" -> 42), HasDefaults(str = null, int = 42)) + testRead(Map("str" -> null), HasDefaults(str = null, int = 42)) + testRead(Map(), HasDefaults(str = "dafuq", int = 42)) + } + + test("write case class with opt values") { + testWrite(HasOptParam(str = "lol".opt), Map("flag" -> false, "str" -> "lol")) + testWrite(HasOptParam(), Map("flag" -> false)) + } + + test("write nested case class with default values") { testWrite( value = NestedHasDefaults( flag = false, diff --git a/macros/src/main/scala/com/avsystem/commons/macros/serialization/GenCodecMacros.scala b/macros/src/main/scala/com/avsystem/commons/macros/serialization/GenCodecMacros.scala index a186d3aff..534a58703 100644 --- a/macros/src/main/scala/com/avsystem/commons/macros/serialization/GenCodecMacros.scala +++ b/macros/src/main/scala/com/avsystem/commons/macros/serialization/GenCodecMacros.scala @@ -172,24 +172,48 @@ class GenCodecMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) with } } - def writeField(p: ApplyParam, value: Tree): Tree = { + def doWriteField(p: ApplyParam, value: Tree, transientValue: Option[Tree]): Tree = { + val writeArgs = q"output" :: q"${p.idx}" :: value :: transientValue.toList + val writeTargs = if (isOptimizedPrimitive(p)) Nil else List(p.valueType) + q"writeField[..$writeTargs](..$writeArgs)" + } + + def writeFieldNoTransientDefault(p: ApplyParam, value: Tree): Tree = { + val transientValue = p.optionLike.map(ol => q"${ol.reference(Nil)}.none") + doWriteField(p, value, transientValue) + } + + def writeFieldTransientDefaultPossible(p: ApplyParam, value: Tree): Tree = { val transientValue = if (isTransientDefault(p)) Some(p.defaultValue) else p.optionLike.map(ol => q"${ol.reference(Nil)}.none") - - val writeArgsNoTransient = q"output" :: q"${p.idx}" :: List(value) - val writeArgs = writeArgsNoTransient ::: transientValue.toList - val writeTargs = if (isOptimizedPrimitive(p)) Nil else List(p.valueType) - q""" - if (ignoreTransientDefault) - writeField[..$writeTargs](..$writeArgsNoTransient) - else - writeField[..$writeTargs](..$writeArgs) - """ + doWriteField(p, value, transientValue) } + def writeField(p: ApplyParam, value: Tree, ignoreTransientDefault: Tree): Tree = + if (isTransientDefault(p)) + q""" + if ($ignoreTransientDefault) ${writeFieldNoTransientDefault(p, value)} + else ${writeFieldTransientDefaultPossible(p, value)} + """ + else + writeFieldNoTransientDefault(p, value) + def ignoreTransientDefaultCheck: Tree = - q"val ignoreTransientDefault = output.customEvent($SerializationPkg.IgnoreTransientDefaultMarker, ())" + q"output.customEvent($SerializationPkg.IgnoreTransientDefaultMarker, ())" + + // when params size is 1 + def writeSingle(p: ApplyParam, value: Tree): Tree = + writeField(p, value, ignoreTransientDefaultCheck) + + // when params size is greater than 1 + def writeMultiple(value: ApplyParam => Tree): Tree = + if (anyParamHasTransientDefault) { + q""" + val ignoreTransientDefault = $ignoreTransientDefaultCheck + ..${params.map(p => writeField(p, value(p), q"ignoreTransientDefault"))} + """ + } else q"..${params.map(p => writeFieldNoTransientDefault(p, value(p)))}" def writeFields: Tree = params match { case Nil => @@ -203,37 +227,30 @@ class GenCodecMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) with """ case List(p: ApplyParam) => if (canUseFields) - q""" - $ignoreTransientDefaultCheck - ${writeField(p, q"value.${p.sym.name}")} - """ + q"${writeSingle(p, q"value.${p.sym.name}")}" else q""" val unapplyRes = $companion.$unapply[..${dtpe.typeArgs}](value) if (unapplyRes.isEmpty) unapplyFailed - else { - $ignoreTransientDefaultCheck - ${writeField(p, q"unapplyRes.get")} - } + else ${writeSingle(p, q"unapplyRes.get")} """ case _ => if (canUseFields) - q""" - $ignoreTransientDefaultCheck - ..${params.map(p => writeField(p, q"value.${p.sym.name}"))} - """ + q"${writeMultiple(p => q"value.${p.sym.name}")}" else q""" val unapplyRes = $companion.$unapply[..${dtpe.typeArgs}](value) if (unapplyRes.isEmpty) unapplyFailed else { val t = unapplyRes.get - $ignoreTransientDefaultCheck - ..${params.map(p => writeField(p, q"t.${tupleGet(p.idx)}"))} + ${writeMultiple(p => q"t.${tupleGet(p.idx)}")} } """ } + def anyParamHasTransientDefault: Boolean = + params.exists(isTransientDefault) + def mayBeTransient(p: ApplyParam): Boolean = p.optionLike.nonEmpty || isTransientDefault(p)