diff --git a/build.sbt b/build.sbt index 756c9061..ac6548d1 100644 --- a/build.sbt +++ b/build.sbt @@ -5,7 +5,7 @@ import sbtcrossproject.CrossPlugin.autoImport.{CrossType, crossProject} lazy val scala_2_11Version = "2.11.12" lazy val scala_2_12Version = "2.12.16" lazy val scala_2_13Version = "2.13.8" -lazy val scala_3Version = "3.2.1-RC1-bin-20220705-9bb3108-NIGHTLY" // Fix not yet available in RC or stable version: https://github.com/lampepfl/dotty/issues/12498 +lazy val scala_3Version = "3.2.1-RC1" lazy val scalaVersionsAll = Seq(scala_2_11Version, scala_2_12Version, scala_2_13Version, scala_3Version) lazy val theScalaVersion = scala_2_12Version @@ -629,6 +629,13 @@ lazy val compilerSettings = Seq( case Some((2, _)) => base ++ Seq("-Xlint") case _ => base } + }, + Test / scalacOptions ++= { + if (scalaBinaryVersion.value == "3") { + Seq("-Yretain-trees") + } else { + Seq.empty + } } ) diff --git a/enumeratum-core/src/main/scala-2/enumeratum/EnumCompat.scala b/enumeratum-core/src/main/scala-2/enumeratum/EnumCompat.scala index 1fd7952f..7e6aeb02 100644 --- a/enumeratum-core/src/main/scala-2/enumeratum/EnumCompat.scala +++ b/enumeratum-core/src/main/scala-2/enumeratum/EnumCompat.scala @@ -7,10 +7,18 @@ private[enumeratum] trait EnumCompat[A <: EnumEntry] { _: Enum[A] => /** Returns a Seq of [[A]] objects that the macro was able to find. * * You will want to use this in some way to implement your [[values]] method. In fact, if you - * aren't using this method...why are you even bothering with this lib? + * aren't using this method... why are you even bothering with this lib? */ protected def findValues: IndexedSeq[A] = macro EnumMacros.findValuesImpl[A] + + /** The sequence of values for your [[Enum]]. You will typically want to implement this in your + * extending class as a `val` so that `withName` and friends are as efficient as possible. + * + * Feel free to implement this however you'd like (including messing around with ordering, etc) + * if that fits your needs better. + */ + def values: IndexedSeq[A] } private[enumeratum] trait EnumCompanion { diff --git a/enumeratum-core/src/main/scala-2/enumeratum/values/ValueEnumCompat.scala b/enumeratum-core/src/main/scala-2/enumeratum/values/ValueEnumCompat.scala index 3e4f1f0a..29a4f49d 100644 --- a/enumeratum-core/src/main/scala-2/enumeratum/values/ValueEnumCompat.scala +++ b/enumeratum-core/src/main/scala-2/enumeratum/values/ValueEnumCompat.scala @@ -2,7 +2,7 @@ package enumeratum.values import scala.language.experimental.macros -import _root_.enumeratum.{Enum, EnumMacros, ValueEnumMacros} +import _root_.enumeratum.{EnumMacros, ValueEnumMacros} private[enumeratum] trait IntEnumCompanion { diff --git a/enumeratum-core/src/main/scala-3/enumeratum/EnumCompat.scala b/enumeratum-core/src/main/scala-3/enumeratum/EnumCompat.scala index 73aae7a1..ce6e7f04 100644 --- a/enumeratum-core/src/main/scala-3/enumeratum/EnumCompat.scala +++ b/enumeratum-core/src/main/scala-3/enumeratum/EnumCompat.scala @@ -1,7 +1,21 @@ package enumeratum private[enumeratum] trait EnumCompat[A <: EnumEntry] { _enum: Enum[A] => + + /** Returns a Seq of [[A]] objects that the macro was able to find. + * + * You will want to use this in some way to implement your [[values]] method. In fact, if you + * aren't using this method... why are you even bothering with this lib? + */ inline def findValues: IndexedSeq[A] = ${ EnumMacros.findValuesImpl[A] } + + /** The sequence of values for your [[Enum]]. You will typically want to implement this in your + * extending class as a `val` so that `withName` and friends are as efficient as possible. + * + * Feel free to implement this however you'd like (including messing around with ordering, etc) + * if that fits your needs better. + */ + def values: IndexedSeq[A] } private[enumeratum] trait EnumCompanion { diff --git a/enumeratum-core/src/main/scala/enumeratum/Enum.scala b/enumeratum-core/src/main/scala/enumeratum/Enum.scala index 15c99715..bdff0861 100644 --- a/enumeratum-core/src/main/scala/enumeratum/Enum.scala +++ b/enumeratum-core/src/main/scala/enumeratum/Enum.scala @@ -61,14 +61,6 @@ trait Enum[A <: EnumEntry] extends EnumCompat[A] { */ lazy final val valuesToIndex: Map[A, Int] = values.zipWithIndex.toMap - /** The sequence of values for your [[Enum]]. You will typically want to implement this in your - * extending class as a `val` so that `withName` and friends are as efficient as possible. - * - * Feel free to implement this however you'd like (including messing around with ordering, etc) - * if that fits your needs better. - */ - def values: IndexedSeq[A] - /** Tries to get an [[A]] by the supplied name. The name corresponds to the .name of the case * objects implementing [[A]] * diff --git a/enumeratum-core/src/main/scala/enumeratum/values/ValueEnum.scala b/enumeratum-core/src/main/scala/enumeratum/values/ValueEnum.scala index 210a5afd..0bd5ea68 100644 --- a/enumeratum-core/src/main/scala/enumeratum/values/ValueEnum.scala +++ b/enumeratum-core/src/main/scala/enumeratum/values/ValueEnum.scala @@ -1,7 +1,5 @@ package enumeratum.values -import _root_.enumeratum.Enum - /** Base trait for a Value-based enums. * * Example: diff --git a/enumeratum-core/src/test/scala/enumeratum/values/Compilation.scala b/enumeratum-core/src/test/scala/enumeratum/values/Compilation.scala index b3158a2c..0d33a437 100644 --- a/enumeratum-core/src/test/scala/enumeratum/values/Compilation.scala +++ b/enumeratum-core/src/test/scala/enumeratum/values/Compilation.scala @@ -1,9 +1,10 @@ +/* TODO package enumeratum.values /** Created by Lloyd on 1/4/17. - * - * Copyright 2017 - */ + * + * Copyright 2017 + */ // From https://github.com/lloydmeta/enumeratum/issues/96 sealed abstract class A private (val value: Int) extends IntEnumEntry { @@ -45,3 +46,4 @@ object B extends IntEnum[B] { def identity(str: String) = str } + */ diff --git a/enumeratum-core/src/test/scala/enumeratum/values/CustomEnumPrivateConstructor.scala b/enumeratum-core/src/test/scala/enumeratum/values/CustomEnumPrivateConstructor.scala index e1873c0f..f5c4183d 100644 --- a/enumeratum-core/src/test/scala/enumeratum/values/CustomEnumPrivateConstructor.scala +++ b/enumeratum-core/src/test/scala/enumeratum/values/CustomEnumPrivateConstructor.scala @@ -6,17 +6,21 @@ trait CustomEnumEntry extends IntEnumEntry { val value: Int val name: String } + trait CustomEnum[T <: CustomEnumEntry] extends IntEnum[T] { def apply(name: String): T = values.find(_.name == name).get } + trait CustomEnumComparable[T <: CustomEnumEntry] { this: T => def >=(that: T): Boolean = this.value >= that.value } + sealed abstract class CustomEnumPrivateConstructor private (val value: Int, val name: String) extends CustomEnumEntry with CustomEnumComparable[CustomEnumPrivateConstructor] + object CustomEnumPrivateConstructor extends CustomEnum[CustomEnumPrivateConstructor] { val values = findValues case object A extends CustomEnumPrivateConstructor(10, "a") diff --git a/enumeratum-core/src/test/scala/enumeratum/values/LibraryItem.scala b/enumeratum-core/src/test/scala/enumeratum/values/LibraryItem.scala index fa11a35f..20325875 100644 --- a/enumeratum-core/src/test/scala/enumeratum/values/LibraryItem.scala +++ b/enumeratum-core/src/test/scala/enumeratum/values/LibraryItem.scala @@ -18,7 +18,6 @@ case object LibraryItem extends IntEnum[LibraryItem] { case object CD extends LibraryItem(14, name = "cd") val values = findValues - } case object Newspaper extends LibraryItem(5, "Zeitung") diff --git a/enumeratum-core/src/test/scala/enumeratum/values/NumericPrecision.scala b/enumeratum-core/src/test/scala/enumeratum/values/NumericPrecision.scala new file mode 100644 index 00000000..b1e5d257 --- /dev/null +++ b/enumeratum-core/src/test/scala/enumeratum/values/NumericPrecision.scala @@ -0,0 +1,13 @@ +package enumeratum.values + +sealed abstract class NumericPrecision(val value: String) extends StringEnumEntry with AllowAlias + +object NumericPrecision extends StringEnum[NumericPrecision] { + case object Integer extends NumericPrecision("integer") + case object Int extends NumericPrecision("integer") + + case object Float extends NumericPrecision("float") + case object Double extends NumericPrecision("double") + + val values = findValues +} diff --git a/enumeratum-core/src/test/scala/enumeratum/values/ValueEnumSpec.scala b/enumeratum-core/src/test/scala/enumeratum/values/ValueEnumSpec.scala index 764e11c3..7f7ff968 100644 --- a/enumeratum-core/src/test/scala/enumeratum/values/ValueEnumSpec.scala +++ b/enumeratum-core/src/test/scala/enumeratum/values/ValueEnumSpec.scala @@ -10,27 +10,40 @@ import org.scalatest.matchers.should.Matchers class ValueEnumSpec extends AnyFunSpec with Matchers with ValueEnumHelpers { describe("basic sanity check") { - it("should have the proper values") { LibraryItem.withValue(1) shouldBe LibraryItem.Book LibraryItem.withValue(2) shouldBe LibraryItem.Movie LibraryItem.withValue(10) shouldBe LibraryItem.Magazine LibraryItem.withValue(14) shouldBe LibraryItem.CD } - } testNumericEnum("IntEnum", LibraryItem) testNumericEnum("ShortEnum", Drinks) testNumericEnum("LongEnum", ContentType) + testEnum("StringEnum", OperatingSystem, Seq("windows-phone")) testEnum("CharEnum", Alphabet, Seq('Z')) testEnum("ByteEnum", Bites, Seq(10).map(_.toByte)) - testNumericEnum("when using val members in the body", MovieGenre) + + testNumericEnum("When using val members in the body", MovieGenre) testNumericEnum("LongEnum that is nesting an IntEnum", Animal) testNumericEnum("IntEnum that is nested inside a LongEnum", Animal.Mammalian) testNumericEnum("Custom IntEnum with private constructors", CustomEnumPrivateConstructor) + describe("AllowAlias") { + it("should be supported") { + NumericPrecision.values should contain(NumericPrecision.Integer) + NumericPrecision.values should contain(NumericPrecision.Int) + NumericPrecision.values should contain(NumericPrecision.Float) + NumericPrecision.values should contain(NumericPrecision.Double) + + NumericPrecision.withValue("integer") shouldBe NumericPrecision.Int + NumericPrecision.withValue("float") shouldBe NumericPrecision.Float + NumericPrecision.withValue("double") shouldBe NumericPrecision.Double + } + } + describe("finding companion object") { it("should work for IntEnums") { @@ -83,11 +96,9 @@ class ValueEnumSpec extends AnyFunSpec with Matchers with ValueEnumHelpers { companion shouldBe Bites companion.values should contain(Bites.FourByte) } - } describe("compilation failures") { - describe("problematic values") { it("should fail to compile when values are repeated") { @@ -141,7 +152,6 @@ class ValueEnumSpec extends AnyFunSpec with Matchers with ValueEnumHelpers { } """ shouldNot compile } - } describe("trying to use with improper types") { @@ -198,7 +208,5 @@ class ValueEnumSpec extends AnyFunSpec with Matchers with ValueEnumHelpers { """ shouldNot compile } } - } - } diff --git a/macros/compat/src/main/scala-3/enumeratum/ContextUtils.scala b/macros/compat/src/main/scala-3/enumeratum/ContextUtils.scala deleted file mode 100644 index f03ec571..00000000 --- a/macros/compat/src/main/scala-3/enumeratum/ContextUtils.scala +++ /dev/null @@ -1,36 +0,0 @@ -package enumeratum - -object ContextUtils { - // Constant types - type CTLong = Long - type CTInt = Int - type CTChar = Char - - /* TODO: Remove; - - /** Returns a TermName */ - def termName(c: Context)(name: String): c.universe.TermName = { - c.universe.TermName(name) - } - - /** Returns a companion symbol. */ - def companion(c: Context)(sym: c.Symbol): c.universe.Symbol = sym.companion - - /** Returns a PartialFunction for turning symbols into names */ - def constructorsToParamNamesPF( - c: Context - ): PartialFunction[c.universe.Symbol, List[c.universe.Name]] = { - case m if m.isConstructor => - m.asMethod.paramLists.flatten.map(_.asTerm.name) - } - - /** Returns the reserved constructor name */ - def constructorName(c: Context): c.universe.TermName = { - c.universe.termNames.CONSTRUCTOR - } - - /** Returns a named arg extractor */ - def namedArg(c: Context) = c.universe.NamedArg - - */ -} diff --git a/macros/src/main/scala-2/enumeratum/ValueEnumMacros.scala b/macros/src/main/scala-2/enumeratum/ValueEnumMacros.scala index 346cbaa8..6eb13bef 100644 --- a/macros/src/main/scala-2/enumeratum/ValueEnumMacros.scala +++ b/macros/src/main/scala-2/enumeratum/ValueEnumMacros.scala @@ -113,6 +113,7 @@ object ValueEnumMacros { // Finish by building our Sequence val subclassSymbols = treeWithVals.map(_.tree.symbol) + EnumMacros.buildSeqExpr[ValueEntryType](c)(subclassSymbols) } diff --git a/macros/src/main/scala-3/enumeratum/EnumHelper.scala b/macros/src/main/scala-3/enumeratum/EnumHelper.scala new file mode 100644 index 00000000..ae004ea2 --- /dev/null +++ b/macros/src/main/scala-3/enumeratum/EnumHelper.scala @@ -0,0 +1,92 @@ +package test + +import scala.deriving.Mirror +import scala.quoted.{Expr, Quotes, Type} +import scala.reflect.Enum + +object EnumHelper: + + /** {{{ + * enum Color: + * case Red, Green, Blue + * + * import reactivemongo.api.bson.EnumHelper + * + * val valueOf: String => Option[Color] = EnumHelper.strictValueOf[Color] + * + * assert(valueOf("Red") contains Color.Red) + * assert(valueOf("red").isEmpty) + * }}} + */ + inline def strictValueOf[T](using + Mirror.SumOf[T] + ): String => Option[T] = ${ strictValueOfImpl[T] } + + private def strictValueOfImpl[T](using + Quotes, + Type[T] + ): Expr[String => Option[T]] = enumValueOfImpl(identity, None) + + private def enumValueOfImpl[T]( + labelNaming: String => String, + normalize: Option[Expr[String => String]] + )(using + q: Quotes, + tpe: Type[T] + ): Expr[String => Option[T]] = { + import q.reflect.* + + val tpr = TypeRepr.of(using tpe) + + val compSym = tpr.typeSymbol.companionModule + + if (compSym == Symbol.noSymbol) { + report.errorAndAbort(s"Unresolved type: ${tpr.typeSymbol.fullName}") + } + + val compRef = Ref(compSym) + + val cases = compSym.fieldMembers + .flatMap { fieldSym => + val fieldTerm = compRef.select(fieldSym) + + if (fieldTerm.tpe <:< tpr) { + Seq(fieldSym -> fieldTerm.asExprOf[T]) + } else { + Seq.empty[(Symbol, Expr[T])] + } + } + .zipWithIndex + .map { case ((sym, expr), i) => + val name = sym.name.toLowerCase + val body: Expr[Some[T]] = '{ Some(${ expr }) } + + CaseDef( + Literal(StringConstant(labelNaming(sym.name))), + guard = None, + rhs = body.asTerm + ) + } + + val none = CaseDef( + Wildcard(), + None, + '{ Option.empty[T] }.asTerm + ) + + def mtch(s: Expr[String]): Expr[Option[T]] = { + Match(s.asTerm, cases :+ none).asExprOf[Option[T]] + } + + normalize match { + case Some(nz) => + '{ (s: String) => + val in = ${ nz }(s) + ${ mtch('in) } + } + + case _ => + '{ (s: String) => ${ mtch('s) } } + } + } +end EnumHelper diff --git a/macros/src/main/scala-3/enumeratum/EnumMacros.scala b/macros/src/main/scala-3/enumeratum/EnumMacros.scala index 5e677a01..91568b1b 100644 --- a/macros/src/main/scala-3/enumeratum/EnumMacros.scala +++ b/macros/src/main/scala-3/enumeratum/EnumMacros.scala @@ -58,6 +58,9 @@ object EnumMacros: /** Makes sure that we can work with the given type as an enum: * * Aborts if the type is not sealed. + * + * @tparam T + * the `Enum` type */ private[enumeratum] def validateType[T](using q: Quotes, tpe: Type[T]): q.reflect.TypeRepr = { import q.reflect.* @@ -75,8 +78,8 @@ object EnumMacros: /** Returns a sequence of symbols for objects that implement the given type * - * @tparam the - * `Enum` type + * @tparam T + * the `Enum` type * @param tpr * the representation of type `T` (also specified by `tpe`) */ @@ -85,6 +88,7 @@ object EnumMacros: )(using tpe: Type[T]): List[q.reflect.TypeRepr] = { import q.reflect.* + // TODO: Use SumOf? given quotes: q.type = q @annotation.tailrec @@ -173,6 +177,7 @@ object EnumMacros: tpr.classSymbol .flatMap { cls => + // TODO: cls.typeMembers val types = subclasses(cls.children.map(_.tree), Nil) if (types.isEmpty) None else Some(types) diff --git a/macros/src/main/scala-3/enumeratum/Macros.scala b/macros/src/main/scala-3/enumeratum/Macros.scala new file mode 100644 index 00000000..44654b1b --- /dev/null +++ b/macros/src/main/scala-3/enumeratum/Macros.scala @@ -0,0 +1,31 @@ +package test + +object Macros { + import scala.quoted.{Expr, Quotes, Type} + + inline def show[A]: String = ${ showImpl[A] } + + private def showImpl[A](using tpe: Type[A], q: Quotes): Expr[String] = { + import q.reflect.* + + val repr = TypeRepr.of[A](using tpe) + + val tpeSym = repr.typeSymbol + /* + > _root_.test.Macros.show[_root_.test.Bar.type] + + val res0: String = @scala.annotation.internal.SourceFile("macros/src/main/scala-3/enumeratum/Foo.scala") object Bar extends test.Foo { this: test.Bar.type => + + } + */ + + // val tpeSym = repr.typeSymbol.companionModule + /* + > _root_.test.Macros.show[_root_.test.Bar.type] + + val res0: String = lazy val Bar: test.Bar.type + */ + + Expr(tpeSym.tree.show) + } +} diff --git a/macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala b/macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala index 422c757e..e6bc2639 100644 --- a/macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala +++ b/macros/src/main/scala-3/enumeratum/ValueEnumMacros.scala @@ -2,32 +2,36 @@ package enumeratum import scala.reflect.ClassTag -import scala.quoted.{Expr, Quotes, Type} +import scala.deriving.Mirror +import scala.quoted.{Expr, FromExpr, Quotes, Type} import enumeratum.values.AllowAlias +/** @define valueEntryTypeNote + * Note, requires the ValueEntryType to have a 'value' member that has a literal value. + */ @SuppressWarnings(Array("org.wartremover.warts.StringPlusAny")) object ValueEnumMacros { - /** Finds ValueEntryType-typed objects in scope that have literal value:Int implementations + /** Finds ValueEntryType-typed objects in scope that have literal `value: Int` implementations. * - * Note, requires the ValueEntryType to have a 'value' member that has a literal value + * $valueEntryTypeNote */ def findIntValueEntriesImpl[ValueEntryType: Type](using Quotes ): Expr[IndexedSeq[ValueEntryType]] = - findValueEntriesImpl[ValueEntryType, ContextUtils.CTInt, Int](identity) + findValueEntriesImpl[ValueEntryType, Int] - /** Finds ValueEntryType-typed objects in scope that have literal value:Long implementations + /** Finds ValueEntryType-typed objects in scope that have literal `value: Long` implementations. * - * Note, requires the ValueEntryType to have a 'value' member that has a literal value + * $valueEntryTypeNote */ def findLongValueEntriesImpl[ValueEntryType: Type](using Quotes ): Expr[IndexedSeq[ValueEntryType]] = - findValueEntriesImpl[ValueEntryType, ContextUtils.CTLong, Long](identity) + findValueEntriesImpl[ValueEntryType, Long] - /** Finds ValueEntryType-typed objects in scope that have literal value:Short implementations + /** Finds ValueEntryType-typed objects in scope that have literal `value: Short` implementations. * * Note * @@ -37,11 +41,9 @@ object ValueEnumMacros { def findShortValueEntriesImpl[ValueEntryType: Type](using Quotes ): Expr[IndexedSeq[ValueEntryType]] = - findValueEntriesImpl[ValueEntryType, ContextUtils.CTInt, Short]( - _.toShort - ) // do a transform because there is no such thing as Short literals + findValueEntriesImpl[ValueEntryType, Short] - /** Finds ValueEntryType-typed objects in scope that have literal value:String implementations + /** Finds ValueEntryType-typed objects in scope that have literal `value: String` implementations. * * Note * @@ -50,9 +52,9 @@ object ValueEnumMacros { def findStringValueEntriesImpl[ValueEntryType: Type](using Quotes ): Expr[IndexedSeq[ValueEntryType]] = - findValueEntriesImpl[ValueEntryType, String, String](identity) + findValueEntriesImpl[ValueEntryType, String] - /** Finds ValueEntryType-typed objects in scope that have literal value:Byte implementations + /** Finds ValueEntryType-typed objects in scope that have literal `value: Byte` implementations. * * Note * @@ -61,9 +63,9 @@ object ValueEnumMacros { def findByteValueEntriesImpl[ValueEntryType: Type](using Quotes ): Expr[IndexedSeq[ValueEntryType]] = - findValueEntriesImpl[ValueEntryType, ContextUtils.CTInt, Byte](_.toByte) + findValueEntriesImpl[ValueEntryType, Byte] - /** Finds ValueEntryType-typed objects in scope that have literal value:Char implementations + /** Finds ValueEntryType-typed objects in scope that have literal `value: Char` implementations. * * Note * @@ -72,220 +74,192 @@ object ValueEnumMacros { def findCharValueEntriesImpl[ValueEntryType: Type](using Quotes ): Expr[IndexedSeq[ValueEntryType]] = - findValueEntriesImpl[ValueEntryType, ContextUtils.CTChar, Char](identity) + findValueEntriesImpl[ValueEntryType, Char] + + private given ValueOfFromExpr[T <: Singleton](using Type[T]): FromExpr[ValueOf[T]] with { + def unapply(x: Expr[ValueOf[T]])(using q: Quotes): Option[ValueOf[T]] = { + import q.reflect.* + + x match { + case '{ new ValueOf[T]($v) } => + v.asTerm match { + case id: Ident => { + val cls = Class.forName(id.symbol.fullName + '$') + val moduleField = cls.getFields.find(_.getName == f"MODULE$$") + + moduleField.map { field => + new ValueOf(field.get(null).asInstanceOf[T]) + } + } + + case _ => + None + } + + case _ => + None + } + } + } /** The method that does the heavy lifting. */ - private[this] def findValueEntriesImpl[ - ValueEntryType, - ValueType: ClassTag, - ProcessedValue - ]( - processFoundValues: ValueType => ProcessedValue - )(using tpe: Type[ValueEntryType], q: Quotes): Expr[IndexedSeq[ValueEntryType]] = { - import q.reflect.* + private def findValueEntriesImpl[A, ValueType](using + q: Quotes, + tpe: Type[A], + valueTpe: Type[ValueType] + )(using cls: ClassTag[ValueType]): Expr[IndexedSeq[A]] = { + type TakeHead[Head <: A & Singleton, Tail <: Tuple] = Head *: Tail - val repr = TypeRepr.of[ValueEntryType](using tpe) - val typeSymbol = repr.typeSymbol - - /* - EnumMacros.validateType(c)(typeSymbol) - // Find the trees in the enclosing object that match the given ValueEntryType - val subclassTrees = EnumMacros.enclosedSubClassTrees(c)(typeSymbol) - // Find the parameters for the constructors of ValueEntryType - val valueEntryTypeConstructorsParams = - findConstructorParamsLists[ValueEntryType](c) - // Identify the value:ValueType implementations for each of the trees we found and process them if required - val treeWithVals = findValuesForSubclassTrees[ValueType, ProcessedValue](c)( - valueEntryTypeConstructorsParams, - subclassTrees, - processFoundValues - ) - - if (weakTypeOf[ValueEntryType] <:< c.typeOf[AllowAlias]) { - // Skip the uniqueness check - } else { - // Make sure the processed found value implementations are unique - ensureUnique[ProcessedValue](c)(treeWithVals) + type SumOf[X <: A, T <: Tuple] = Mirror.SumOf[X] { + type MirroredElemTypes = T } - // Finish by building our Sequence - val subclassSymbols = treeWithVals.map(_.tree.symbol) - EnumMacros.buildSeqExpr[ValueEntryType](c)(subclassSymbols) - */ + import q.reflect.* - '{ - IndexedSeq.empty[ValueEntryType] // TODO - } - } + val ctx = q.asInstanceOf[scala.quoted.runtime.impl.QuotesImpl].ctx + + val yRetainTrees = ctx.settings.YretainTrees.valueIn(ctx.settingsState) + + if (!yRetainTrees) { + report.errorAndAbort("""Option -Yretain-trees must be set in scalacOptions. +In SBT settings: - /* - /** Returns a list of TreeWithVal (tree with value of type ProcessedValueType) for the given trees - * and transformation - * - * Will abort compilation if not all the trees provided have a literal value member/constructor - * argument - */ - private[this] def findValuesForSubclassTrees[ValueType: ClassTag, ProcessedValueType](using Quotes)( - valueEntryCTorsParams: List[List[c.universe.Name]], - memberTrees: Seq[c.universe.ModuleDef], - processFoundValues: ValueType => ProcessedValueType - ): Seq[TreeWithVal[c.universe.ModuleDef, ProcessedValueType]] = { - val treeWithValues = toTreeWithMaybeVals[ValueType, ProcessedValueType](c)( - valueEntryCTorsParams, - memberTrees, - processFoundValues - ) - val (hasValueMember, lacksValueMember) = - treeWithValues.partition(_.maybeValue.isDefined) - if (lacksValueMember.nonEmpty) { - val classTag = implicitly[ClassTag[ValueType]] - val lacksValueMemberStr = - lacksValueMember.map(_.tree.symbol).mkString(", ") - c.abort( - c.enclosingPosition, - s""" - |It looks like not all of the members have a literal/constant 'value:${classTag.runtimeClass.getSimpleName}' declaration, namely: $lacksValueMemberStr. - | - |This can happen if: - | - |- The aforementioned members have their `value` supplied by a variable, or otherwise defined as a method - | - |If none of the above apply to your case, it's likely you have discovered an issue with Enumeratum, so please file an issue :) - """.stripMargin - ) + scalacOptions += "-Yretain-trees" +""") } - hasValueMember.collect { case TreeWithMaybeVal(tree, Some(v)) => - TreeWithVal(tree, v) + + val repr = TypeRepr.of[A](using tpe) + val tpeSym = repr.typeSymbol + + val valueRepr = TypeRepr.of[ValueType] + + val ctorParams = tpeSym.primaryConstructor.paramSymss.flatten + + val enumFields = repr.typeSymbol.fieldMembers.flatMap { field => + ctorParams.zipWithIndex.find { case (p, i) => + p.name == field.name && (p.tree match { + case term: Term => + term.tpe <:< valueRepr + + case _ => + false + }) + } + }.toSeq + + val (valueField, valueParamIndex): (Symbol, Int) = { + if (enumFields.size == 1) { + enumFields.headOption + } else { + enumFields.find(_._1.name == "value") + } + }.getOrElse { + Symbol.newVal(tpeSym, "value", valueRepr, Flags.Abstract, Symbol.noSymbol) -> 0 } - } - /** Looks through the given trees and tries to find the proper value declaration/constructor - * argument. - * - * Aborts compilation if the value declaration/constructor is of the wrong type, - */ - private[this] def toTreeWithMaybeVals[ValueType: ClassTag, ProcessedValueType](c: Context)( - valueEntryCTorsParams: List[List[c.universe.Name]], - memberTrees: Seq[c.universe.ModuleDef], - processFoundValues: ValueType => ProcessedValueType - ): Seq[TreeWithMaybeVal[c.universe.ModuleDef, ProcessedValueType]] = { - import c.universe._ - val classTag = implicitly[ClassTag[ValueType]] - val valueTerm = ContextUtils.termName(c)("value") - // go through all the trees - memberTrees.map { declTree => - val directMemberTrees = - declTree.children.flatMap(_.children) // Things that are body-level, no lower - val constructorTrees = { - val immediate = directMemberTrees // for 2.11+ this is enough - val constructorName = ContextUtils.constructorName(c) - // Sadly 2.10 has parent-class constructor calls nested inside a member.. - val method = - directMemberTrees.collect { // for 2.10.x, we need to grab the body-level constructor method's trees - case t @ DefDef(_, `constructorName`, _, _, _, _) => - t.collect { case t => t } - }.flatten - immediate ++ method - }.iterator - - val valuesFromMembers = directMemberTrees.iterator.collect { - case ValDef(_, termName, _, Literal(Constant(i: ValueType))) if termName == valueTerm => - Some(i) + type IsValue[T <: ValueType] = T + + object ConstVal { + @annotation.tailrec + def unapply(tree: Tree): Option[Constant] = tree match { + case NamedArg(nme, v) if (nme == valueField.name) => + unapply(v) + + case ValDef(nme, _, Some(v)) if (nme == valueField.name) => + unapply(v) + + case lit @ Literal(const) if (lit.tpe <:< valueRepr) => + Some(const) + + case _ => + None } + } + + @annotation.tailrec + def collect[T <: Tuple]( + instances: List[Expr[A]], + values: Map[TypeRepr, ValueType] + )(using tupleTpe: Type[T]): Either[String, Expr[List[A]]] = + tupleTpe match { + case '[TakeHead[h, tail]] => { + val htpr = TypeRepr.of[h] + + (for { + vof <- Expr.summon[ValueOf[h]] + constValue <- htpr.typeSymbol.tree match { + case ClassDef(_, _, spr, _, rhs) => { + val fromCtor = spr + .collectFirst { + case Apply(Select(New(id), _), args) if id.tpe <:< repr => + args + } + .flatMap(_.lift(valueParamIndex).collect { case ConstVal(const) => + const + }) + + fromCtor + .orElse(rhs.collectFirst { case ConstVal(v) => v }) + .flatMap { const => + cls.unapply(const.value) + } - val NamedArg = ContextUtils.namedArg(c) - - // Sadly 2.10 has parent-class constructor calls nested inside a member.. - val valuesFromConstructors = constructorTrees.collect { - // The tree has a method call - case Apply(_, args) => { - val valueArguments: List[Option[ValueType]] = - valueEntryCTorsParams.collect { - // Find non-empty constructor param lists - case paramTermNames if paramTermNames.nonEmpty => { - val paramsWithArg = paramTermNames.zip(args) - paramsWithArg.collectFirst { - // found a (paramName, argument) parameter-argument pair where paramName is "value", and argument is a constant with the right type - case (`valueTerm`, Literal(Constant(i: ValueType))) => i - // found a (paramName, argument) parameter-argument pair where paramName is "value", and argument is a constant with the wrong type - case (`valueTerm`, Literal(Constant(i))) => - c.abort( - c.enclosingPosition, - s"${declTree.symbol} has a value with the wrong type: $i:${i.getClass}, instead of ${classTag.runtimeClass}." - ) - /* - * found a (_, NamedArgument(argName, argument)) parameter-named pair where the argument is named "value" and the argument itself is of the right type - */ - case (_, NamedArg(Ident(`valueTerm`), Literal(Constant(i: ValueType)))) => - i - /* - * found a (_, NamedArgument(argName, argument)) parameter-named pair where the argument is named "value" and the argument itself is of the wrong type - */ - case (_, NamedArg(Ident(`valueTerm`), Literal(Constant(i)))) => - c.abort( - c.enclosingPosition, - s"${declTree.symbol} has a value with the wrong type: $i:${i.getClass}, instead of ${classTag.runtimeClass}" - ) - } } + + case _ => + Option.empty[ValueType] } - // We only want the first such constructor argument - valueArguments.collectFirst { case Some(v) => v } + } yield Tuple3(TypeRepr.of[h], '{ ${ vof }.value: A }, constValue)) match { + case Some((tpr, instance, value)) => + collect[tail](instance :: instances, values + (tpr -> value)) + + case None => + report.errorAndAbort( + s"Fails to check value entry ${htpr.show} for enum ${repr.show}" + ) + } } - } - val values = valuesFromMembers ++ valuesFromConstructors - val processedValue = values.collectFirst { case Some(v) => - processFoundValues(v) + case '[EmptyTuple] => { + val allowAlias = repr <:< TypeRepr.of[AllowAlias] + + if (!allowAlias && values.values.toSet.size < values.size) { + val details = values + .map { case (sub, value) => + s"${sub.show} = $value" + } + .mkString(", ") + + Left(s"Values for ${valueField.name} are not discriminated subtypes: ${details}") + } else { + Right(Expr ofList instances.reverse) + } + } } - TreeWithMaybeVal(declTree, processedValue) - } - } - /** Given a type, finds the constructor params lists for it - */ - private[this] def findConstructorParamsLists[ValueEntryType: Type](using Quotes): List[List[c.universe.Name]] = { - val valueEntryTypeTpe = implicitly[Type[ValueEntryType]].tpe - val valueEntryTypeTpeMembers = valueEntryTypeTpe.members - valueEntryTypeTpeMembers.collect(ContextUtils.constructorsToParamNamesPF(c)).toList - } + val result: Either[String, Expr[List[A]]] = + Expr.summon[Mirror.SumOf[A]] match { + case Some(sum) => + sum.asTerm.tpe.asType match { + case '[SumOf[A, t]] => + collect[t](List.empty, Map.empty) - /** Ensures that we have unique values for trees, aborting otherwise with a message indicating - * which trees have the same symbol - */ - private[this] def ensureUnique[A](c: Context)( - treeWithVals: Seq[TreeWithVal[c.universe.ModuleDef, A]] - ): Unit = { - val membersWithValues = treeWithVals.map { treeWithVal => - treeWithVal.tree.symbol -> treeWithVal.value - } - val groupedByValue = membersWithValues.groupBy(_._2).map { case (k, v) => - (k, v.map(_._1)) - } - val (valuesWithOneSymbol, valuesWithMoreThanOneSymbol) = - groupedByValue.partition(_._2.size <= 1) - if (valuesWithOneSymbol.size != membersWithValues.toMap.keys.size) { - val formattedString = valuesWithMoreThanOneSymbol.toSeq.reverse.foldLeft("") { - case (acc, (k, v)) => - acc ++ s"""$k has members [ ${v.mkString(", ")} ]\n """ + case _ => + Left(s"Invalid `Mirror.SumOf[${repr.show}]`") + + } + + case None => + Left(s"Missing `Mirror.SumOf[${repr.show}]`") } - c.abort( - c.enclosingPosition, - s""" - | - | It does not look like you have unique values in your ValueEnum. - | Each of the following values correspond to more than one member: - | - | $formattedString - | Please check to make sure members have unique values. - | """.stripMargin - ) + + result match { + case Left(errorMsg) => + report.errorAndAbort(errorMsg) + + case Right(instances) => + '{ IndexedSeq.empty ++ $instances } } } - - // Helper case classes - private[this] case class TreeWithMaybeVal[CTree, T](tree: CTree, maybeValue: Option[T]) - private[this] case class TreeWithVal[CTree, T](tree: CTree, value: T) - */ }