Skip to content

Commit

Permalink
Fix bugs in QuoteMatcher::MatchResult::toExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
zeptometer committed Aug 7, 2023
1 parent 6764d5e commit 3cb8b72
Showing 1 changed file with 43 additions and 54 deletions.
97 changes: 43 additions & 54 deletions compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.core.StdNames.nme
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.util.optional
import dotty.tools.dotc.ast.TreeTypeMap

/** Matches a quoted tree against a quoted pattern tree.
* A quoted pattern tree may have type and term holes in addition to normal terms.
Expand Down Expand Up @@ -319,9 +320,9 @@ class QuoteMatcher(debug: Boolean) {

val env = summon[Env]
val capturedIds = args.map(getCapturedIdent)
val capturedSymbols = capturedIds.map(_.symbol)
val capturedTargs = unrollHkNestedPairsTypeTree(targs)
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v) && !capturedTargs.map(_.symbol).contains(v))
val capturedSymbols = Set.from(capturedIds.map(_.symbol) ++ capturedTargs.map(_.symbol))
val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v))
withEnv(captureEnv) {
scrutinee match
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), capturedTargs.map(_.tpe), env)
Expand Down Expand Up @@ -581,18 +582,17 @@ class QuoteMatcher(debug: Boolean) {
/** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */
def freePatternVars(term: Tree)(using Env, Context): Set[Symbol] =
val typeAccumulator = new TypeAccumulator[Set[Symbol]] {
def apply(x: Set[Symbol], tp: Type): Set[Symbol] =
if summon[Env].contains(tp.typeSymbol) then
foldOver(x + tp.typeSymbol, tp)
else
foldOver(x, tp)
def apply(x: Set[Symbol], tp: Type): Set[Symbol] = tp match
case tp: TypeRef if summon[Env].contains(tp.typeSymbol) => foldOver(x + tp.typeSymbol, tp)
case tp: TermRef if summon[Env].contains(tp.termSymbol) => foldOver(x + tp.termSymbol, tp)
case _ => foldOver(x, tp)
}
val treeAccumulator = new TreeAccumulator[Set[Symbol]] {
def apply(x: Set[Symbol], tree: Tree)(using Context): Set[Symbol] =
val tvars = typeAccumulator(Set.empty, tree.tpe)
tree match
case tree: Ident if summon[Env].contains(tree.symbol) => foldOver(x ++ tvars + tree.symbol, tree)
case _ => foldOver(x ++ tvars, tree)
case tree: Ident if summon[Env].contains(tree.symbol) => foldOver(typeAccumulator(x, tree.tpe) + tree.symbol, tree)
case tree: TypeTree => typeAccumulator(x, tree.tpe)
case _ => foldOver(x, tree)
}
treeAccumulator(Set.empty, term)
}
Expand Down Expand Up @@ -625,49 +625,38 @@ class QuoteMatcher(debug: Boolean) {
case MatchResult.ClosedTree(tree) =>
new ExprImpl(tree, spliceScope)
case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, typeArgs, env) =>
if typeArgs.isEmpty then
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
/*
* When matching a method call `f(0)` against a HOAS pattern `p(g)` where
* f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold
* `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion.
*/
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(tree)
TreeOps(body).changeNonLocalOwners(meth)
}
val hoasClosure = Closure(meth, bodyFn)
new ExprImpl(hoasClosure, spliceScope)
else
// TODO-18271: This implementation fails Typer.assertPositioned.
// We want to find safe way to generate poly function
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
val names: List[TermName] = argIds.map(_.symbol.name.asTermName)
val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr))
val ptTypeVarSymbols = typeArgs.map(_.typeSymbol)

val methTpe = if typeArgs.isEmpty then
MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
else
val typeArgs1 = PolyType.syntheticParamNames(typeArgs.length)
val bounds = typeArgs map (_ => TypeBounds.empty)
val fromSymbols = typeArgs.map(_.typeSymbol)
val resultTypeExp = (pt: PolyType) => {
val argTypes1 = paramTypes.map(_.subst(fromSymbols, pt.paramRefs))
val resultType1 = mapTypeHoles(patternTpe).subst(fromSymbols, pt.paramRefs)
val argTypes1 = paramTypes.map(_.subst(ptTypeVarSymbols, pt.paramRefs))
val resultType1 = mapTypeHoles(patternTpe).subst(ptTypeVarSymbols, pt.paramRefs)
MethodType(argTypes1, resultType1)
}
val methTpe = PolyType(typeArgs1)(_ => bounds, resultTypeExp)
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val typeArgs = lambdaArgss.head
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.tail.head).toMap
val body = new TreeMap {
PolyType(typeArgs1)(_ => bounds, resultTypeExp)

val meth = newAnonFun(ctx.owner, methTpe)

def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val typeArgsMap = ptTypeVarSymbols.zip(lambdaArgss.head.map(_.tpe)).toMap
val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.tail.head).toMap

val body = new TreeTypeMap(
typeMap = if typeArgs.isEmpty then IdentityTypeMap
else new TypeMap() {
override def apply(tp: Type): Type = tp match {
case tr: TypeRef if tr.prefix.eq(NoPrefix) =>
env.get(tr.symbol).flatMap(typeArgsMap.get).getOrElse(tr)
case tp => mapOver(tp)
}
},
treeMap = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
/*
Expand All @@ -678,13 +667,13 @@ class QuoteMatcher(debug: Boolean) {
case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform))
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}
.transform(tree)
.subst(fromSymbols, typeArgs.map(_.symbol))
TreeOps(body).changeNonLocalOwners(meth)
}
val hoasClosure = Closure(meth, bodyFn)
new ExprImpl(hoasClosure, spliceScope)
}.transform
).transform(tree)

TreeOps(body).changeNonLocalOwners(meth)
}
val hoasClosure = Closure(meth, bodyFn).withSpan(tree.span)
new ExprImpl(hoasClosure, spliceScope)

private inline def notMatched[T]: optional[T] =
optional.break()
Expand Down

0 comments on commit 3cb8b72

Please sign in to comment.