diff --git a/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabApplyCustom.scala b/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabApplyCustom.scala index bc50c01..0ddb4d9 100644 --- a/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabApplyCustom.scala +++ b/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/ElabApplyCustom.scala @@ -16,6 +16,7 @@ import com.whatsapp.eqwalizer.ast.CompilerMacro class ElabApplyCustom(pipelineContext: PipelineContext) { private lazy val elab = pipelineContext.elab private lazy val elabApply = pipelineContext.elabApply + private lazy val elabPat = pipelineContext.elabPat private lazy val check = pipelineContext.check private lazy val subtype = pipelineContext.subtype private lazy val narrow = pipelineContext.narrow @@ -37,6 +38,7 @@ class ElabApplyCustom(pipelineContext: PipelineContext) { RemoteId("lists", "flatten", 2), RemoteId("lists", "keysort", 2), RemoteId("lists", "keystore", 4), + RemoteId("lists", "partition", 2), RemoteId("maps", "filter", 2), RemoteId("maps", "filtermap", 2), RemoteId("maps", "find", 2), @@ -218,6 +220,29 @@ class ElabApplyCustom(pipelineContext: PipelineContext) { val resTy = ListType(subtype.join(inTupleTy, replacementCoercedTy)) (resTy, env1) + case RemoteId("lists", "partition", 2) => + val List(pred, list) = args + val List(predTy, listTy) = argTys + val coercedListTy = coerce(list, listTy, ListType(AnyType)) + val elemTy = narrow.asListType(coercedListTy).get.t + val expFunTy = FunType(Nil, List(elemTy), booleanType) + pred match { + case lambda: Lambda if Predicates.booleanReturnClauses(lambda.clauses) => + check.checkLambda(lambda, expFunTy, env1) + val (trueClause, falseClause) = Predicates.getTrueFalseReturnClauses(lambda.clauses) + val lamEnv = lambda.name.map(name => env.updated(name, expFunTy)).getOrElse(env1) + val List(trueEnv, falseEnv) = occurrence.clausesEnvs(List(trueClause, falseClause), List(elemTy), lamEnv) + val (trueTy, _) = elabPat.elabPat(trueClause.pats.head, elemTy, trueEnv) + val (falseTy, _) = elabPat.elabPat(falseClause.pats.head, elemTy, falseEnv) + (TupleType(List(ListType(trueTy), ListType(falseTy))), env1) + case _ => + if (!subtype.subType(predTy, expFunTy)) { + diagnosticsInfo.add(ExpectedSubtype(pred.pos, pred, expected = expFunTy, got = predTy)) + (TupleType(List(DynamicType, DynamicType)), env1) + } else + (TupleType(List(ListType(elemTy), ListType(elemTy))), env1) + } + case RemoteId("maps", "filter", 2) => val List(funArg, map) = args val List(funArgTy, mapTy) = argTys diff --git a/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Predicates.scala b/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Predicates.scala index d8139d3..1b63ac8 100644 --- a/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Predicates.scala +++ b/eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Predicates.scala @@ -54,4 +54,24 @@ object Predicates { (clause2, clause1) } } + + def booleanReturnClauses(clauses: List[Clause]): Boolean = + clauses match { + case List(Clause(_, _, Body(List(expr1))), Clause(_, _, Body(List(expr2)))) => + (expr1, expr2) match { + case (AtomLit("true"), AtomLit("false")) => true + case (AtomLit("false"), AtomLit("true")) => true + case (_, _) => false + } + case _ => + false + } + + def getTrueFalseReturnClauses(clauses: List[Clause]): (Clause, Clause) = { + val List(clause1 @ Clause(_, _, Body(List(expr1))), clause2 @ Clause(_, _, Body(List(expr2)))) = clauses + (expr1, expr2) match { + case (AtomLit("true"), AtomLit("false")) => (clause1, clause2) + case (_, _) => (clause2, clause1) + } + } }