From 3cb8b72f708d3d79ae7c018b1669a5e81441179e Mon Sep 17 00:00:00 2001 From: Yuito Murase Date: Mon, 7 Aug 2023 21:46:27 +0900 Subject: [PATCH] Fix bugs in QuoteMatcher::MatchResult::toExpr --- .../quoted/runtime/impl/QuoteMatcher.scala | 97 ++++++++----------- 1 file changed, 43 insertions(+), 54 deletions(-) diff --git a/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala b/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala index 26838d8ee695..f028e0286a73 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala @@ -11,6 +11,7 @@ import dotty.tools.dotc.core.Types.* import dotty.tools.dotc.core.StdNames.nme import dotty.tools.dotc.core.Symbols.* import dotty.tools.dotc.util.optional +import dotty.tools.dotc.ast.TreeTypeMap /** Matches a quoted tree against a quoted pattern tree. * A quoted pattern tree may have type and term holes in addition to normal terms. @@ -319,9 +320,9 @@ class QuoteMatcher(debug: Boolean) { val env = summon[Env] val capturedIds = args.map(getCapturedIdent) - val capturedSymbols = capturedIds.map(_.symbol) val capturedTargs = unrollHkNestedPairsTypeTree(targs) - val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v) && !capturedTargs.map(_.symbol).contains(v)) + val capturedSymbols = Set.from(capturedIds.map(_.symbol) ++ capturedTargs.map(_.symbol)) + val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v)) withEnv(captureEnv) { scrutinee match case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), capturedTargs.map(_.tpe), env) @@ -581,18 +582,17 @@ class QuoteMatcher(debug: Boolean) { /** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */ def freePatternVars(term: Tree)(using Env, Context): Set[Symbol] = val typeAccumulator = new TypeAccumulator[Set[Symbol]] { - def apply(x: Set[Symbol], tp: Type): Set[Symbol] = - if summon[Env].contains(tp.typeSymbol) then - foldOver(x + tp.typeSymbol, tp) - else - foldOver(x, tp) + def apply(x: Set[Symbol], tp: Type): Set[Symbol] = tp match + case tp: TypeRef if summon[Env].contains(tp.typeSymbol) => foldOver(x + tp.typeSymbol, tp) + case tp: TermRef if summon[Env].contains(tp.termSymbol) => foldOver(x + tp.termSymbol, tp) + case _ => foldOver(x, tp) } val treeAccumulator = new TreeAccumulator[Set[Symbol]] { def apply(x: Set[Symbol], tree: Tree)(using Context): Set[Symbol] = - val tvars = typeAccumulator(Set.empty, tree.tpe) tree match - case tree: Ident if summon[Env].contains(tree.symbol) => foldOver(x ++ tvars + tree.symbol, tree) - case _ => foldOver(x ++ tvars, tree) + case tree: Ident if summon[Env].contains(tree.symbol) => foldOver(typeAccumulator(x, tree.tpe) + tree.symbol, tree) + case tree: TypeTree => typeAccumulator(x, tree.tpe) + case _ => foldOver(x, tree) } treeAccumulator(Set.empty, term) } @@ -625,49 +625,38 @@ class QuoteMatcher(debug: Boolean) { case MatchResult.ClosedTree(tree) => new ExprImpl(tree, spliceScope) case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, typeArgs, env) => - if typeArgs.isEmpty then - val names: List[TermName] = argIds.map(_.symbol.name.asTermName) - val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr)) - val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe)) - val meth = newAnonFun(ctx.owner, methTpe) - def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { - val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap - val body = new TreeMap { - override def transform(tree: Tree)(using Context): Tree = - tree match - /* - * When matching a method call `f(0)` against a HOAS pattern `p(g)` where - * f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold - * `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion. - */ - case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform)) - case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) - case tree => super.transform(tree) - }.transform(tree) - TreeOps(body).changeNonLocalOwners(meth) - } - val hoasClosure = Closure(meth, bodyFn) - new ExprImpl(hoasClosure, spliceScope) - else - // TODO-18271: This implementation fails Typer.assertPositioned. - // We want to find safe way to generate poly function - val names: List[TermName] = argIds.map(_.symbol.name.asTermName) - val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr)) + val names: List[TermName] = argIds.map(_.symbol.name.asTermName) + val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr)) + val ptTypeVarSymbols = typeArgs.map(_.typeSymbol) + val methTpe = if typeArgs.isEmpty then + MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe)) + else val typeArgs1 = PolyType.syntheticParamNames(typeArgs.length) val bounds = typeArgs map (_ => TypeBounds.empty) - val fromSymbols = typeArgs.map(_.typeSymbol) val resultTypeExp = (pt: PolyType) => { - val argTypes1 = paramTypes.map(_.subst(fromSymbols, pt.paramRefs)) - val resultType1 = mapTypeHoles(patternTpe).subst(fromSymbols, pt.paramRefs) + val argTypes1 = paramTypes.map(_.subst(ptTypeVarSymbols, pt.paramRefs)) + val resultType1 = mapTypeHoles(patternTpe).subst(ptTypeVarSymbols, pt.paramRefs) MethodType(argTypes1, resultType1) } - val methTpe = PolyType(typeArgs1)(_ => bounds, resultTypeExp) - val meth = newAnonFun(ctx.owner, methTpe) - def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { - val typeArgs = lambdaArgss.head - val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.tail.head).toMap - val body = new TreeMap { + PolyType(typeArgs1)(_ => bounds, resultTypeExp) + + val meth = newAnonFun(ctx.owner, methTpe) + + def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { + val typeArgsMap = ptTypeVarSymbols.zip(lambdaArgss.head.map(_.tpe)).toMap + val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.tail.head).toMap + + val body = new TreeTypeMap( + typeMap = if typeArgs.isEmpty then IdentityTypeMap + else new TypeMap() { + override def apply(tp: Type): Type = tp match { + case tr: TypeRef if tr.prefix.eq(NoPrefix) => + env.get(tr.symbol).flatMap(typeArgsMap.get).getOrElse(tr) + case tp => mapOver(tp) + } + }, + treeMap = new TreeMap { override def transform(tree: Tree)(using Context): Tree = tree match /* @@ -678,13 +667,13 @@ class QuoteMatcher(debug: Boolean) { case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform)) case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) case tree => super.transform(tree) - } - .transform(tree) - .subst(fromSymbols, typeArgs.map(_.symbol)) - TreeOps(body).changeNonLocalOwners(meth) - } - val hoasClosure = Closure(meth, bodyFn) - new ExprImpl(hoasClosure, spliceScope) + }.transform + ).transform(tree) + + TreeOps(body).changeNonLocalOwners(meth) + } + val hoasClosure = Closure(meth, bodyFn).withSpan(tree.span) + new ExprImpl(hoasClosure, spliceScope) private inline def notMatched[T]: optional[T] = optional.break()