Skip to content

Commit

Permalink
match: extract discriminant
Browse files Browse the repository at this point in the history
  • Loading branch information
dark-flames committed Feb 17, 2025
1 parent 6bf52d9 commit aefb0e9
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 70 deletions.
13 changes: 8 additions & 5 deletions base/src/main/java/org/aya/resolve/salt/Desalt.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package org.aya.resolve.salt;

import kala.collection.immutable.ImmutableSeq;
import kala.control.Option;
import org.aya.generic.term.SortKind;
import org.aya.resolve.ResolveInfo;
import org.aya.syntax.concrete.Expr;
Expand Down Expand Up @@ -47,7 +46,9 @@ public final class Desalt implements PosedUnaryOperator<Expr> {
}
case Expr.Match match -> {
return match.update(
match.discriminant().map(e -> e.descent(this)),
match.discriminant().map(d -> new Expr.Match.Discriminant(
d.discr().descent(this), d.asBinding(), d.isElim()
)),
match.clauses().map(clause -> clause.descent(this, pattern)),
match.returns() != null ? match.returns().descent(this) : null
);
Expand Down Expand Up @@ -98,10 +99,12 @@ public final class Desalt implements PosedUnaryOperator<Expr> {
// the var with prime are renamed vars

realBody = new WithPos<>(sourcePos, new Expr.Match(
lamTele.map(x -> new WithPos<>(x.definition(), new Expr.Ref(x))),
lamTele.map(x -> new Expr.Match.Discriminant(
new WithPos<>(x.definition(), new Expr.Ref(x)),
null,
true
)),
ImmutableSeq.of(lam.clause()),
ImmutableSeq.fill(lamTele.size(), Option.none()),
ImmutableSeq.fill(lamTele.size(), true),
null
));
}
Expand Down
10 changes: 6 additions & 4 deletions base/src/main/java/org/aya/resolve/visitor/ExprResolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,13 @@ public ExprResolver(@NotNull Context ctx, boolean allowGeneralizing) {
yield letOpen.update(letOpen.body().descent(enter(context)));
}
case Expr.Match match -> {
var discriminant = match.discriminant().map(x -> x.descent(this));
var discriminant = match.discriminant().map(
d -> new Expr.Match.Discriminant(d.discr().descent(this), d.asBinding(), d.isElim())
);
var returnsCtx = ctx;
for (var binding : match.asBindings()) {
if (binding.isDefined()) {
returnsCtx = returnsCtx.bind(binding.get());
for (var discr : match.discriminant()) {
if (discr.asBinding() != null) {
returnsCtx = returnsCtx.bind(discr.asBinding());
}
}
var returns = match.returns() != null ? match.returns().descent(enter(returnsCtx)) : null;
Expand Down
53 changes: 21 additions & 32 deletions base/src/main/java/org/aya/tyck/ExprTycker.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,17 @@ && whnf(type) instanceof DataCall dataCall
element -> inherit(element, elementTy).wellTyped()));
yield new Jdg.Default(new ListTerm(results, recog, dataCall), type);
}
case Expr.Match(var discriminant, var clauses, var asBindings, var elims, var returns) -> {
var wellArgs = discriminant.map(this::synthesize);
case Expr.Match(var discriminant, var clauses, var returns) -> {
var wellArgs = discriminant.map(d -> synthesize(d.discr()));
Term storedTy;
// Type check the type annotation
if (returns != null) {
if (asBindings.isEmpty()) {
unifyTyReported(type, storedTy = ty(returns), returns);
} else {
storedTy = matchReturnTy(wellArgs, asBindings, elims, returns);
unifyTyReported(type, storedTy.instTele(wellArgs.view().map(Jdg::wellTyped)), returns);
}
storedTy = matchReturnTy(discriminant, wellArgs, returns);
unifyTyReported(type, storedTy.instTele(wellArgs.view().map(Jdg::wellTyped)), returns);
} else {
storedTy = type;
}
yield new Jdg.Default(match(discriminant, expr.sourcePos(), clauses, wellArgs, elims, storedTy), type);
yield new Jdg.Default(match(discriminant, expr.sourcePos(), clauses, wellArgs, storedTy), type);
}
case Expr.Let let -> checkLet(let, e -> inherit(e, type));
case Expr.Partial(var element) -> whnf(type) instanceof PartialTyTerm(var r, var s, var A)
Expand All @@ -181,19 +177,16 @@ && whnf(type) instanceof DataCall dataCall
}

private @NotNull Term matchReturnTy(
ImmutableSeq<Jdg> wellArgs, ImmutableSeq<Option<LocalVar>> asBindings, ImmutableSeq<Boolean> elims,
WithPos<Expr> returns
ImmutableSeq<Expr.Match.Discriminant> discriminant,
ImmutableSeq<Jdg> wellArgs, WithPos<Expr> returns
) {
try (var _ = subscope()) {
var tele = MutableList.<LocalVar>create();
wellArgs.forEachWith(asBindings.zip(elims), (discr, t) -> {
var as = t.component1();
var isElim = t.component2();

if (as.isDefined()) {
localCtx().put(as.get(), discr.type());
tele.append(as.get());
} else if (isElim && discr.wellTyped() instanceof FreeTerm(LocalVar name)) {
wellArgs.forEachWith(discriminant, (arg, discr) -> {
if (discr.asBinding() != null) {
localCtx().put(discr.asBinding(), arg.type());
tele.append(discr.asBinding());
} else if (discr.isElim() && arg.wellTyped() instanceof FreeTerm(LocalVar name)) {
tele.append(name);

Check warning on line 190 in base/src/main/java/org/aya/tyck/ExprTycker.java

View check run for this annotation

Codecov / codecov/patch

base/src/main/java/org/aya/tyck/ExprTycker.java#L190

Added line #L190 was not covered by tests
} else {
tele.append(new LocalVar(Constants.ANONYMOUS_PREFIX));
Expand All @@ -205,15 +198,15 @@ && whnf(type) instanceof DataCall dataCall
}

private @NotNull MatchCall match(
ImmutableSeq<WithPos<Expr>> discriminant, @NotNull SourcePos exprPos,
ImmutableSeq<Pattern.Clause> clauses, ImmutableSeq<Jdg> wellArgs, ImmutableSeq<Boolean> elims, Term type
ImmutableSeq<Expr.Match.Discriminant> discriminant, @NotNull SourcePos exprPos,
ImmutableSeq<Pattern.Clause> clauses, ImmutableSeq<Jdg> wellArgs, Term type
) {
var elimVarTele = MutableList.<LocalVar>create();
var paramTele = MutableList.<Param>create();
wellArgs.forEachWith(elims, (arg, elim) -> {
wellArgs.forEachWith(discriminant, (arg, discr) -> {
var paramTy = arg.type().bindTele(elimVarTele.view());

if (elim && arg.wellTyped() instanceof FreeTerm(LocalVar name)) {
if (discr.isElim() && arg.wellTyped() instanceof FreeTerm(LocalVar name)) {
elimVarTele.append(name);
} else {
elimVarTele.append(new LocalVar(Constants.ANONYMOUS_PREFIX));
Expand All @@ -227,7 +220,7 @@ && whnf(type) instanceof DataCall dataCall
paramTele.toSeq(),
new DepTypeTerm.Unpi(ImmutableSeq.empty(), type),
ImmutableSeq.fill(discriminant.size(), i ->
new LocalVar("match" + i, discriminant.get(i).sourcePos(), GenerateKind.Basic.Tyck)),
new LocalVar("match" + i, discriminant.get(i).discr().sourcePos(), GenerateKind.Basic.Tyck)),
ImmutableSeq.empty(), clauses);
var wellClauses = clauseTycker.check(exprPos).wellTyped().matchingsView();

Expand Down Expand Up @@ -493,16 +486,12 @@ case DepTypeTerm(var kind, var param, var body) when kind == DTKind.Sigma -> {

yield new Jdg.Default(new NewTerm(call), call);
}
case Expr.Match(var discriminant, var clauses, var asBindings, var elims, var returns) -> {
var wellArgs = discriminant.map(this::synthesize);
case Expr.Match(var discriminant, var clauses, var returns) -> {
var wellArgs = discriminant.map(d -> synthesize(d.discr()));
if (returns == null) yield fail(expr.data(), new MatchMissingReturnsError(expr));
// Type check the type annotation
Term type;
if (asBindings.isEmpty()) type = ty(returns);
else {
type = matchReturnTy(wellArgs, asBindings, elims, returns);
}
yield new Jdg.Default(match(discriminant, expr.sourcePos(), clauses, wellArgs, elims, type), type);
Term type = matchReturnTy(discriminant, wellArgs, returns);
yield new Jdg.Default(match(discriminant, expr.sourcePos(), clauses, wellArgs, type), type);
}
case Expr.Unresolved _ -> Panic.unreachable();
default -> fail(expr.data(), new NoRuleError(expr, null));
Expand Down
30 changes: 10 additions & 20 deletions producer/src/main/java/org/aya/producer/AyaProducer.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import kala.control.Either;
import kala.control.Option;
import kala.function.BooleanObjBiFunction;
import kala.tuple.Tuple;
import kala.value.MutableValue;
import org.aya.generic.Constants;
import org.aya.generic.Modifier;
Expand Down Expand Up @@ -603,30 +602,21 @@ private record DeclNameOrInfix(@NotNull WithPos<String> name, @Nullable OpDecl.O
var clauses = node.child(CLAUSES);
var bare = clauses.childrenOfType(BARE_CLAUSE).map(this::bareOrBarredClause);
var barred = clauses.childrenOfType(BARRED_CLAUSE).map(this::bareOrBarredClause);
var discrList = node.child(MATCH_DISCR_LIST).child(COMMA_SEP).childrenOfType(MATCH_DISCR).map(d -> {
var elim = d.peekChild(KW_ELIM) != null;
Option<LocalVar> asBinding = d.peekChild(KW_AS) == null ?
Option.none() :
Option.some(LocalVar.from(weakId(d.child(WEAK_ID))));
var discr = expr(d.child(EXPR));

return Tuple.of(
var discrList = node.child(MATCH_DISCR_LIST).child(COMMA_SEP).childrenOfType(MATCH_DISCR)
.map(d -> new Expr.Match.Discriminant(
expr(d.child(EXPR)),
asBinding,
elim
);
}).toSeq();
d.peekChild(KW_AS) != null ?
LocalVar.from(weakId(d.child(WEAK_ID))) : null,
d.peekChild(KW_ELIM) != null
)).toSeq();

if (!discrList.allMatch(d -> !d.component3() || d.component1().data() instanceof Expr.Unresolved)) {
if (!discrList.allMatch(d -> !d.isElim() || d.discr().data() instanceof Expr.Unresolved)) {
reporter.report(new ParseError(pos, "Elimination match must be on variables"));
throw new ParsingInterruptedException();
} else if (!discrList.allMatch(d -> d.component2().isEmpty() || !d.component3())) {
} else if (!discrList.allMatch(d -> d.asBinding() == null || !d.isElim())) {
reporter.report(new ParseError(pos, "Elimination match could not be combined with as-binding"));
throw new ParsingInterruptedException();
}
var discr = discrList.map(Tuple::component1).toSeq();
var asBindings = discrList.map(Tuple::component2).toSeq();
var elims = discrList.map(Tuple::component3).toSeq();
var matchType = node.peekChild(MATCH_TYPE);

WithPos<Expr> returns = null;
Expand All @@ -635,8 +625,8 @@ private record DeclNameOrInfix(@NotNull WithPos<String> name, @Nullable OpDecl.O
if (returnsNode != null) returns = expr(returnsNode);
}
return new WithPos<>(pos, new Expr.Match(
discr,
bare.concat(barred).toSeq(), asBindings, elims,
discrList,
bare.concat(barred).toSeq(),
returns
));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ yield visitCalls(null, fn, (_, l) -> l.toDoc(options), outer,
);
case Expr.New neu -> Doc.sep(KW_NEW, term(Outer.Free, neu.classCall()));
case Expr.Match match -> {
var deltaDoc = match.discriminant().map(x -> term(Outer.Free, x));
var deltaDoc = match.discriminant().map(x -> term(Outer.Free, x.discr()));
var prefix = Doc.sep(KW_MATCH, Doc.commaList(deltaDoc));
var clauseDoc = visitClauses(match.clauses());

Expand Down
19 changes: 11 additions & 8 deletions syntax/src/main/java/org/aya/syntax/concrete/Expr.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import kala.collection.immutable.ImmutableSeq;
import kala.collection.mutable.MutableArrayList;
import kala.control.Either;
import kala.control.Option;
import kala.value.MutableValue;
import org.aya.generic.AyaDocile;
import org.aya.generic.Nested;
Expand Down Expand Up @@ -583,28 +582,32 @@ record LetOpen(
}

record Match(
@NotNull ImmutableSeq<WithPos<Expr>> discriminant,
@NotNull ImmutableSeq<Discriminant> discriminant,
@NotNull ImmutableSeq<Pattern.Clause> clauses,
@NotNull ImmutableSeq<Option<LocalVar>> asBindings,
@NotNull ImmutableSeq<Boolean> elims,
@Nullable WithPos<Expr> returns
) implements Expr {
public record Discriminant(
@NotNull WithPos<Expr> discr,
@Nullable LocalVar asBinding,
@NotNull Boolean isElim
) { }

public @NotNull Match update(
@NotNull ImmutableSeq<WithPos<Expr>> discriminant,
@NotNull ImmutableSeq<Discriminant> discriminant,
@NotNull ImmutableSeq<Pattern.Clause> clauses,
@Nullable WithPos<Expr> returns
) {
return this.discriminant.sameElements(discriminant, true)
&& this.clauses.sameElements(clauses, true) && this.returns == returns
? this : new Match(discriminant, clauses, asBindings, elims, returns);
? this : new Match(discriminant, clauses, returns);
}

@Override public @NotNull Expr descent(@NotNull PosedUnaryOperator<@NotNull Expr> f) {
return descent(f, (_, p) -> p);
}

public @NotNull Expr descent(@NotNull PosedUnaryOperator<@NotNull Expr> f, @NotNull PosedUnaryOperator<@NotNull Pattern> g) {
return update(discriminant.map(x -> x.descent(f)),
return update(discriminant.map(d -> new Discriminant(d.discr().descent(f), d.asBinding(), d.isElim())),

Check warning on line 610 in syntax/src/main/java/org/aya/syntax/concrete/Expr.java

View check run for this annotation

Codecov / codecov/patch

syntax/src/main/java/org/aya/syntax/concrete/Expr.java#L610

Added line #L610 was not covered by tests
clauses.map(x -> x.descent(f, g)),
returns != null ? returns.descent(f) : null);
}
Expand All @@ -613,7 +616,7 @@ record Match(
///
/// @see StmtVisitor#visitExpr
@Override public void forEach(@NotNull PosedConsumer<@NotNull Expr> f) {
discriminant.forEach(f::accept);
discriminant.forEach(d -> f.accept(d.discr()));
clauses.forEach(clause -> clause.forEach(f, (_, _) -> { }));
if (returns != null) f.accept(returns);
}
Expand Down

0 comments on commit aefb0e9

Please sign in to comment.