diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 4bc427ee0687..34ddaacc5378 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -787,6 +787,7 @@ class Definitions { @tu lazy val MirrorClass: ClassSymbol = requiredClass("scala.deriving.Mirror") @tu lazy val Mirror_ProductClass: ClassSymbol = requiredClass("scala.deriving.Mirror.Product") @tu lazy val Mirror_Product_fromProduct: Symbol = Mirror_ProductClass.requiredMethod(nme.fromProduct) + @tu lazy val Mirror_Product_defaultArgument: Symbol = Mirror_ProductClass.requiredMethod(nme.defaultArgument) @tu lazy val Mirror_SumClass: ClassSymbol = requiredClass("scala.deriving.Mirror.Sum") @tu lazy val Mirror_SingletonClass: ClassSymbol = requiredClass("scala.deriving.Mirror.Singleton") @tu lazy val Mirror_SingletonProxyClass: ClassSymbol = requiredClass("scala.deriving.Mirror.SingletonProxy") diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 253a45ffd7a8..14ad70d74f5f 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -368,6 +368,7 @@ object StdNames { val LiteralAnnotArg: N = "LiteralAnnotArg" val Matchable: N = "Matchable" val MatchCase: N = "MatchCase" + val MirroredElemHasDefaults: N = "MirroredElemHasDefaults" val MirroredElemTypes: N = "MirroredElemTypes" val MirroredElemLabels: N = "MirroredElemLabels" val MirroredLabel: N = "MirroredLabel" @@ -452,6 +453,7 @@ object StdNames { val create: N = "create" val currentMirror: N = "currentMirror" val curried: N = "curried" + val defaultArgument: N = "defaultArgument" val definitions: N = "definitions" val delayedInit: N = "delayedInit" val delayedInitArg: N = "delayedInit$body" diff --git a/compiler/src/dotty/tools/dotc/core/SymUtils.scala b/compiler/src/dotty/tools/dotc/core/SymUtils.scala index 65634241b790..cdf132e32590 100644 --- a/compiler/src/dotty/tools/dotc/core/SymUtils.scala +++ b/compiler/src/dotty/tools/dotc/core/SymUtils.scala @@ -17,6 +17,7 @@ import Annotations.Annotation import Phases.* import ast.tpd.Literal import transform.Mixin +import dotty.tools.tasty.TastyVersion import dotty.tools.dotc.transform.sjs.JSSymUtils.sjsNeedsField @@ -115,6 +116,13 @@ class SymUtils: def isGenericProduct(using Context): Boolean = whyNotGenericProduct.isEmpty + /** Is a case class for which mirrors support access to default arguments. + * see sbt-test/scala3-compat/defaultArgument-mirrors-3.3 for why this is needed + */ + def mirrorSupportsDefaultArguments(using Context): Boolean = + !(self.is(JavaDefined) || self.is(Scala2x)) && self.isClass && self.tastyInfo.forall: + case TastyInfo(TastyVersion(major, minor, exp), _) => major == 28 && minor >= 4 + /** Is this an old style implicit conversion? * @param directOnly only consider explicitly written methods * @param forImplicitClassOnly only consider methods generated from implicit classes diff --git a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala index 6d2aedb9b47b..8c276dc68f29 100644 --- a/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala +++ b/compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala @@ -9,7 +9,7 @@ import Decorators.* import NameOps.* import Annotations.Annotation import typer.ProtoTypes.constrained -import ast.untpd +import ast.{tpd, untpd} import util.Property import util.Spans.Span @@ -547,6 +547,30 @@ class SyntheticMembers(thisPhase: DenotTransformer) { New(classRefApplied, elems) end fromProductBody + def defaultArgumentBody(caseClass: Symbol, index: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree = + val companionTree: Tree = + val companion: Symbol = caseClass.companionModule + val prefix: Type = optInfo.fold(NoPrefix)(_.pre) + ref(TermRef(prefix, companion.asTerm)) + + def defaultArgumentGetter(idx: Int): Tree = + val getterName = NameKinds.DefaultGetterName(nme.CONSTRUCTOR, idx) + val getterDenot = companionTree.tpe.member(getterName) + companionTree.select(TermRef(companionTree.tpe, getterName, getterDenot)) + + val withDefaultCases = for + (acc, idx) <- caseClass.caseAccessors.zipWithIndex if acc.is(HasDefault) + body = Typed(defaultArgumentGetter(idx), TypeTree(defn.AnyType)) // so match tree does try to find union of case types + yield CaseDef(Literal(Constant(idx)), EmptyTree, body) + + val withoutDefaultCase = + val stringIndex = Apply(Select(index, nme.toString_), Nil) + val nsee = tpd.resolveConstructor(defn.NoSuchElementExceptionType, List(stringIndex)) + CaseDef(Underscore(defn.IntType), EmptyTree, Throw(nsee)) + + Match(index, withDefaultCases :+ withoutDefaultCase) + end defaultArgumentBody + /** For an enum T: * * def ordinal(x: MirroredMonoType) = x.ordinal @@ -616,6 +640,12 @@ class SyntheticMembers(thisPhase: DenotTransformer) { synthesizeDef(meth, vrefss => body(cls, vrefss.head.head)) } } + def overrideMethod(name: TermName, info: Type, cls: Symbol, body: (Symbol, Tree) => Context ?=> Tree, isExperimental: Boolean = false): Unit = { + val meth = newSymbol(clazz, name, Synthetic | Method | Override, info, coord = clazz.coord) + if isExperimental then meth.addAnnotation(defn.ExperimentalAnnot) + meth.enteredAfter(thisPhase) + newBody = newBody :+ synthesizeDef(meth, vrefss => body(cls, vrefss.head.head)) + } val linked = clazz.linkedClass lazy val monoType = { val existing = clazz.info.member(tpnme.MirroredMonoType).symbol @@ -633,6 +663,9 @@ class SyntheticMembers(thisPhase: DenotTransformer) { addParent(defn.Mirror_ProductClass.typeRef) addMethod(nme.fromProduct, MethodType(defn.ProductClass.typeRef :: Nil, monoType.typeRef), cls, fromProductBody(_, _, optInfo).ensureConforms(monoType.typeRef)) // t4758.scala or i3381.scala are examples where a cast is needed + if cls.primaryConstructor.hasDefaultParams && cls.mirrorSupportsDefaultArguments then + overrideMethod(nme.defaultArgument, MethodType(defn.IntType :: Nil, defn.AnyType), cls, + defaultArgumentBody(_, _, optInfo), isExperimental = true) } def makeSumMirror(cls: Symbol, optInfo: Option[MirrorImpl.OfSum]) = { addParent(defn.Mirror_SumClass.typeRef) diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index c94724faf4d4..0725e07ff267 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -409,25 +409,31 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): def makeProductMirror(pre: Type, cls: Symbol, tps: Option[List[Type]]): TreeWithErrors = val accessors = cls.caseAccessors - val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString))) - val typeElems = tps.getOrElse(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr)) - val nestedPairs = TypeOps.nestedPairs(typeElems) - val (monoType, elemsType) = mirroredType match + val Seq(elemLabels, elemHasDefaults, elemTypes1) = + val supportsDefaults = cls.mirrorSupportsDefaultArguments + Seq( + accessors.map(acc => ConstantType(Constant(acc.name.toString))), + accessors.map(acc => ConstantType(Constant(supportsDefaults && acc.is(HasDefault)))), + tps.getOrElse(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr)) + ).map(TypeOps.nestedPairs) + val (monoType, elemTypes) = mirroredType match case mirroredType: HKTypeLambda => - (mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs)) + (mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = elemTypes1)) case _ => - (mirroredType, nestedPairs) - val elemsLabels = TypeOps.nestedPairs(elemLabels) - checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span) - checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span) + (mirroredType, elemTypes1) + + checkRefinement(formal, tpnme.MirroredElemTypes, elemTypes, span) + checkRefinement(formal, tpnme.MirroredElemLabels, elemLabels, span) + checkRefinement(formal, tpnme.MirroredElemHasDefaults, elemHasDefaults, span) val mirrorType = formal.constrained_& { mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name) - .refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType)) - .refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels)) + .refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemTypes)) + .refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemLabels)) + .refinedWith(tpnme.MirroredElemHasDefaults, TypeAlias(elemHasDefaults)) } val mirrorRef = if cls.useCompanionAsProductMirror then companionPath(mirroredType, span) - else if defn.isTupleClass(cls) then newTupleMirror(typeElems.size) // TODO: cls == defn.PairClass when > 22 + else if defn.isTupleClass(cls) then newTupleMirror(accessors.size) // TODO: cls == defn.PairClass when > 22 else anonymousMirror(monoType, MirrorImpl.OfProduct(pre), span) withNoErrors(mirrorRef.cast(mirrorType).withSpan(span)) end makeProductMirror diff --git a/library/src/scala/deriving/Mirror.scala b/library/src/scala/deriving/Mirror.scala index 57453a516567..b54786ad4208 100644 --- a/library/src/scala/deriving/Mirror.scala +++ b/library/src/scala/deriving/Mirror.scala @@ -1,5 +1,8 @@ package scala.deriving +import java.util.NoSuchElementException +import scala.annotation.experimental + /** Mirrors allows typelevel access to enums, case classes and objects, and their sealed parents. */ sealed trait Mirror { @@ -27,6 +30,14 @@ object Mirror { /** Create a new instance of type `T` with elements taken from product `p`. */ def fromProduct(p: scala.Product): MirroredMonoType + + /** Whether each product element has a default value */ + @experimental type MirroredElemHasDefaults <: Tuple + + /** The default argument of the product argument at given `index` */ + @experimental def defaultArgument(index: Int): Any = + throw NoSuchElementException(String.valueOf(index)) + } trait Singleton extends Product { diff --git a/project/MiMaFilters.scala b/project/MiMaFilters.scala index 696fbeec8a39..f3b3b28e3d3c 100644 --- a/project/MiMaFilters.scala +++ b/project/MiMaFilters.scala @@ -28,6 +28,7 @@ object MiMaFilters { val LibraryForward: Map[String, Seq[ProblemFilter]] = Map( // Additions that require a new minor version of the library Build.previousDottyVersion -> Seq( + ProblemFilters.exclude[DirectMissingMethodProblem]("scala.compiletime.testing.Error.defaultArgument"), ), // Additions since last LTS @@ -62,6 +63,7 @@ object MiMaFilters { ), ) val TastyCore: Seq[ProblemFilter] = Seq( + ProblemFilters.exclude[DirectMissingMethodProblem]("dotty.tools.tasty.TastyVersion.defaultArgument"), ) val Interfaces: Seq[ProblemFilter] = Seq( ) diff --git a/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/app/Main.scala b/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/app/Main.scala new file mode 100644 index 000000000000..f7edaeb34fb8 --- /dev/null +++ b/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/app/Main.scala @@ -0,0 +1,58 @@ +import scala.deriving.Mirror + +package lib { + + case class NewFoo(x: Int = 1, y: Int) + + object NewMirrors { + val mNewFoo = summon[Mirror.Of[NewFoo]] + + val mOldFoo = summon[Mirror.Of[OldFoo]] + val mOldBar = summon[Mirror.Of[OldBar]] + } +} + +package app { + import lib.* + + object Main { + + // defaultArgument implementation did not throw NoSuchElementException + def foundDefaultArgument(m: Mirror.Product): Boolean = try { + m.defaultArgument(0) + true + } catch { + case _: NoSuchElementException => false + } + + def main(args: Array[String]): Unit = { + + // NewFoo: normal case with support for default arguments + + assert(NewMirrors.mNewFoo.defaultArgument(0) == 1) + summon[NewMirrors.mNewFoo.MirroredElemHasDefaults =:= (true, false)] + + // OldFoo: does not override the defaultArgument implementation + + assert(!foundDefaultArgument(NewMirrors.mOldFoo)) // Expected: since mirror of old case class + summon[NewMirrors.mOldFoo.MirroredElemHasDefaults =:= (false, false)] // Necessary: to be consistent with defaultArgument implementation + + assert(!foundDefaultArgument(OldMirrors.mOldFoo)) // Expected: since mirror of old case class + summon[scala.util.NotGiven[OldMirrors.mOldFoo.MirroredElemHasDefaults <:< (Boolean, Boolean)]] // reference to old mirror doesn't have any refinement + summon[OldMirrors.mOldFoo.MirroredElemHasDefaults <:< Tuple] // but does inherit type member from Mirror trait + + // OldBar: is anon mirror so could implement defaultArgument + // but we manually keep behaviour consistent with other mirrors of old case classes + + assert(NewMirrors.mOldBar ne lib.OldBar) + assert(!foundDefaultArgument(NewMirrors.mOldBar)) + summon[NewMirrors.mOldBar.MirroredElemHasDefaults =:= (false, false)] // Ok: should be consistent with above + + assert(OldMirrors.mOldBar ne lib.OldBar) + assert(!foundDefaultArgument(OldMirrors.mOldBar)) + summon[scala.util.NotGiven[OldMirrors.mOldBar.MirroredElemHasDefaults <:< (Boolean, Boolean)]] + summon[OldMirrors.mOldBar.MirroredElemHasDefaults <:< Tuple] + + } + } +} diff --git a/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/build.sbt b/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/build.sbt new file mode 100644 index 000000000000..189f9e7e707a --- /dev/null +++ b/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/build.sbt @@ -0,0 +1,7 @@ +lazy val lib = project.in(file("lib")) + .settings( + scalaVersion := "3.3.0" + ) + +lazy val app = project.in(file("app")) + .dependsOn(lib) diff --git a/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/lib/Foo.scala b/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/lib/Foo.scala new file mode 100644 index 000000000000..ff331eb26668 --- /dev/null +++ b/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/lib/Foo.scala @@ -0,0 +1,13 @@ +package lib + +import deriving.Mirror + +case class OldFoo(x: Int = 1, y: Int) + +case class OldBar(x: Int = 1, y: Int) +case object OldBar + +object OldMirrors { + val mOldFoo = summon[Mirror.ProductOf[OldFoo]] + val mOldBar = summon[Mirror.ProductOf[OldBar]] +} diff --git a/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/project/DottyInjectedPlugin.scala b/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/project/DottyInjectedPlugin.scala new file mode 100644 index 000000000000..fb946c4b8c61 --- /dev/null +++ b/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/project/DottyInjectedPlugin.scala @@ -0,0 +1,11 @@ +import sbt._ +import Keys._ + +object DottyInjectedPlugin extends AutoPlugin { + override def requires = plugins.JvmPlugin + override def trigger = allRequirements + + override val projectSettings = Seq( + scalaVersion := sys.props("plugin.scalaVersion") + ) +} diff --git a/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/test b/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/test new file mode 100644 index 000000000000..63092ffa4a03 --- /dev/null +++ b/sbt-test/scala3-compat/defaultArgument-mirrors-3.3/test @@ -0,0 +1 @@ +> app/run diff --git a/sbt-test/source-dependencies/mirror-product/MyProduct.scala b/sbt-test/source-dependencies/mirror-product/MyProduct.scala index acad1358f62b..ecc868c8d850 100644 --- a/sbt-test/source-dependencies/mirror-product/MyProduct.scala +++ b/sbt-test/source-dependencies/mirror-product/MyProduct.scala @@ -1 +1,3 @@ case class MyProduct(x: Int) +case class WillGetDefault(x: Int) +case class WillChangeDefault(x: Int = 1) diff --git a/sbt-test/source-dependencies/mirror-product/Test.scala b/sbt-test/source-dependencies/mirror-product/Test.scala index e53d7b999517..826b0d181abc 100644 --- a/sbt-test/source-dependencies/mirror-product/Test.scala +++ b/sbt-test/source-dependencies/mirror-product/Test.scala @@ -8,3 +8,5 @@ transparent inline def foo[T](using m: Mirror.Of[T]): Int = @main def Test = assert(foo[MyProduct] == 2) + assert(summon[Mirror.Of[WillGetDefault]].defaultArgument(0) == 1) + assert(summon[Mirror.Of[WillChangeDefault]].defaultArgument(0) == 2) diff --git a/sbt-test/source-dependencies/mirror-product/changes/MyProduct.scala b/sbt-test/source-dependencies/mirror-product/changes/MyProduct.scala index 87e5af62bd7e..c98c40bd279c 100644 --- a/sbt-test/source-dependencies/mirror-product/changes/MyProduct.scala +++ b/sbt-test/source-dependencies/mirror-product/changes/MyProduct.scala @@ -1 +1,3 @@ case class MyProduct(x: Int, y: String) +case class WillGetDefault(x: Int = 1) +case class WillChangeDefault(x: Int = 2) diff --git a/tests/run-macros/i7987.check b/tests/run-macros/i7987.check index 85a185c1d5c7..80e6e372c833 100644 --- a/tests/run-macros/i7987.check +++ b/tests/run-macros/i7987.check @@ -4,4 +4,5 @@ scala.deriving.Mirror.Product { type MirroredLabel >: "Some" <: "Some" type MirroredElemTypes >: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple] <: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple] type MirroredElemLabels >: scala.*:["value", scala.Tuple$package.EmptyTuple] <: scala.*:["value", scala.Tuple$package.EmptyTuple] + type MirroredElemHasDefaults >: scala.*:[false, scala.Tuple$package.EmptyTuple] <: scala.*:[false, scala.Tuple$package.EmptyTuple] } diff --git a/tests/run-macros/mirror-defaultArgument/MirrorOps.scala b/tests/run-macros/mirror-defaultArgument/MirrorOps.scala new file mode 100644 index 000000000000..75f00aff8f63 --- /dev/null +++ b/tests/run-macros/mirror-defaultArgument/MirrorOps.scala @@ -0,0 +1,25 @@ +import scala.deriving._ +import scala.annotation.experimental +import scala.quoted._ + +object MirrorOps: + + inline def overridesDefaultArgument[T]: Boolean = ${ overridesDefaultArgumentImpl[T] } + + def overridesDefaultArgumentImpl[T](using Quotes, Type[T]): Expr[Boolean] = + import quotes.reflect.* + val cls = TypeRepr.of[T].classSymbol.get + val companion = cls.companionModule.moduleClass + val methods = companion.declaredMethods + + val experAnnotType = Symbol.requiredClass("scala.annotation.experimental").typeRef + + Expr { + methods.exists { m => + m.name == "defaultArgument" && + m.flags.is(Flags.Synthetic) && + m.annotations.exists(_.tpe <:< experAnnotType) + } + } + +end MirrorOps diff --git a/tests/run-macros/mirror-defaultArgument/test.scala b/tests/run-macros/mirror-defaultArgument/test.scala new file mode 100644 index 000000000000..da2d29b27b20 --- /dev/null +++ b/tests/run-macros/mirror-defaultArgument/test.scala @@ -0,0 +1,13 @@ +import scala.deriving._ +import scala.annotation.experimental +import scala.quoted._ + +import MirrorOps.* + +object Test extends App: + + case class WithDefault(x: Int, y: Int = 1) + assert(overridesDefaultArgument[WithDefault]) + + case class WithoutDefault(x: Int) + assert(!overridesDefaultArgument[WithoutDefault]) diff --git a/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala b/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala index 12ea8eb26c47..e939494108bc 100644 --- a/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala +++ b/tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala @@ -96,6 +96,10 @@ val experimentalDefinitionInLibrary = Set( "scala.Tuple$.Reverse", // can be stabilized in 3.5 "scala.Tuple$.ReverseOnto", // can be stabilized in 3.5 "scala.runtime.Tuples$.reverse", // can be stabilized in 3.5 + + // New APIs: Mirror support for default arguments + "scala.deriving.Mirror$.Product.MirroredElemHasDefaults", + "scala.deriving.Mirror$.Product.defaultArgument", ) diff --git a/tests/run/mirror-defaultArgument.scala b/tests/run/mirror-defaultArgument.scala new file mode 100644 index 000000000000..eaff19094128 --- /dev/null +++ b/tests/run/mirror-defaultArgument.scala @@ -0,0 +1,53 @@ +import scala.deriving._ +import scala.annotation.experimental + +object Test extends App: + + case class WithDefault(x: Int, y: Int = 1) + val m = summon[Mirror.Of[WithDefault]] + assert(m.defaultArgument(1) == 1) + try + m.defaultArgument(0) + throw IllegalStateException("There should be no default argument") + catch + case ex: NoSuchElementException => assert(ex.getMessage == "0") // Ok + + + case class WithCompanion(s: String = "hello") + case object WithCompanion // => mirrors must be anonymous + + val m2 = summon[Mirror.Of[WithCompanion]] + assert(m2 ne WithCompanion) + assert(m2.defaultArgument(0) == "hello") + + + class Outer(val i: Int) { + + case class Inner(x: Int, y: Int = i + 1) + case object Inner + + val m3 = summon[Mirror.Of[Inner]] + assert(m3.defaultArgument(1) == i + 1) + + def localTest(d: Double): Unit = { + case class Local(x: Int = i, y: Double = d, z: Double = i + d) + case object Local + + val m4 = summon[Mirror.Of[Local]] + assert(m4.defaultArgument(0) == i) + assert(m4.defaultArgument(1) == d) + assert(m4.defaultArgument(2) == i + d) + } + + } + + val outer = Outer(3) + val m5 = summon[Mirror.Of[outer.Inner]] + assert(m5.defaultArgument(1) == 3 + 1) + outer.localTest(9d) + + + // new defaultArgument match tree should be able to unify different default value types + case class Foo[T](x: Int = 0, y: String = "hi") + +end Test diff --git a/tests/run/typeclass-derivation-defaultArgument.scala b/tests/run/typeclass-derivation-defaultArgument.scala new file mode 100644 index 000000000000..e2648c6cad89 --- /dev/null +++ b/tests/run/typeclass-derivation-defaultArgument.scala @@ -0,0 +1,101 @@ +import scala.deriving.Mirror as M +import scala.deriving.* +import scala.Tuple.* +import scala.compiletime.* +import scala.compiletime.ops.int.S + +trait Migration[-From, +To]: + def apply(x: From): To + +object Migration: + + extension [From](x: From) + def migrateTo[To](using m: Migration[From, To]): To = m(x) + + given[T]: Migration[T, T] with + override def apply(x: T): T = x + + type IndexOf[Elems <: Tuple, X] <: Int = Elems match { + case (X *: elems) => 0 + case (_ *: elems) => S[IndexOf[elems, X]] + case EmptyTuple => Nothing + } + + inline def migrateElem[F,T, ToIdx <: Int](from: M.ProductOf[F], to: M.ProductOf[T])(x: Product): Any = + + type Label = Elem[to.MirroredElemLabels, ToIdx] + type FromIdx = IndexOf[from.MirroredElemLabels, Label] + inline constValueOpt[FromIdx] match + + case Some(fromIdx) => + type FromType = Elem[from.MirroredElemTypes, FromIdx] + type ToType = Elem[to.MirroredElemTypes, ToIdx] + summonFrom { case _: Migration[FromType, ToType] => + x.productElement(fromIdx).asInstanceOf[FromType].migrateTo[ToType] + } + + case None => + type HasDefault = Elem[to.MirroredElemHasDefaults, ToIdx] + inline erasedValue[HasDefault] match + case _: true => to.defaultArgument(constValue[ToIdx]) + case _: false => compiletime.error("An element has no equivalent or default") + + + inline def migrateElems[F,T, ToIdx <: Int](from: M.ProductOf[F], to: M.ProductOf[T])(x: Product): Seq[Any] = + inline erasedValue[ToIdx] match + case _: Tuple.Size[to.MirroredElemLabels] => Seq() + case _ => migrateElem[F,T,ToIdx](from, to)(x) +: migrateElems[F,T,S[ToIdx]](from, to)(x) + + inline def migrateProduct[F,T](from: M.ProductOf[F], to: M.ProductOf[T]) + (x: Product): T = + val elems = migrateElems[F, T, 0](from, to)(x) + to.fromProduct(new Product: + def canEqual(that: Any): Boolean = false + def productArity: Int = elems.length + def productElement(n: Int): Any = elems(n) + ) + + inline def migration[F,T](using from: M.Of[F], to: M.Of[T]): Migration[F,T] = (x: F) => + inline from match + case fromP: M.ProductOf[F] => inline to match + case toP: M.ProductOf[T] => migrateProduct[F, T](fromP, toP)(x.asInstanceOf[Product]) + case _: M.SumOf[T] => compiletime.error("Cannot migrate sums") + case _: M.SumOf[F] => compiletime.error("Cannot migrate sums") + +end Migration + + +import Migration.* +object Test extends App: + + case class A1(x: Int) + case class A2(x: Int) + given Migration[A1, A2] = migration + assert(A1(2).migrateTo[A2] == A2(2)) + + case class B1(x: Int, y: String) + case class B2(y: String, x: Int) + given Migration[B1, B2] = migration + assert(B1(5, "hi").migrateTo[B2] == B2("hi", 5)) + + case class C1(x: A1) + case class C2(x: A2) + given Migration[C1, C2] = migration + assert(C1(A1(0)).migrateTo[C2] == C2(A2(0))) + + case class D1(x: Double) + case class D2(b: Boolean = true, x: Double) + given Migration[D1, D2] = migration + assert(D1(9).migrateTo[D2] == D2(true, 9)) + + case class E1(x: D1, y: D1) + case class E2(y: D2, s: String = "hi", x: D2) + given Migration[E1, E2] = migration + assert(E1(D1(1), D1(2)).migrateTo[E2] == E2(D2(true, 2), "hi", D2(true, 1))) + + // should only use default when needed + case class F1(x: Int) + case class F2(x: Int = 3) + given Migration[F1, F2] = migration + assert(F1(7).migrateTo[F2] == F2(7)) +