From 282e4ad11cfcb254109c413cd18e12f2d7e94381 Mon Sep 17 00:00:00 2001 From: Yuito Murase Date: Mon, 31 Jul 2023 16:50:31 +0900 Subject: [PATCH] WIP: Pattern match against hoas pattern wtih type vars --- .../quoted/runtime/impl/QuoteMatcher.scala | 124 ++++++++++++++---- 1 file changed, 97 insertions(+), 27 deletions(-) diff --git a/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala b/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala index 3ef1c854064e..54e07baa1bd1 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala @@ -288,7 +288,42 @@ class QuoteMatcher(debug: Boolean) { val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v)) withEnv(captureEnv) { scrutinee match - case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), env) + case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), Nil, env) + case _ => notMatched + } + + /* Higher order term hole */ + // Matches an open term and wraps it into a lambda that provides the free variables + case Apply(TypeApply(Ident(_), List(TypeTree(), targs)), SeqLiteral(args, _) :: Nil) + if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHoleWithTypes) => + + /* Some of method symbols in arguments of higher-order term hole are eta-expanded. + * e.g. + * g: (Int) => Int + * => { + * def $anonfun(y: Int): Int = g(y) + * closure($anonfun) + * } + * + * f: (using Int) => Int + * => f(using x) + * This function restores the symbol of the original method from + * the eta-expanded function. + */ + def getCapturedIdent(arg: Tree)(using Context): Ident = + arg match + case id: Ident => id + case Apply(fun, _) => getCapturedIdent(fun) + case Block((ddef: DefDef) :: _, _: Closure) => getCapturedIdent(ddef.rhs) + case Typed(expr, _) => getCapturedIdent(expr) + + val env = summon[Env] + val capturedIds = args.map(getCapturedIdent) + val capturedSymbols = capturedIds.map(_.symbol) + val captureEnv = env.filter((k, v) => !capturedSymbols.contains(v)) + withEnv(captureEnv) { + scrutinee match + case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, capturedIds, args.map(_.tpe), targs, env) case _ => notMatched } @@ -558,9 +593,10 @@ class QuoteMatcher(debug: Boolean) { * @param patternTpe Type of the pattern hole (from the pattern) * @param argIds Identifiers of HOAS arguments (from the pattern) * @param argTypes Eta-expanded types of HOAS arguments (from the pattern) + * @param typeArgs type arguments from the pattern * @param env Mapping between scrutinee and pattern variables */ - case OpenTree(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env) + case OpenTree(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], typeArgs: List[Type], env: Env) /** Return the expression that was extracted from a hole. * @@ -573,29 +609,63 @@ class QuoteMatcher(debug: Boolean) { def toExpr(mapTypeHoles: Type => Type, spliceScope: Scope)(using Context): Expr[Any] = this match case MatchResult.ClosedTree(tree) => new ExprImpl(tree, spliceScope) - case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env) => - val names: List[TermName] = argIds.map(_.symbol.name.asTermName) - val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr)) - val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe)) - val meth = newAnonFun(ctx.owner, methTpe) - def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { - val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap - val body = new TreeMap { - override def transform(tree: Tree)(using Context): Tree = - tree match - /* - * When matching a method call `f(0)` against a HOAS pattern `p(g)` where - * f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold - * `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion. - */ - case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform)) - case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) - case tree => super.transform(tree) - }.transform(tree) - TreeOps(body).changeNonLocalOwners(meth) - } - val hoasClosure = Closure(meth, bodyFn) - new ExprImpl(hoasClosure, spliceScope) + case MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, typeArgs, env) => + if typeArgs.isEmpty then + val names: List[TermName] = argIds.map(_.symbol.name.asTermName) + val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr)) + val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe)) + val meth = newAnonFun(ctx.owner, methTpe) + def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { + val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap + val body = new TreeMap { + override def transform(tree: Tree)(using Context): Tree = + tree match + /* + * When matching a method call `f(0)` against a HOAS pattern `p(g)` where + * f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold + * `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion. + */ + case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform)) + case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) + case tree => super.transform(tree) + }.transform(tree) + TreeOps(body).changeNonLocalOwners(meth) + } + val hoasClosure = Closure(meth, bodyFn) + new ExprImpl(hoasClosure, spliceScope) + else + val names: List[TermName] = argIds.map(_.symbol.name.asTermName) + val paramTypes = argTypes.map(tpe => mapTypeHoles(tpe.widenTermRefExpr)) + + val typeArgs1 = PolyType.syntheticParamNames(typeArgs.length) + val bounds = typeArgs map (_ => TypeBounds.empty) + val resultTypeExp = (pt: PolyType) => { + val fromSymbols = typeArgs.map(_.typeSymbol) + val argTypes1 = argTypes.map(_.subst(fromSymbols, pt.paramRefs)) + val resultType1 = mapTypeHoles(patternTpe).subst(fromSymbols, pt.paramRefs) + MethodType(argTypes1, resultType1) + } + val methTpe = PolyType(typeArgs1)(_ => bounds, resultTypeExp) + val meth = newAnonFun(ctx.owner, methTpe) + // TODO-18271 + def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { + val argsMap = argIds.view.map(_.symbol).zip(lambdaArgss.head).toMap + val body = new TreeMap { + override def transform(tree: Tree)(using Context): Tree = + tree match + /* + * When matching a method call `f(0)` against a HOAS pattern `p(g)` where + * f has a method type `(x: Int): Int` and `f` maps to `g`, `p` should hold + * `g.apply(0)` because the type of `g` is `Int => Int` due to eta expansion. + */ + case Apply(fun, args) if env.contains(tree.symbol) => transform(fun).select(nme.apply).appliedToArgs(args.map(transform)) + case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) + case tree => super.transform(tree) + }.transform(tree) + TreeOps(body).changeNonLocalOwners(meth) + } + val hoasClosure = Closure(meth, bodyFn) + new ExprImpl(hoasClosure, spliceScope) private inline def notMatched[T]: optional[T] = optional.break() @@ -606,8 +676,8 @@ class QuoteMatcher(debug: Boolean) { private inline def matched(tree: Tree)(using Context): MatchingExprs = Seq(MatchResult.ClosedTree(tree)) - private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], env: Env)(using Context): MatchingExprs = - Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, env)) + private def matchedOpen(tree: Tree, patternTpe: Type, argIds: List[Tree], argTypes: List[Type], typeArgs: List[Type], env: Env)(using Context): MatchingExprs = + Seq(MatchResult.OpenTree(tree, patternTpe, argIds, argTypes, typeArgs, env)) extension (self: MatchingExprs) /** Concatenates the contents of two successful matchings */