Skip to content

Commit

Permalink
Use infix extension to indicate right associativity with natural order
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasstucki committed Dec 6, 2023
1 parent e598bef commit 71aa636
Show file tree
Hide file tree
Showing 22 changed files with 133 additions and 130 deletions.
20 changes: 12 additions & 8 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -996,11 +996,18 @@ object desugar {

def badRightAssoc(problem: String) =
report.error(em"right-associative extension method $problem", mdef.srcPos)
() // extParamss ++ mdef.paramss
extParamss ++ mdef.paramss

rightParam match
case ValDefs(vparam :: Nil) =>
if !vparam.mods.is(Given) then
if vparam.mods.is(Given) then
badRightAssoc("cannot start with using clause")
else if mdef.mods.is(Infix) then
// New encoding:
// we keep the extension method as is and rely on the swap of arguments at call site
extParamss ++ mdef.paramss
else
// Old encoding:
// we merge the extension parameters with the method parameters,
// swapping the operator arguments:
// e.g.
Expand All @@ -1010,16 +1017,13 @@ object desugar {
// def %:[A](using B)[E](f: F)(c: C)(using D)(g: G)(using H): Res = ???
//
// If you change the names of the clauses below, also change them in right-associative-extension-methods.md
// val (leftTyParamsAndLeadingUsing, leftParamAndTrailingUsing) = extParamss.span(isUsingOrTypeParamClause)
() // leftTyParamsAndLeadingUsing ::: rightTyParams ::: rightParam :: leftParamAndTrailingUsing ::: paramss1
else
badRightAssoc("cannot start with using clause")
val (leftTyParamsAndLeadingUsing, leftParamAndTrailingUsing) = extParamss.span(isUsingOrTypeParamClause)
leftTyParamsAndLeadingUsing ::: rightTyParams ::: rightParam :: leftParamAndTrailingUsing ::: paramss1
case _ =>
badRightAssoc("must start with a single parameter")
case _ =>
// no value parameters, so not an infix operator.
() // extParamss ++ mdef.paramss
extParamss ++ mdef.paramss
extParamss ++ mdef.paramss
else
extParamss ++ mdef.paramss
).withMods(mdef.mods | ExtensionMethod)
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/Positioned.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import util.{SourceFile, SourcePosition, SrcPos}
import core.Contexts.*
import core.Decorators.*
import core.NameOps.*
import core.Flags.{JavaDefined, ExtensionMethod}
import core.Flags.{JavaDefined, ExtensionMethod, Infix}
import core.StdNames.nme
import ast.Trees.mods
import annotation.constructorOnly
Expand Down Expand Up @@ -215,8 +215,8 @@ abstract class Positioned(implicit @constructorOnly src: SourceFile) extends Src
check(tree.trailingParamss)
case tree: DefDef if tree.mods.is(ExtensionMethod) =>
tree.paramss match
// case vparams1 :: vparams2 :: rest if tree.name.isRightAssocOperatorName =>
// // omit check for right-associatiove extension methods; their parameters were swapped
case vparams1 :: vparams2 :: rest if tree.name.isRightAssocOperatorName && !tree.mods.is(Infix) =>
// omit check for right-associatiove extension methods; their parameters were swapped
case _ =>
check(tree.paramss)
check(tree.tpt)
Expand Down
58 changes: 29 additions & 29 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -933,35 +933,35 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
val coreSig =
if isExtension then
val paramss =
// if tree.name.isRightAssocOperatorName then
// // If you change the names of the clauses below, also change them in right-associative-extension-methods.md
// // we have the following encoding of tree.paramss:
// // (leftTyParams ++ leadingUsing
// // ++ rightTyParams ++ rightParam
// // ++ leftParam ++ trailingUsing ++ rest)
// // e.g.
// // extension [A](using B)(c: C)(using D)
// // def %:[E](f: F)(g: G)(using H): Res = ???
// // will have the following values:
// // - leftTyParams = List(`[A]`)
// // - leadingUsing = List(`(using B)`)
// // - rightTyParams = List(`[E]`)
// // - rightParam = List(`(f: F)`)
// // - leftParam = List(`(c: C)`)
// // - trailingUsing = List(`(using D)`)
// // - rest = List(`(g: G)`, `(using H)`)
// // we need to swap (rightTyParams ++ rightParam) with (leftParam ++ trailingUsing)
// val (leftTyParams, rest1) = tree.paramss.span(isTypeParamClause)
// val (leadingUsing, rest2) = rest1.span(isUsingClause)
// val (rightTyParams, rest3) = rest2.span(isTypeParamClause)
// val (rightParam, rest4) = rest3.splitAt(1)
// val (leftParam, rest5) = rest4.splitAt(1)
// val (trailingUsing, rest6) = rest5.span(isUsingClause)
// if leftParam.nonEmpty then
// leftTyParams ::: leadingUsing ::: leftParam ::: trailingUsing ::: rightTyParams ::: rightParam ::: rest6
// else
// tree.paramss // it wasn't a binary operator, after all.
// else
if tree.name.isRightAssocOperatorName && !tree.mods.is(Infix) && !tree.symbol.is(Infix) then
// If you change the names of the clauses below, also change them in right-associative-extension-methods.md
// we have the following encoding of tree.paramss:
// (leftTyParams ++ leadingUsing
// ++ rightTyParams ++ rightParam
// ++ leftParam ++ trailingUsing ++ rest)
// e.g.
// extension [A](using B)(c: C)(using D)
// def %:[E](f: F)(g: G)(using H): Res = ???
// will have the following values:
// - leftTyParams = List(`[A]`)
// - leadingUsing = List(`(using B)`)
// - rightTyParams = List(`[E]`)
// - rightParam = List(`(f: F)`)
// - leftParam = List(`(c: C)`)
// - trailingUsing = List(`(using D)`)
// - rest = List(`(g: G)`, `(using H)`)
// we need to swap (rightTyParams ++ rightParam) with (leftParam ++ trailingUsing)
val (leftTyParams, rest1) = tree.paramss.span(isTypeParamClause)
val (leadingUsing, rest2) = rest1.span(isUsingClause)
val (rightTyParams, rest3) = rest2.span(isTypeParamClause)
val (rightParam, rest4) = rest3.splitAt(1)
val (leftParam, rest5) = rest4.splitAt(1)
val (trailingUsing, rest6) = rest5.span(isUsingClause)
if leftParam.nonEmpty then
leftTyParams ::: leadingUsing ::: leftParam ::: trailingUsing ::: rightTyParams ::: rightParam ::: rest6
else
tree.paramss // it wasn't a binary operator, after all.
else
tree.paramss
val trailingParamss = paramss
.dropWhile(isUsingOrTypeParamClause)
Expand Down
12 changes: 8 additions & 4 deletions library/src/scala/IArray.scala
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,14 @@ object IArray:
def zipAll[T1 >: T, U](that: Iterable[U], thisElem: T1, thatElem: U): IArray[(T1, U)] = genericArrayOps(arr).zipAll(that, thisElem, thatElem)
def zipWithIndex: IArray[(T, Int)] = genericArrayOps(arr).zipWithIndex

extension [T, U >: T: ClassTag](arr: IArray[U])
def ++:(prefix: IterableOnce[T]): IArray[U] = genericArrayOps(arr).prependedAll(prefix)
def ++:(prefix: IArray[T]): IArray[U] = genericArrayOps(arr).prependedAll(prefix)
def +:(x: T): IArray[U] = genericArrayOps(arr).prepended(x)
extension [T, U >: T: ClassTag](prefix: IterableOnce[T])
def ++:(arr: IArray[U]): IArray[U] = genericArrayOps(arr).prependedAll(prefix)

extension [T, U >: T: ClassTag](prefix: IArray[T])
def ++:(arr: IArray[U]): IArray[U] = genericArrayOps(arr).prependedAll(prefix)

extension [T, U >: T: ClassTag](x: T)
def +:(arr: IArray[U]): IArray[U] = genericArrayOps(arr).prepended(x)

// For backwards compatibility with code compiled without -Yexplicit-nulls
private inline def mapNull[A, B](a: A, inline f: B): B =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ class HoverTypeSuite extends BaseHoverSuite:
|class C
|
|object Foo:
| extension [R](using A)(res: R)(using B)
| def %:[T](main: T)(using C): R = ???
| extension [T](using A)(main: T)(using B)
| def %:[R](res: R)(using C): R = ???
| given A with {}
| given B with {}
| given C with {}
Expand All @@ -162,7 +162,7 @@ class HoverTypeSuite extends BaseHoverSuite:
|end Foo
|""".stripMargin,
"""|Int
|extension [R](using A)(using B)(res: R) def %:[T](main: T)(using C): R""".stripMargin.hover
|extension [T](using A)(main: T) def %:[R](res: R)(using B)(using C): R""".stripMargin.hover
)

@Test def `using` =
Expand Down
4 changes: 2 additions & 2 deletions tests/neg-custom-args/captures/lazylists-exceptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ final class LazyCons[+T](val x: T, val xs: () => LazyList[T]^) extends LazyList[
def tail: LazyList[T]^{this} = xs()
end LazyCons

extension [A](xs1: => LazyList[A]^)
def #:(x: A): LazyList[A]^{xs1} =
extension [A](x: A)
def #:(xs1: => LazyList[A]^): LazyList[A]^{xs1} =
LazyCons(x, () => xs1)

def tabulate[A](n: Int)(gen: Int => A): LazyList[A]^{gen} =
Expand Down
4 changes: 2 additions & 2 deletions tests/neg/i13075.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ object Implementing_Tuples:
type *:[H, T <: Tup] = ConsTup[H, T] // for type matching
type EmptyTup = EmptyTup.type // for type matching

extension [T <: Tup](tail: T)
def *:[H](head: H) = ConsTup(head, tail)
extension [H](head: H)
def *:[T <: Tup](tail: T) = ConsTup(head, tail)

type Fold[T <: Tup, Seed, F[_,_]] = T match
case EmptyTup => Seed
Expand Down
3 changes: 1 addition & 2 deletions tests/neg/i9562.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@ object Unrelated:
def h1: Int = foo // error
def h2: Int = h1 + 1 // OK
def h3: Int = g // error
extension (x: Int)
def ++:(f: Foo): Int = f.h1 + x // OK
def ++: (x: Int): Int = h1 + x // OK
4 changes: 2 additions & 2 deletions tests/pos-custom-args/captures/lazylists-exceptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ extension [A](xs: LzyList[A]^)
if n == 0 then xs else xs.tail.drop(n - 1)
end extension

extension [A](xs1: => LzyList[A]^)
def #:(x: A): LzyList[A]^{xs1} =
extension [A](x: A)
def #:(xs1: => LzyList[A]^): LzyList[A]^{xs1} =
LzyCons(x, () => xs1)

def lazyCons[A](x: A, xs1: => LzyList[A]^): LzyList[A]^{xs1} =
Expand Down
4 changes: 2 additions & 2 deletions tests/pos-custom-args/captures/logger.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ final class LazyCons[+T](val x: T, val xs: () => LazyList[T]^) extends LazyList[
def tail: LazyList[T]^{this} = xs()
end LazyCons

extension [A](xs1: => LazyList[A]^)
def #::(x: A): LazyList[A]^{xs1} =
extension [A](x: A)
def #::(xs1: => LazyList[A]^): LazyList[A]^{xs1} =
LazyCons(x, () => xs1)

extension [A](xs: LazyList[A]^)
Expand Down
7 changes: 4 additions & 3 deletions tests/pos-custom-args/captures/strictlists.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ extension [A](xs: StrictList[A])
def concat(ys: StrictList[A]): StrictList[A] =
if xs.isEmpty then ys
else xs.head #: xs.tail.concat(ys)

def #:(x: A): StrictList[A] =
StrictCons(x, xs)
end extension

extension [A](x: A)
def #:(xs1: StrictList[A]): StrictList[A] =
StrictCons(x, xs1)

def tabulate[A](n: Int)(gen: Int => A) =
def recur(i: Int): StrictList[A] =
if i == n then StrictNil
Expand Down
2 changes: 1 addition & 1 deletion tests/pos/i19197.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
extension (tuple: Tuple)
def **:[T >: tuple.type <: Tuple, H](x: H): H *: T = ???
infix def **:[T >: tuple.type <: Tuple, H](x: H): H *: T = ???

def test1: (Int, String, Char) = 1 **: ("a", 'b')
def test2: (Int, String, Char) = ("a", 'b').**:(1)
Expand Down
4 changes: 1 addition & 3 deletions tests/pos/i9562.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,4 @@ object Unrelated:
extension (f: Foo)
def h1: Int = 0
def h2: Int = h1 + 1 // OK

extension (x: Int)
def ++: (f: Foo): Int = f.h2 + x // OK
def ++: (x: Int): Int = h2 + x // OK
2 changes: 1 addition & 1 deletion tests/pos/reference/extension-methods.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ object ExtMethods:
assert(circle.circumference == circumference(circle))

extension (x: String) def < (y: String) = x.compareTo(y) < 0
extension [Elem](xs: Seq[Elem]) def #: (x: Elem) = x +: xs
extension [Elem](x: Elem) def #: (xs: Seq[Elem]) = x +: xs
extension (x: Number) infix def min (y: Number) = x

assert("a" < "bb")
Expand Down
5 changes: 1 addition & 4 deletions tests/run/errorhandling/Result.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,12 @@ object Result:
case (Err(e), Ok(_)) => Err(e :: Nil)
case (Err(e1), Err(e2)) => Err(e1 :: e2 :: Nil)

end extension

extension [U <: Tuple, E](other: Result[U, List[E]])
/** Validate both `r` and `other`; return a tuple of successes or a list of failures.
* Unlike with `zip`, the right hand side `other` must be a `Result` returning a `Tuple`,
* and the left hand side is added to it. See `Result.empty` for a convenient
* right unit of chains of `*:`s.
*/
def *: [T](r: Result[T, E]): Result[T *: U, List[E]] = (r, other) match
def *: [U <: Tuple](other: Result[U, List[E]]): Result[T *: U, List[E]] = (r, other) match
case (Ok(x), Ok(ys)) => Ok(x *: ys)
case (Ok(_), es: Err[?]) => es
case (Err(e), Ok(_)) => Err(e :: Nil)
Expand Down
4 changes: 2 additions & 2 deletions tests/run/export-in-extension.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ object O:
export cm.*
def succ: Int = x + 1
def succ2: Int = succ + 1
def ::: (y: Int) = y - x
def ::: (y: Int) = x - y

object O2:
import O.C
Expand All @@ -20,7 +20,7 @@ object O2:
export cm.{bar, baz, bam, ::}
def succ: Int = x + 1
def succ2: Int = succ + 1
def ::: (y: Int) = y - x
def ::: (y: Int) = x - y

@main def Test =
import O.*
Expand Down
8 changes: 4 additions & 4 deletions tests/run/i11583.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ class Env:
// */
// def &&:[T <: ctx.Term](trm: T)(ext: env.Extra): (ctx.Type, T, env.Extra) = (tpe, trm, ext)

extension [Ctx <: Context, T <: Boolean](using ctx: Ctx)(trm: T)(using env: Env)
def :#:(tpe: String)(ext: env.Extra): (String, T, env.Extra) = (tpe, trm, ext)
extension [Ctx <: Context](using ctx: Ctx)(tpe: String)(using env: Env)
def :#:[T <: Boolean](trm: T)(ext: env.Extra): (String, T, env.Extra) = (tpe, trm, ext)

extension [T <: Tuple](t: T)
def :*:[A](a: A): A *: T = a *: t
extension [A](a: A)
def :*:[T <: Tuple](t: T): A *: T = a *: t

@main def Test =

Expand Down
2 changes: 1 addition & 1 deletion tests/run/i9530.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ trait Scope:
extension (using s: Scope)(expr: s.Expr)
def show = expr.toString
def eval = s.value(expr)
def *: (other: s.Expr) = s.combine(other, expr)
def *: (other: s.Expr) = s.combine(expr, other)

def f(using s: Scope)(x: s.Expr): (String, s.Value) =
(x.show, x.eval)
Expand Down
2 changes: 1 addition & 1 deletion tests/run/instances.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ object Test extends App {
extension [T](xs: List[List[T]])
def flattened = xs.foldLeft[List[T]](Nil)(_ ++ _)

extension [T](xs: Seq[T]) def :: (x: T) = x +: xs
extension [T](x: T) def :: (xs: Seq[T]) = x +: xs

val ss: Seq[Int] = List(1, 2, 3)
val ss1 = 0 :: ss
Expand Down
Loading

0 comments on commit 71aa636

Please sign in to comment.