diff --git a/shared/src/main/scala-3/org/scalamock/clazz/MockMaker.scala b/shared/src/main/scala-3/org/scalamock/clazz/MockMaker.scala index e7514c28..95db97a5 100644 --- a/shared/src/main/scala-3/org/scalamock/clazz/MockMaker.scala +++ b/shared/src/main/scala-3/org/scalamock/clazz/MockMaker.scala @@ -21,7 +21,6 @@ package org.scalamock.clazz import org.scalamock.context.MockContext - import scala.quoted.* import scala.reflect.Selectable @@ -42,8 +41,13 @@ private[clazz] object MockMaker: def asParent(tree: TypeTree): TypeTree | Term = val constructorFieldsFilledWithNulls: List[List[Term]] = tree.tpe.dealias.typeSymbol.primaryConstructor.paramSymss - .filter(_.exists(!_.isType)) - .map(_.map(_.typeRef.asType match { case '[t] => '{ null.asInstanceOf[t] }.asTerm })) + .filterNot(_.exists(_.isType)) + .map(_.map(_.info.widen match { + case t@AppliedType(inner, applied) => + Select.unique('{null}.asTerm, "asInstanceOf").appliedToTypes(List(inner.appliedTo(tpe.typeArgs))) + case other => + Select.unique('{null}.asTerm, "asInstanceOf").appliedToTypes(List(other)) + })) if constructorFieldsFilledWithNulls.forall(_.isEmpty) then tree @@ -51,7 +55,9 @@ private[clazz] object MockMaker: Select( New(TypeIdent(tree.tpe.typeSymbol)), tree.tpe.typeSymbol.primaryConstructor - ).appliedToArgss(constructorFieldsFilledWithNulls) + ).appliedToTypes(tree.tpe.typeArgs) + .appliedToArgss(constructorFieldsFilledWithNulls) + val parents = @@ -91,7 +97,7 @@ private[clazz] object MockMaker: Symbol.newVal( parent = classSymbol, name = definition.symbol.name, - tpe = definition.tpeWithSubstitutedPathDependentFor(classSymbol), + tpe = definition.tpeWithSubstitutedInnerTypesFor(classSymbol), flags = Flags.Override, privateWithin = Symbol.noSymbol ) @@ -99,7 +105,7 @@ private[clazz] object MockMaker: Symbol.newMethod( parent = classSymbol, name = definition.symbol.name, - tpe = definition.tpeWithSubstitutedPathDependentFor(classSymbol), + tpe = definition.tpeWithSubstitutedInnerTypesFor(classSymbol), flags = Flags.Override, privateWithin = Symbol.noSymbol ) @@ -177,7 +183,7 @@ private[clazz] object MockMaker: "asInstanceOf" ), definition.tpe - .resolveParamRefs(definition.resTypeWithPathDependentOverrideFor(classSymbol), args) + .resolveParamRefs(definition.resTypeWithInnerTypesOverrideFor(classSymbol), args) .asType match { case '[t] => List(TypeTree.of[t]) } ) ) diff --git a/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala b/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala index a966a499..60bfb4dd 100644 --- a/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala +++ b/shared/src/main/scala-3/org/scalamock/clazz/Utils.scala @@ -3,22 +3,22 @@ package org.scalamock.clazz import scala.quoted.* import org.scalamock.context.MockContext -import scala.annotation.tailrec +import scala.annotation.{experimental, tailrec} private[clazz] class Utils(using val quotes: Quotes): import quotes.reflect.* extension (tpe: TypeRepr) - def collectPathDependent(ownerSymbol: Symbol): List[TypeRepr] = + def collectInnerTypes(ownerSymbol: Symbol): List[TypeRepr] = def loop(currentTpe: TypeRepr, names: List[String]): List[TypeRepr] = currentTpe match - case AppliedType(inner, appliedTypes) => loop(inner, names) ++ appliedTypes.flatMap(_.collectPathDependent(ownerSymbol)) + case AppliedType(inner, appliedTypes) => loop(inner, names) ++ appliedTypes.flatMap(_.collectInnerTypes(ownerSymbol)) case TypeRef(inner, name) if name == ownerSymbol.name && names.nonEmpty => List(tpe) case TypeRef(inner, name) => loop(inner, name :: names) case _ => Nil loop(tpe, Nil) - def pathDependentOverride(ownerSymbol: Symbol, newOwnerSymbol: Symbol, applyTypes: Boolean): TypeRepr = + def innerTypeOverride(ownerSymbol: Symbol, newOwnerSymbol: Symbol, applyTypes: Boolean): TypeRepr = @tailrec def loop(currentTpe: TypeRepr, names: List[(String, List[TypeRepr])], appliedTypes: List[TypeRepr]): TypeRepr = currentTpe match @@ -53,55 +53,80 @@ private[clazz] class Utils(using val quotes: Quotes): case _ => tpe + @experimental def resolveParamRefs(resType: TypeRepr, methodArgs: List[List[Tree]]) = - def loop(baseBindings: TypeRepr, typeRepr: TypeRepr): TypeRepr = - typeRepr match - case pr@ParamRef(bindings, idx) if bindings == baseBindings => - methodArgs.head(idx).asInstanceOf[TypeTree].tpe + tpe match + case baseBindings: PolyType => + def loop(typeRepr: TypeRepr): TypeRepr = + typeRepr match + case pr@ParamRef(bindings, idx) if bindings == baseBindings => + methodArgs.head(idx).asInstanceOf[TypeTree].tpe - case AppliedType(tycon, args) => - AppliedType(tycon, args.map(arg => loop(baseBindings, arg))) + case AppliedType(tycon, args) => + AppliedType(loop(tycon), args.map(arg => loop(arg))) - case other => other + case ff @ TypeRef(ref @ ParamRef(bindings, idx), name) => + def getIndex(bindings: TypeRepr): Int = + @tailrec + def loop(bindings: TypeRepr, idx: Int): Int = + bindings match + case MethodType(_, _, method: MethodType) => loop(method, idx + 1) + case _ => idx - tpe match - case pt: PolyType => loop(pt, resType) - case _ => resType + loop(bindings, 1) + + val maxIndex = methodArgs.length + val parameterListIdx = maxIndex - getIndex(bindings) + + TypeSelect(methodArgs(parameterListIdx)(idx).asInstanceOf[Term], name).tpe + + case other => other + + loop(resType) + case _ => + resType - def collectTypes: List[TypeRepr] = - def loop(currentTpe: TypeRepr, params: List[TypeRepr]): List[TypeRepr] = + def collectTypes: (List[TypeRepr], TypeRepr) = + @tailrec + def loop(currentTpe: TypeRepr, argTypesAcc: List[List[TypeRepr]], resType: TypeRepr): (List[TypeRepr], TypeRepr) = currentTpe match - case PolyType(_, _, res) => loop(res, Nil) - case MethodType(_, argTypes, res) => argTypes ++ loop(res, params) - case other => List(other) - loop(tpe, Nil) + case PolyType(_, _, res) => loop(res, List.empty[TypeRepr] :: argTypesAcc, resType) + case MethodType(_, argTypes, res) => loop(res, argTypes :: argTypesAcc, resType) + case other => (argTypesAcc.reverse.flatten, other) + loop(tpe, Nil, TypeRepr.of[Nothing]) case class MockableDefinition(idx: Int, symbol: Symbol, ownerTpe: TypeRepr): val mockValName = s"mock$$${symbol.name}$$$idx" val tpe = ownerTpe.memberType(symbol) - private val rawTypes = tpe.widen.collectTypes + private val (rawTypes, rawResType) = tpe.widen.collectTypes val parameterTypes = prepareTypesFor(ownerTpe.typeSymbol).map(_.tpe).init - def resTypeWithPathDependentOverrideFor(classSymbol: Symbol): TypeRepr = - val pd = rawTypes.last.collectPathDependent(ownerTpe.typeSymbol) - val pdUpdated = pd.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false)) - rawTypes.last.substituteTypes(pd.map(_.typeSymbol), pdUpdated) + def resTypeWithInnerTypesOverrideFor(classSymbol: Symbol): TypeRepr = + updatePathDependent(rawResType, List(rawResType), classSymbol) + + def tpeWithSubstitutedInnerTypesFor(classSymbol: Symbol): TypeRepr = + updatePathDependent(tpe, rawResType :: rawTypes, classSymbol) - def tpeWithSubstitutedPathDependentFor(classSymbol: Symbol): TypeRepr = - val pathDependentTypes = rawTypes.flatMap(_.collectPathDependent(ownerTpe.typeSymbol)) - val pdUpdated = pathDependentTypes.map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false)) - tpe.substituteTypes(pathDependentTypes.map(_.typeSymbol), pdUpdated) + private def updatePathDependent(where: TypeRepr, types: List[TypeRepr], classSymbol: Symbol): TypeRepr = + val pathDependentTypes = types.flatMap(_.collectInnerTypes(ownerTpe.typeSymbol)) + val pdUpdated = pathDependentTypes.map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = false)) + where.substituteTypes(pathDependentTypes.map(_.typeSymbol), pdUpdated) - def prepareTypesFor(classSymbol: Symbol) = rawTypes - .map(_.pathDependentOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true)) + def prepareTypesFor(classSymbol: Symbol) = (rawTypes :+ rawResType) + .map(_.innerTypeOverride(ownerTpe.typeSymbol, classSymbol, applyTypes = true)) .map { typeRepr => val adjusted = typeRepr.widen.mapParamRefWithWildcard match case TypeBounds(lower, upper) => upper case AppliedType(TypeRef(_, ""), elemTyps) => TypeRepr.typeConstructorOf(classOf[Seq[_]]).appliedTo(elemTyps) - case other => other + case TypeRef(_: ParamRef, _) => + TypeRepr.of[Any] + case AppliedType(TypeRef(_: ParamRef, _), _) => + TypeRepr.of[Any] + case other => + other adjusted.asType match case '[t] => TypeTree.of[t] } @@ -128,10 +153,11 @@ private[clazz] class Utils(using val quotes: Quotes): def apply(tpe: TypeRepr): List[MockableDefinition] = val methods = (tpe.typeSymbol.methodMembers.toSet -- TypeRepr.of[Object].typeSymbol.methodMembers).toList - .filter(sym => !sym.flags.is(Flags.Private) && !sym.flags.is(Flags.Final) && !sym.flags.is(Flags.Mutable)) - .filterNot(sym => tpe.memberType(sym) match - case defaultParam @ ByNameType(AnnotatedType(_, Apply(Select(New(Inferred()), ""), Nil))) => true - case _ => false + .filter(sym => + !sym.flags.is(Flags.Private) && + !sym.flags.is(Flags.Final) && + !sym.flags.is(Flags.Mutable) && + !sym.name.contains("$default$") ) .zipWithIndex .map((sym, idx) => MockableDefinition(idx, sym, tpe)) diff --git a/shared/src/test/scala-3/com/paulbutcher/test/ClassWithContextBoundSpec.scala b/shared/src/test/scala-3/com/paulbutcher/test/ClassWithContextBoundSpec.scala new file mode 100644 index 00000000..67066c60 --- /dev/null +++ b/shared/src/test/scala-3/com/paulbutcher/test/ClassWithContextBoundSpec.scala @@ -0,0 +1,37 @@ +package com.paulbutcher.test + +import org.scalamock.scalatest.MockFactory +import org.scalatest.funspec.AnyFunSpec + +import scala.reflect.ClassTag + +class ClassWithContextBoundSpec extends AnyFunSpec with MockFactory { + + it("compile without args") { + class ContextBounded[T: ClassTag] { + def method(x: Int): Unit = () + } + + val m = mock[ContextBounded[String]] + + } + + it("compile with args") { + class ContextBounded[T: ClassTag](x: Int) { + def method(x: Int): Unit = () + } + + val m = mock[ContextBounded[String]] + + } + + it("compile with provided explicitly type class") { + class ContextBounded[T](x: ClassTag[T]) { + def method(x: Int): Unit = () + } + + val m = mock[ContextBounded[String]] + + } + +} diff --git a/shared/src/test/scala-3/com/paulbutcher/test/PathDependentParamSpec.scala b/shared/src/test/scala-3/com/paulbutcher/test/PathDependentParamSpec.scala new file mode 100644 index 00000000..3524ef43 --- /dev/null +++ b/shared/src/test/scala-3/com/paulbutcher/test/PathDependentParamSpec.scala @@ -0,0 +1,107 @@ +package com.paulbutcher.test + +import org.scalamock.matchers.Matchers +import org.scalamock.scalatest.MockFactory +import org.scalatest.funspec.AnyFunSpec + +class PathDependentParamSpec extends AnyFunSpec with Matchers with MockFactory { + + trait Command { + type Answer + type AnswerConstructor[A] + } + + case class IntCommand() extends Command { + override type Answer = Int + override type AnswerConstructor[A] = Option[A] + } + + val cmd = IntCommand() + + trait PathDependent { + + def call0[T <: Command](cmd: T): cmd.Answer + + def call1[T <: Command](x: Int)(cmd: T): cmd.Answer + + def call2[T <: Command](y: String)(cmd: T)(x: Int): cmd.Answer + + def call3[T <: Command](cmd: T)(y: String)(x: Int): cmd.Answer + + def call4[T <: Command](cmd: T): Option[cmd.Answer] + + def call5[T <: Command](cmd: T)(x: cmd.Answer): Unit + + def call6[T <: Command](cmd: T): cmd.AnswerConstructor[Int] + + def call7[T <: Command](cmd: T)(x: cmd.AnswerConstructor[String])(y: cmd.Answer): Unit + } + + + it("path dependent in return type") { + val pathDependent = mock[PathDependent] + + (pathDependent.call0[IntCommand] _).expects(cmd).returns(5) + + assert(pathDependent.call0(cmd) == 5) + } + + it("path dependent in return type and parameter in last parameter list") { + val pathDependent = mock[PathDependent] + + (pathDependent.call1(_: Int)(_: IntCommand)).expects(5, cmd).returns(5) + + assert(pathDependent.call1(5)(cmd) == 5) + } + + it("path dependent in return type and parameter in middle parameter list ") { + val pathDependent = mock[PathDependent] + + (pathDependent.call2(_: String)(_: IntCommand)(_: Int)).expects("5", cmd, 5).returns(5) + + assert(pathDependent.call2("5")(cmd)(5) == 5) + } + + it("path dependent in return type and parameter in first parameter list ") { + val pathDependent = mock[PathDependent] + + (pathDependent.call3(_: IntCommand)(_: String)(_: Int)).expects(cmd, "5", 5).returns(5) + + assert(pathDependent.call3(cmd)("5")(5) == 5) + } + + it("path dependent in tycon return type") { + val pathDependent = mock[PathDependent] + + (pathDependent.call4[IntCommand] _).expects(cmd).returns(Some(5)) + + assert(pathDependent.call4(cmd) == Some(5)) + } + + it("path dependent in parameter list") { + val pathDependent = mock[PathDependent] + + (pathDependent.call5(_: IntCommand)(_: Int)).expects(cmd, 5).returns(()) + + assert(pathDependent.call5(cmd)(5) == ()) + } + + it("path dependent tycon in return type") { + val pathDependent = mock[PathDependent] + + (pathDependent.call6[IntCommand] _).expects(cmd).returns(Some(5)) + + assert(pathDependent.call6(cmd) == Some(5)) + } + + it("path dependent tycon in parameter list") { + val pathDependent = mock[PathDependent] + + (pathDependent.call7[IntCommand](_: IntCommand)(_: Option[String])(_: Int)) + .expects(cmd, Some("5"), 6) + .returns(()) + + assert(pathDependent.call7(cmd)(Some("5"))(6) == ()) + } + +} diff --git a/shared/src/test/scala/com/paulbutcher/test/mock/MethodsWithDefaultParamsTest.scala b/shared/src/test/scala/com/paulbutcher/test/mock/MethodsWithDefaultParamsTest.scala index 67d6e479..b9693513 100644 --- a/shared/src/test/scala/com/paulbutcher/test/mock/MethodsWithDefaultParamsTest.scala +++ b/shared/src/test/scala/com/paulbutcher/test/mock/MethodsWithDefaultParamsTest.scala @@ -34,6 +34,8 @@ class MethodsWithDefaultParamsTest extends IsolatedSpec { trait TraitHavingMethodsWithDefaultParams { def withAllDefaultParams(a: String = "default", b: CaseClass = CaseClass(42)): String + + def withDefaultParamAndTypeParam[T](a: String = "default", b: Int = 5): T } behavior of "Mocks" @@ -84,5 +86,13 @@ class MethodsWithDefaultParamsTest extends IsolatedSpec { m.withAllDefaultParams("other", CaseClass(99)) } + they should "mock trait methods with type param and default parameters" in { + val m = mock[TraitHavingMethodsWithDefaultParams] + + (m.withDefaultParamAndTypeParam[Int] _).expects("default", 5).returns(5) + + m.withDefaultParamAndTypeParam[Int]("default", 5) shouldBe 5 + } + override def newInstance = new MethodsWithDefaultParamsTest } diff --git a/shared/src/test/scala/org/scalamock/test/scalatest/AsyncSyncMixinTest.scala b/shared/src/test/scala/org/scalamock/test/scalatest/AsyncSyncMixinTest.scala index 87d8030a..d9c3ae02 100644 --- a/shared/src/test/scala/org/scalamock/test/scalatest/AsyncSyncMixinTest.scala +++ b/shared/src/test/scala/org/scalamock/test/scalatest/AsyncSyncMixinTest.scala @@ -20,13 +20,12 @@ package org.scalamock.test.scalatest -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest._ +import org.scalatest.flatspec.{AnyFlatSpec, AsyncFlatSpec} +import org.scalamock.scalatest.{MockFactory, AsyncMockFactory} /** * Tests for issue #371 */ -@Ignore class AsyncSyncMixinTest extends AnyFlatSpec { "MockFactory" should "be mixed only with Any*Spec and not Async*Spec traits" in {