Skip to content

Commit

Permalink
Add custom type checking for element/2
Browse files Browse the repository at this point in the history
Summary: This would have caught the issue in D50360629, where the usage of `element` was clearly unsafe.

Reviewed By: ilya-klyuchnikov

Differential Revision: D50412145

fbshipit-source-id: a93a679027f7de3a2f469cfcf78fbcd53237d10c
  • Loading branch information
ruippeixotog authored and facebook-github-bot committed Oct 26, 2023
1 parent bb47c23 commit c206550
Show file tree
Hide file tree
Showing 8 changed files with 389 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
package com.whatsapp.eqwalizer.tc

import scala.annotation.tailrec
import com.whatsapp.eqwalizer.ast.Exprs.{AtomLit, Cons, Expr, Lambda, NilLit}
import com.whatsapp.eqwalizer.ast.Exprs.{AtomLit, Cons, Expr, IntLit, Lambda, NilLit}
import com.whatsapp.eqwalizer.ast.Types._
import com.whatsapp.eqwalizer.ast.{Exprs, Pos, RemoteId}
import com.whatsapp.eqwalizer.tc.TcDiagnostics.{ExpectedSubtype, UnboundVar, UnboundRecord}
import com.whatsapp.eqwalizer.tc.TcDiagnostics.{ExpectedSubtype, IndexOutOfBounds, UnboundVar, UnboundRecord}
import com.whatsapp.eqwalizer.ast.CompilerMacro

class ElabApplyCustom(pipelineContext: PipelineContext) {
Expand All @@ -28,6 +28,7 @@ class ElabApplyCustom(pipelineContext: PipelineContext) {

private lazy val custom: Set[RemoteId] =
Set(
RemoteId("erlang", "element", 2),
RemoteId("erlang", "map_get", 2),
RemoteId("file", "open", 2),
RemoteId("lists", "filtermap", 2),
Expand Down Expand Up @@ -296,6 +297,36 @@ class ElabApplyCustom(pipelineContext: PipelineContext) {
(valTy, env1)
}

/*
`-spec element(N :: NumberType, Tup :: TupleType) -> Out`, where `Out` is:
- `Tup[N]` when `N` is an integer literal corresponding to a valid index
- Union of element types of `Tup` when `N` is not a literal
- An error otherwise (index out of bounds or unexpected type)
*/
case RemoteId("erlang", "element", 2) =>
val List(index, tuple) = args
val List(indexTy, tupleTy) = argTys

def validate(): Unit = {
if (!subtype.subType(indexTy, NumberType))
throw ExpectedSubtype(index.pos, index, expected = NumberType, got = indexTy)
if (!subtype.subType(tupleTy, AnyTupleType))
throw ExpectedSubtype(tuple.pos, tuple, expected = AnyTupleType, got = tupleTy)
}
validate()

val elemTy = index match {
case IntLit(Some(n)) =>
narrow.getTupleElement(tupleTy, n) match {
case Right(elemTy) => elemTy
case Left(tupLen) => throw IndexOutOfBounds(callPos, index, n, tupLen)
}
case _ =>
narrow.getAllTupleElements(tupleTy)
}

(elemTy, env1)

case RemoteId("maps", "get", 3) =>
val List(key, map, defaultVal) = args
val List(keyTy, mapTy, defaultValTy) = argTys
Expand Down
98 changes: 98 additions & 0 deletions eqwalizer/src/main/scala/com/whatsapp/eqwalizer/tc/Narrow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,104 @@ class Narrow(pipelineContext: PipelineContext) {
case _ => List()
}

/**
* Given a type (required to be a subtype of `AnyTupleType`) and an index, returns the type of the tuple element at
* the index wrapped in a `Right`. If the index can be possibly out of bounds (in at least one of the options in a
* union) the function returns `Left(tupLen)`, where `tupLen` is the minimum index value for which this operation would
* type check.
*/
def getTupleElement(t: Type, idx: Int): Either[Int, Type] = t match {
case NoneType =>
Right(NoneType)
case DynamicType =>
Right(DynamicType)
case AnyTupleType if pipelineContext.gradualTyping =>
Right(DynamicType)
case AnyTupleType =>
Right(AnyType)
case BoundedDynamicType(t) if subtype.subType(t, AnyTupleType) =>
Right(BoundedDynamicType(getTupleElement(t, idx).getOrElse(NoneType)))
case BoundedDynamicType(t) =>
Right(BoundedDynamicType(NoneType))
case TupleType(elemTys) if idx >= 1 && idx <= elemTys.length =>
Right(elemTys(idx - 1))
case TupleType(elemTys) =>
Left(elemTys.length)
case r: RecordType =>
recordToTuple(r) match {
case Some(tupTy) => getTupleElement(tupTy, idx)
case None if pipelineContext.gradualTyping => Right(DynamicType)
case None => Right(AnyType)
}
case r: RefinedRecordType =>
refinedRecordToTuple(r) match {
case Some(tupTy) => getTupleElement(tupTy, idx)
case None if pipelineContext.gradualTyping => Right(DynamicType)
case None => Right(AnyType)
}
case UnionType(tys) =>
val res = tys.map(getTupleElement(_, idx)).foldLeft[Either[Int, Set[Type]]](Right(Set.empty)) {
case (Right(accTy), Right(elemTy)) => Right(accTy + elemTy)
case (Left(n1), Left(n2)) => Left(n1.min(n2))
case (Left(n1), _) => Left(n1)
case (_, Left(n2)) => Left(n2)
}
res.map { optionTys => UnionType(util.flattenUnions(UnionType(optionTys)).toSet) }
case RemoteType(rid, args) =>
val body = util.getTypeDeclBody(rid, args)
getTupleElement(body, idx)
case _ =>
throw new IllegalStateException()
}

/**
* Given a type (required to be a subtype of `AnyTupleType`), returns the union of all its element types.
*/
def getAllTupleElements(t: Type): Type = t match {
case NoneType =>
NoneType
case DynamicType =>
DynamicType
case AnyTupleType if pipelineContext.gradualTyping =>
DynamicType
case AnyTupleType =>
AnyType
case BoundedDynamicType(t) if subtype.subType(t, AnyTupleType) =>
BoundedDynamicType(getAllTupleElements(t))
case BoundedDynamicType(t) =>
BoundedDynamicType(NoneType)
case TupleType(elemTys) =>
UnionType(elemTys.toSet)
case r: RecordType =>
recordToTuple(r) match {
case Some(tupTy) => getAllTupleElements(tupTy)
case None if pipelineContext.gradualTyping => DynamicType
case None => AnyType
}
case r: RefinedRecordType =>
refinedRecordToTuple(r) match {
case Some(tupTy) => getAllTupleElements(tupTy)
case None if pipelineContext.gradualTyping => DynamicType
case None => AnyType
}
case UnionType(tys) =>
UnionType(util.flattenUnions(UnionType(tys.map(getAllTupleElements))).toSet)
case RemoteType(rid, args) =>
val body = util.getTypeDeclBody(rid, args)
getAllTupleElements(body)
case _ =>
throw new IllegalStateException()
}

private def recordToTuple(r: RecordType): Option[TupleType] =
refinedRecordToTuple(RefinedRecordType(r, Map()))

private def refinedRecordToTuple(r: RefinedRecordType): Option[TupleType] =
util.getRecord(r.recType.module, r.recType.name).map { recDecl =>
val elemTys = AtomLitType(r.recType.name) :: recDecl.fields.map(f => r.fields.getOrElse(f._1, f._2.tp)).toList
TupleType(elemTys)
}

private def adjustShapeMap(t: ShapeMap, keyT: Type, valT: Type): Type =
keyT match {
case AtomLitType(key) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ object TcDiagnostics {
def errorName = "fun_arity_mismatch"
override def erroneousExpr: Option[Expr] = Some(expr)
}
case class IndexOutOfBounds(pos: Pos, expr: Expr, index: Int, tupleArity: Int) extends TypeError {
override val msg: String = s"Tried to access element $index of a tuple with $tupleArity elements"
def errorName = "index_out_of_bounds"
override def erroneousExpr: Option[Expr] = Some(expr)
}
case class NotSupportedLambdaInOverloadedCall(pos: Pos, expr: Expr) extends TypeError {
override val msg: String = s"Lambdas are not allowed as args to overloaded functions"
def errorName = "fun_in_overload_arg"
Expand Down
4 changes: 2 additions & 2 deletions eqwalizer/test_projects/_cli/otp_funs.cli
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ gb_sets 26
proplists 51
maps 149
lists 169
erlang 374
erlang 396
Per app stats:
kernel 21
erts 374
erts 396
stdlib 471
86 changes: 86 additions & 0 deletions eqwalizer/test_projects/check/src/custom.erl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,92 @@
-import(maps, [get/2, get/3]).
-compile([export_all, nowarn_export_all]).

-record(foo, {
a :: ok | error,
b :: number(),
c :: string()
}).

% element/2 - basic examples

-spec element_2_basic_1({atom(), number(), string()}) -> atom().
element_2_basic_1(Tup) ->
element(1, Tup).

-spec element_2_basic_2_neg({atom(), number(), string(), map()}) -> atom().
element_2_basic_2_neg(Tup) ->
element(4, Tup).

-spec element_2_basic_3_neg({atom(), number(), string()}) -> atom().
element_2_basic_3_neg(Tup) ->
element(42, Tup).

% element/2 - union examples

-spec element_2_union_1({atom(), number() | string()} | {number(), atom()}) -> number() | string() | atom().
element_2_union_1(Tup) ->
element(2, Tup).

-spec element_2_union_2_neg({atom(), number() | string()} | {number(), atom()}) -> map().
element_2_union_2_neg(Tup) ->
element(2, Tup).

-spec element_2_union_3_neg({atom(), string()} | list()) -> string().
element_2_union_3_neg(Tup) ->
element(2, Tup).

-spec element_2_union_4_neg({c, d, e, f} | {a, b} | {b, c, d}) -> atom().
element_2_union_4_neg(Tup) ->
element(42, Tup).

% element/2 - dynamic index examples

-spec element_2_dynindex_1_neg(pos_integer(), {atom(), number(), string()}) -> map().
element_2_dynindex_1_neg(N, Tup) ->
element(N, Tup).

-spec element_2_dynindex_2_neg(pos_integer(), {atom(), atom()} | {atom(), atom(), number()}) -> atom().
element_2_dynindex_2_neg(N, Tup) ->
element(N, Tup).

% element/2 - tuple() examples

-spec element_2_anytuple_1_neg(tuple()) -> atom().
element_2_anytuple_1_neg(Tup) ->
element(1, Tup).

-spec element_2_anytuple_2_neg(tuple() | {number(), atom()}) -> atom().
element_2_anytuple_2_neg(Tup) ->
element(1, Tup).

% element/2 - record examples

-spec element_2_record_1(#foo{}) -> foo.
element_2_record_1(Rec) ->
element(1, Rec).

-spec element_2_record_2(#foo{}) -> ok | error.
element_2_record_2(Rec) ->
element(2, Rec).

-spec element_2_record_3(#foo{}) -> ok.
element_2_record_3(Rec) when Rec#foo.a =/= error ->
element(2, Rec).

-spec element_2_record_4_neg(pos_integer(), #foo{}) -> atom().
element_2_record_4_neg(N, Rec) ->
element(N, Rec).

% element/2 - none examples

-spec element_2_none_1(none()) -> number().
element_2_none_1(Tup) ->
element(42, Tup).

-spec element_2_none_2(pos_integer(), none()) -> number().
element_2_none_2(N, Tup) ->
element(N, Tup).

-spec map_get_2_1(
pid(), #{pid() => atom()}
) -> atom().
Expand Down
Loading

0 comments on commit c206550

Please sign in to comment.