Skip to content

Commit

Permalink
Optimize macro code when handling IgnoreTransientDefaultMarker
Browse files Browse the repository at this point in the history
  • Loading branch information
sebaciv committed Sep 23, 2024
1 parent 37c2a77 commit b19248f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ object IgnoreTransientDefaultMarkerTest {
@transientDefault defaultObj: HasDefaults = HasDefaults(),
)
object NestedHasDefaults extends HasGenCodec[NestedHasDefaults]

final case class HasOptParam(
@transientDefault flag: Boolean = false,
@optionalParam str: Opt[String] = Opt.Empty,
)
object HasOptParam extends HasGenCodec[HasOptParam]
}

class IgnoreTransientDefaultMarkerTest extends AbstractCodecTest {
Expand All @@ -24,16 +30,31 @@ class IgnoreTransientDefaultMarkerTest extends AbstractCodecTest {
result
}

def createInput(raw: Any): Input = new SimpleValueInput(raw)
def createInput(raw: Any): Input =
CustomMarkersInputWrapper(new SimpleValueInput(raw), IgnoreTransientDefaultMarker)

test("case class with default values") {
test("write case class with default values") {
testWrite(HasDefaults(str = "lol"), Map("str" -> "lol", "int" -> 42))
testWrite(HasDefaults(43, "lol"), Map("int" -> 43, "str" -> "lol"))
testWrite(HasDefaults(str = null), Map("str" -> null, "int" -> 42))
testWrite(HasDefaults(str = "dafuq"), Map("str" -> "dafuq", "int" -> 42))
}

test("nested case class with default values") {
test("read case class with default values") {
testRead(Map("str" -> "lol", "int" -> 42), HasDefaults(str = "lol", int = 42))
testRead(Map("str" -> "lol"), HasDefaults(str = "lol", int = 42))
testRead(Map("int" -> 43, "str" -> "lol"), HasDefaults(int = 43, str = "lol"))
testRead(Map("str" -> null, "int" -> 42), HasDefaults(str = null, int = 42))
testRead(Map("str" -> null), HasDefaults(str = null, int = 42))
testRead(Map(), HasDefaults(str = "dafuq", int = 42))
}

test("write case class with opt values") {
testWrite(HasOptParam(str = "lol".opt), Map("flag" -> false, "str" -> "lol"))
testWrite(HasOptParam(), Map("flag" -> false))
}

test("write nested case class with default values") {
testWrite(
value = NestedHasDefaults(
flag = false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,24 +172,48 @@ class GenCodecMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) with
}
}

def writeField(p: ApplyParam, value: Tree): Tree = {
def doWriteField(p: ApplyParam, value: Tree, transientValue: Option[Tree]): Tree = {
val writeArgs = q"output" :: q"${p.idx}" :: value :: transientValue.toList
val writeTargs = if (isOptimizedPrimitive(p)) Nil else List(p.valueType)
q"writeField[..$writeTargs](..$writeArgs)"
}

def writeFieldNoTransientDefault(p: ApplyParam, value: Tree): Tree = {
val transientValue = p.optionLike.map(ol => q"${ol.reference(Nil)}.none")
doWriteField(p, value, transientValue)
}

def writeFieldTransientDefaultPossible(p: ApplyParam, value: Tree): Tree = {
val transientValue =
if (isTransientDefault(p)) Some(p.defaultValue)
else p.optionLike.map(ol => q"${ol.reference(Nil)}.none")

val writeArgsNoTransient = q"output" :: q"${p.idx}" :: List(value)
val writeArgs = writeArgsNoTransient ::: transientValue.toList
val writeTargs = if (isOptimizedPrimitive(p)) Nil else List(p.valueType)
q"""
if (ignoreTransientDefault)
writeField[..$writeTargs](..$writeArgsNoTransient)
else
writeField[..$writeTargs](..$writeArgs)
"""
doWriteField(p, value, transientValue)
}

def writeField(p: ApplyParam, value: Tree, ignoreTransientDefault: Tree): Tree =
if (isTransientDefault(p))
q"""
if ($ignoreTransientDefault) ${writeFieldNoTransientDefault(p, value)}
else ${writeFieldTransientDefaultPossible(p, value)}
"""
else
writeFieldNoTransientDefault(p, value)

def ignoreTransientDefaultCheck: Tree =
q"val ignoreTransientDefault = output.customEvent($SerializationPkg.IgnoreTransientDefaultMarker, ())"
q"output.customEvent($SerializationPkg.IgnoreTransientDefaultMarker, ())"

// when params size is 1
def writeSingle(p: ApplyParam, value: Tree): Tree =
writeField(p, value, ignoreTransientDefaultCheck)

// when params size is greater than 1
def writeMultiple(value: ApplyParam => Tree): Tree =
if (anyParamHasTransientDefault) {
q"""
val ignoreTransientDefault = $ignoreTransientDefaultCheck
..${params.map(p => writeField(p, value(p), q"ignoreTransientDefault"))}
"""
} else q"..${params.map(p => writeFieldNoTransientDefault(p, value(p)))}"

def writeFields: Tree = params match {
case Nil =>
Expand All @@ -203,37 +227,30 @@ class GenCodecMacros(ctx: blackbox.Context) extends CodecMacroCommons(ctx) with
"""
case List(p: ApplyParam) =>
if (canUseFields)
q"""
$ignoreTransientDefaultCheck
${writeField(p, q"value.${p.sym.name}")}
"""
q"${writeSingle(p, q"value.${p.sym.name}")}"
else
q"""
val unapplyRes = $companion.$unapply[..${dtpe.typeArgs}](value)
if (unapplyRes.isEmpty) unapplyFailed
else {
$ignoreTransientDefaultCheck
${writeField(p, q"unapplyRes.get")}
}
else ${writeSingle(p, q"unapplyRes.get")}
"""
case _ =>
if (canUseFields)
q"""
$ignoreTransientDefaultCheck
..${params.map(p => writeField(p, q"value.${p.sym.name}"))}
"""
q"${writeMultiple(p => q"value.${p.sym.name}")}"
else
q"""
val unapplyRes = $companion.$unapply[..${dtpe.typeArgs}](value)
if (unapplyRes.isEmpty) unapplyFailed
else {
val t = unapplyRes.get
$ignoreTransientDefaultCheck
..${params.map(p => writeField(p, q"t.${tupleGet(p.idx)}"))}
${writeMultiple(p => q"t.${tupleGet(p.idx)}")}
}
"""
}

def anyParamHasTransientDefault: Boolean =
params.exists(isTransientDefault)

def mayBeTransient(p: ApplyParam): Boolean =
p.optionLike.nonEmpty || isTransientDefault(p)

Expand Down

0 comments on commit b19248f

Please sign in to comment.