Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes ignoring transient fields in zio-schema based ADT codecs #435

Merged
merged 2 commits into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ object AdtCodec {
}
}

override def serialize(value: SerializedEvolutionStep)(implicit context: SerializationContext): Unit =
override def serialize(value: SerializedEvolutionStep)(implicit context: SerializationContext): Unit = {
value match {
case FieldAddedToNewChunk(size) =>
writeVarInt(size, optimizeForPositive = false)
Expand All @@ -409,6 +409,7 @@ object AdtCodec {
case UnknownEvolutionStep =>
writeVarInt(Codes.Unknown, optimizeForPositive = false)
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package io.github.vigoo.desert.shapeless

import io.github.vigoo.desert.Evolution.FieldAdded
import io.github.vigoo.desert._
import zio.test.{Spec, TestEnvironment, ZIOSpecDefault}
import zio.Chunk
import zio.test.{Spec, TestEnvironment, ZIOSpecDefault, assertTrue}

object TransientSpec extends ZIOSpecDefault with SerializationProperties {
case class TypeWithoutCodec(value: Int)
Expand All @@ -24,12 +25,19 @@ object TransientSpec extends ZIOSpecDefault with SerializationProperties {

case class Case3(x: String) extends SumWithTransientCons

@evolutionSteps(
Evolution.FieldAdded("x", 0),
Evolution.FieldRemoved("z")
)
case class Point(x: Int, y: Int, @transientField(None) cachedStr: Option[String])

private implicit val typeRegistry: TypeRegistry = TypeRegistry.empty

implicit val ttCodec: BinaryCodec[TransientTest] = DerivedBinaryCodec.derive
implicit val case1Codec: BinaryCodec[Case1] = DerivedBinaryCodec.derive
implicit val case3Codec: BinaryCodec[Case3] = DerivedBinaryCodec.derive
implicit val swtCodec: BinaryCodec[SumWithTransientCons] = DerivedBinaryCodec.derive
implicit val pointCodec: BinaryCodec[Point] = DerivedBinaryCodec.derive

override def spec: Spec[TestEnvironment, Any] =
suite("Support for transient modifiers")(
Expand All @@ -51,6 +59,12 @@ object TransientSpec extends ZIOSpecDefault with SerializationProperties {
},
test("serializing a transient constructor fails") {
cannotBeSerializedAndReadBack[SumWithTransientCons, SumWithTransientCons](Case2(TypeWithoutCodec(1)))
},
test("expected binary format") {
val point = Point(1, -10, None)
val bytes = serializeToArray(point).map(Chunk.fromArray)
val expected: Chunk[Byte] = Chunk(2, 8, 8, 3, 2, 122, -1, -1, -1, -10, 0, 0, 0, 1)
assertTrue(bytes.toOption.get == expected)
}
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ object DerivedBinaryCodec extends DerivedBinaryCodecVersionSpecific {
fields: => Chunk[Deriver.WrappedF[BinaryCodec, _]],
summoned: => Option[BinaryCodec[A]]
): BinaryCodec[A] = {
val transientFields = getTransientFields(record.fields)
val preparedSerializationCommands =
record.fields
.zip(fields)
.map { case (field, fieldCodec) =>
record.fields.zipWithIndex
.filter { case (field, _) => !transientFields.contains(field.name) }
.map { case (field, idx) =>
val fieldCodec = fields(idx)
(field.name, field.get, () => fieldCodec.unwrap.asInstanceOf[BinaryCodec[Any]])
}
.toList
Expand All @@ -50,7 +52,7 @@ object DerivedBinaryCodec extends DerivedBinaryCodecVersionSpecific {
evolutionSteps = getEvolutionStepsFromAnnotation(record.annotations),
typeName = record.id.name,
constructors = Vector(record.id.name),
transientFields = getTransientFields(record.fields),
transientFields = transientFields,
getSerializationCommands = (value: A) =>
preparedSerializationCommands.map { case (fieldName, getter, codec) =>
AdtCodec.SerializationCommand.WriteField[Any](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ package io.github.vigoo.desert.zioschema
import io.github.vigoo.desert.Evolution.FieldAdded
import io.github.vigoo.desert.{
BinaryCodec,
Evolution,
SerializationProperties,
TypeRegistry,
evolutionSteps,
serializeToArray,
transientConstructor,
transientField
}
import zio.Chunk
import zio.schema.{DeriveSchema, Schema}
import zio.test.{Spec, TestEnvironment, ZIOSpecDefault}
import zio.test.{Spec, TestEnvironment, ZIOSpecDefault, assertTrue}

object TransientSpec extends ZIOSpecDefault with SerializationProperties {
case class TypeWithoutCodec(value: Int)
Expand All @@ -32,13 +35,21 @@ object TransientSpec extends ZIOSpecDefault with SerializationProperties {

case class Case3(x: String) extends SumWithTransientCons

@evolutionSteps(
Evolution.FieldAdded("x", 0),
Evolution.FieldRemoved("z")
)
case class Point(x: Int, y: Int, @transientField(None) cachedStr: Option[String])

private implicit val typeRegistry: TypeRegistry = TypeRegistry.empty

implicit val ttSchema: Schema[TransientTest] = DeriveSchema.gen[TransientTest]
implicit val swtSchema: Schema[SumWithTransientCons] = DeriveSchema.gen[SumWithTransientCons]
implicit val pointSchema: Schema[Point] = DeriveSchema.gen[Point]

implicit val ttCodec: BinaryCodec[TransientTest] = DerivedBinaryCodec.derive[TransientTest]
implicit val swtCodec: BinaryCodec[SumWithTransientCons] = DerivedBinaryCodec.derive[SumWithTransientCons]
implicit val pointCodec: BinaryCodec[Point] = DerivedBinaryCodec.derive[Point]

override def spec: Spec[TestEnvironment, Any] =
suite("Support for transient modifiers")(
Expand All @@ -60,6 +71,12 @@ object TransientSpec extends ZIOSpecDefault with SerializationProperties {
},
test("serializing a transient constructor fails") {
cannotBeSerializedAndReadBack[SumWithTransientCons, SumWithTransientCons](Case2(TypeWithoutCodec(1)))
},
test("expected binary format") {
val point = Point(1, -10, None)
val bytes = serializeToArray(point).map(Chunk.fromArray)
val expected: Chunk[Byte] = Chunk(2, 8, 8, 3, 2, 122, -1, -1, -1, -10, 0, 0, 0, 1)
assertTrue(bytes.toOption.get == expected)
}
)
}
Loading