From ca8fa3dcd1b10f62c8a13c5ace9423e7d079cce8 Mon Sep 17 00:00:00 2001 From: Zeina Migeed Date: Tue, 21 Jan 2025 17:30:13 -0800 Subject: [PATCH] Add plumbing for send type inference Summary: Here, we add the keys and bindings needed for send type inference. The only thing that remains is: In answers.rs for the `Binding::SendTypeOfYield(_)`, we must lookup the return annotation and extract the send type from there if it exists. The key would be `&KeyAnnotation::ReturnAnnotation(function_name.clone())`. However, using self.get to retrieve the annotation causes panic if the key is not found. It looks like we can consider `std::panic::catch_unwind` but it seems this is unsafe behavior. Should we modify the `get` method to return an optional? This is something to address next. Once we do this, the send type will be correctly inferred. Reviewed By: stroxler Differential Revision: D68468322 fbshipit-source-id: 4464a1159defed48062d5747ac4e63a237ae3a81 --- pyre2/conformance/third_party/conformance.exp | 170 ------------------ .../third_party/conformance.result | 17 +- pyre2/conformance/third_party/results.json | 4 +- pyre2/pyre2/bin/alt/answers.rs | 2 + pyre2/pyre2/bin/alt/expr.rs | 3 +- pyre2/pyre2/bin/binding/binding.rs | 11 ++ pyre2/pyre2/bin/binding/bindings.rs | 8 +- pyre2/pyre2/bin/test/yield.rs | 66 ++++--- 8 files changed, 73 insertions(+), 208 deletions(-) diff --git a/pyre2/conformance/third_party/conformance.exp b/pyre2/conformance/third_party/conformance.exp index 15cb4818e32..c7ba48805cc 100644 --- a/pyre2/conformance/third_party/conformance.exp +++ b/pyre2/conformance/third_party/conformance.exp @@ -1648,16 +1648,6 @@ } ], "annotations_generators.py": [ - { - "code": -2, - "column": 9, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 46, - "name": "PyreError", - "stop_column": 18, - "stop_line": 46 - }, { "code": -2, "column": 16, @@ -1678,106 +1668,6 @@ "stop_column": 16, "stop_line": 57 }, - { - "code": -2, - "column": 9, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 57, - "name": "PyreError", - "stop_column": 16, - "stop_line": 57 - }, - { - "code": -2, - "column": 9, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 66, - "name": "PyreError", - "stop_column": 16, - "stop_line": 66 - }, - { - "code": -2, - "column": 5, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 70, - "name": "PyreError", - "stop_column": 14, - "stop_line": 70 - }, - { - "code": -2, - "column": 5, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 75, - "name": "PyreError", - "stop_column": 14, - "stop_line": 75 - }, - { - "code": -2, - "column": 5, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 79, - "name": "PyreError", - "stop_column": 10, - "stop_line": 79 - }, - { - "code": -2, - "column": 5, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 83, - "name": "PyreError", - "stop_column": 18, - "stop_line": 83 - }, - { - "code": -2, - "column": 5, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 87, - "name": "PyreError", - "stop_column": 15, - "stop_line": 87 - }, - { - "code": -2, - "column": 5, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 92, - "name": "PyreError", - "stop_column": 15, - "stop_line": 92 - }, - { - "code": -2, - "column": 5, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 101, - "name": "PyreError", - "stop_column": 12, - "stop_line": 101 - }, - { - "code": -2, - "column": 5, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 110, - "name": "PyreError", - "stop_column": 12, - "stop_line": 110 - }, { "code": -2, "column": 5, @@ -1828,16 +1718,6 @@ "stop_column": 19, "stop_line": 119 }, - { - "code": -2, - "column": 16, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 123, - "name": "PyreError", - "stop_column": 21, - "stop_line": 123 - }, { "code": -2, "column": 5, @@ -1858,16 +1738,6 @@ "stop_column": 29, "stop_line": 127 }, - { - "code": -2, - "column": 16, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 131, - "name": "PyreError", - "stop_column": 21, - "stop_line": 131 - }, { "code": -2, "column": 5, @@ -1888,26 +1758,6 @@ "stop_column": 29, "stop_line": 135 }, - { - "code": -2, - "column": 5, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 140, - "name": "PyreError", - "stop_column": 13, - "stop_line": 140 - }, - { - "code": -2, - "column": 5, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 145, - "name": "PyreError", - "stop_column": 13, - "stop_line": 145 - }, { "code": -2, "column": 5, @@ -1978,16 +1828,6 @@ "stop_column": 30, "stop_line": 160 }, - { - "code": -2, - "column": 9, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 161, - "name": "PyreError", - "stop_column": 35, - "stop_line": 161 - }, { "code": -2, "column": 5, @@ -2038,16 +1878,6 @@ "stop_column": 30, "stop_line": 179 }, - { - "code": -2, - "column": 5, - "concise_description": "TODO: ExprYield - Answers::expr_infer", - "description": "TODO: ExprYield - Answers::expr_infer", - "line": 187, - "name": "PyreError", - "stop_column": 10, - "stop_line": 187 - }, { "code": -2, "column": 1, diff --git a/pyre2/conformance/third_party/conformance.result b/pyre2/conformance/third_party/conformance.result index fd871c275eb..5a7ff79e7a4 100644 --- a/pyre2/conformance/third_party/conformance.result +++ b/pyre2/conformance/third_party/conformance.result @@ -142,32 +142,25 @@ ], "annotations_generators.py": [ "Line 51: Expected 1 errors", + "Line 57: Expected 1 errors", + "Line 66: Expected 1 errors", + "Line 75: Expected 1 errors", "Line 86: Expected 1 errors", + "Line 87: Expected 1 errors", "Line 91: Expected 1 errors", - "Line 46: Unexpected errors ['TODO: ExprYield - Answers::expr_infer']", + "Line 92: Expected 1 errors", "Line 56: Unexpected errors ['EXPECTED None <: C']", - "Line 70: Unexpected errors ['TODO: ExprYield - Answers::expr_infer']", - "Line 79: Unexpected errors ['TODO: ExprYield - Answers::expr_infer']", - "Line 83: Unexpected errors ['TODO: ExprYield - Answers::expr_infer']", - "Line 101: Unexpected errors ['TODO: ExprYield - Answers::expr_infer']", - "Line 110: Unexpected errors ['TODO: ExprYield - Answers::expr_infer']", "Line 114: Unexpected errors ['EXPECTED None <: Iterator[A]', 'TODO: YieldFrom(ExprYieldFrom - Answers::expr_infer']", - "Line 123: Unexpected errors ['TODO: ExprYield - Answers::expr_infer']", "Line 127: Unexpected errors ['EXPECTED None <: Generator[None, int, None]', 'TODO: YieldFrom(ExprYieldFrom - Answers::expr_infer']", - "Line 131: Unexpected errors ['TODO: ExprYield - Answers::expr_infer']", - "Line 140: Unexpected errors ['TODO: ExprYield - Answers::expr_infer']", - "Line 145: Unexpected errors ['TODO: ExprYield - Answers::expr_infer']", "Line 149: Unexpected errors ['TODO: YieldFrom(ExprYieldFrom - Answers::expr_infer']", "Line 150: Unexpected errors ['EXPECTED None <: Generator[int, None, None]', 'TODO: YieldFrom(ExprYieldFrom - Answers::expr_infer']", "Line 154: Unexpected errors [\"Missing argument 'result'\", 'EXPECTED Literal[1] <: float']", "Line 160: Unexpected errors [\"Missing argument 'result'\", 'EXPECTED Literal[1] <: float']", - "Line 161: Unexpected errors ['TODO: ExprYield - Answers::expr_infer']", "Line 167: Unexpected errors ['assert_type(Coroutine[Any, Any, AsyncGenerator[str, None]], AsyncGenerator[str, None]) failed']", "Line 168: Unexpected errors ['EXPECTED Coroutine[Unknown, Unknown, AsyncGenerator[str, None]] <: AsyncGenerator[str, None]']", "Line 174: Unexpected errors ['assert_type(Coroutine[Any, Any, AsyncGenerator[str, None]], AsyncGenerator[str, None]) failed']", "Line 175: Unexpected errors ['EXPECTED Coroutine[Unknown, Unknown, AsyncGenerator[str, None]] <: AsyncIterator[str]']", "Line 179: Unexpected errors ['EXPECTED None <: AsyncIterator[int]']", - "Line 187: Unexpected errors ['TODO: ExprYield - Answers::expr_infer']", "Line 190: Unexpected errors ['assert_type(Callable[[], Coroutine[Any, Any, AsyncIterator[int]]], Callable[[], AsyncIterator[int]]) failed']" ], "annotations_methods.py": [ diff --git a/pyre2/conformance/third_party/results.json b/pyre2/conformance/third_party/results.json index c487107d5f2..4ff4702e693 100644 --- a/pyre2/conformance/third_party/results.json +++ b/pyre2/conformance/third_party/results.json @@ -3,7 +3,7 @@ "pass": 9, "fail": 124, "pass_rate": 0.07, - "differences": 1348, + "differences": 1341, "passing": [ "annotations_coroutines.py", "directives_no_type_check.py", @@ -24,7 +24,7 @@ "aliases_typealiastype.py": 25, "aliases_variance.py": 3, "annotations_forward_refs.py": 10, - "annotations_generators.py": 28, + "annotations_generators.py": 21, "annotations_methods.py": 9, "annotations_typeexpr.py": 6, "callables_annotation.py": 26, diff --git a/pyre2/pyre2/bin/alt/answers.rs b/pyre2/pyre2/bin/alt/answers.rs index 247108f28d0..ba5ce7ea8b9 100644 --- a/pyre2/pyre2/bin/alt/answers.rs +++ b/pyre2/pyre2/bin/alt/answers.rs @@ -1076,6 +1076,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { .generator(yield_type, Type::any_implicit(), return_type) .to_type() } + // TODO: Zeina, here we must construct a ReturnAnnotation key and look it up. The lookup will panic if key is not found. Figure out how to handle the failure. + Binding::SendTypeOfYield(_) => Type::any_explicit(), Binding::ReturnExpr(ann, e, has_yields) => { let ann = ann.map(|k| self.get_idx(k)); let hint = ann.as_ref().and_then(|x| x.ty.as_ref()); diff --git a/pyre2/pyre2/bin/alt/expr.rs b/pyre2/pyre2/bin/alt/expr.rs index 593ba1bc82f..e9a4e5fd9ed 100644 --- a/pyre2/pyre2/bin/alt/expr.rs +++ b/pyre2/pyre2/bin/alt/expr.rs @@ -1086,7 +1086,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { None => self.error(x.range, "Expression is not awaitable".to_owned()), } } - Expr::Yield(x) => self.error_todo("Answers::expr_infer", x), + Expr::Yield(x) => self.get(&Key::SendTypeOfYield(x.range)).arc_clone(), + Expr::YieldFrom(_) => self.error_todo("Answers::expr_infer", x), Expr::Compare(x) => { let _ty = self.expr_infer(&x.left); diff --git a/pyre2/pyre2/bin/binding/binding.rs b/pyre2/pyre2/bin/binding/binding.rs index c6c8dd08142..a4ebd9757de 100644 --- a/pyre2/pyre2/bin/binding/binding.rs +++ b/pyre2/pyre2/bin/binding/binding.rs @@ -101,6 +101,8 @@ pub enum Key { DecoratorApplication(TextRange), /// I am the self type for a particular class. SelfType(ShortIdentifier), + /// The send type of a yield expression. + SendTypeOfYield(TextRange), /// The type at a specific return point. ReturnExpression(ShortIdentifier, TextRange), /// The type yielded inside of a specific yield expression inside a function. @@ -139,6 +141,7 @@ impl Ranged for Key { Self::Definition(x) => x.range(), Self::DecoratorApplication(r) => r.range(), Self::SelfType(x) => x.range(), + Self::SendTypeOfYield(x) => x.range(), Self::ReturnExpression(_, r) => *r, Self::YieldTypeOfYield(_, r) => *r, Self::YieldTypeOfGenerator(x) => x.range(), @@ -160,6 +163,9 @@ impl DisplayWith for Key { Self::Definition(x) => write!(f, "{} {:?}", ctx.display(x), x.range()), Self::DecoratorApplication(r) => write!(f, "decorator {:?}", r), Self::SelfType(x) => write!(f, "self {} {:?}", ctx.display(x), x.range()), + Self::SendTypeOfYield(x) => { + write!(f, "send type of yield {} {:?}", ctx.display(x), x.range()) + } Self::Usage(x) => write!(f, "use {} {:?}", ctx.display(x), x.range()), Self::Anon(r) => write!(f, "anon {r:?}"), Self::Expect(r) => write!(f, "expect {r:?}"), @@ -336,6 +342,8 @@ pub enum Binding { /// An expression returned from a function. /// The `bool` is whether the function has `yield` within it. ReturnExpr(Option>, Expr, bool), + /// An expression returned from a function. + SendTypeOfYield(ShortIdentifier), /// A decorator application: the Key is the entity being decorated. DecoratorApplication(Box, Idx), /// A grouping of both the yield expression types and the return type. @@ -449,6 +457,9 @@ impl DisplayWith for Binding { iterable.display_with(ctx) ) } + self::Binding::SendTypeOfYield(x) => { + write!(f, "send type of yield {} {:?}", m.display(x), x.range()) + } Self::IterableValue(None, x) => write!(f, "iter {}", m.display(x)), Self::IterableValue(Some(k), x) => { write!(f, "iter {}: {}", ctx.display(*k), m.display(x)) diff --git a/pyre2/pyre2/bin/binding/bindings.rs b/pyre2/pyre2/bin/binding/bindings.rs index b35548ed85d..ddd3f723db5 100644 --- a/pyre2/pyre2/bin/binding/bindings.rs +++ b/pyre2/pyre2/bin/binding/bindings.rs @@ -1233,9 +1233,15 @@ impl<'a> BindingsBuilder<'a> { let key = self.table.insert( Key::YieldTypeOfYield(ShortIdentifier::new(&func_name), x.range()), // collect the value of the yield expression. - Binding::Expr(None, yield_expr(x)), + Binding::Expr(None, yield_expr(x.clone())), ); yield_expr_keys.insert(key); + + self.table.insert( + Key::SendTypeOfYield(x.range()), + // collect the value of the yield expression. + Binding::SendTypeOfYield(ShortIdentifier::new(&func_name)), + ); } let yield_type = Binding::phi(yield_expr_keys); self.table.insert( diff --git a/pyre2/pyre2/bin/test/yield.rs b/pyre2/pyre2/bin/test/yield.rs index 7cdaaa88dea..77c27f5a687 100644 --- a/pyre2/pyre2/bin/test/yield.rs +++ b/pyre2/pyre2/bin/test/yield.rs @@ -18,13 +18,13 @@ TODO zeina: 1- We need a generator type; 2- next keyword currently unsupported from typing import assert_type, Generator, Literal, Any, reveal_type def yielding(): - yield 1 # E: TODO: ExprYield - Answers::expr_infer + yield 1 f = yielding() next_f = next(f) # E: Could not find name `next` reveal_type(next_f) # E: revealed type: Error -reveal_type(f) # E: None +reveal_type(f) # E: revealed type: Generator[Literal[1], Unknown, None] "#, ); @@ -40,11 +40,11 @@ It should be Generator[Literal[1, 2], Any, Literal['done']] or Generator[int, An from typing import reveal_type def gen_with_return(): - yield 1 # E: TODO: ExprYield - Answers::expr_infer - yield 2 # E: TODO: ExprYield - Answers::expr_infer + yield 1 + yield 2 return "done" -reveal_type(gen_with_return()) # E: Literal['done'] +reveal_type(gen_with_return()) # E: Generator[Literal[1, 2], Unknown, Literal['done']] "#, ); @@ -59,7 +59,7 @@ TODO zeina: we should correctly determine the send() type based on the signature from typing import Generator, reveal_type def accumulate(x: int) -> Generator[int, int, None]: - yield x # E: TODO: ExprYield - Answers::expr_infer + yield x gen = accumulate(10) reveal_type(gen) # E: revealed type: Generator[int, int, None] @@ -68,6 +68,28 @@ gen.send(5) "#, ); +testcase_with_bug!( + r#" +TODO zeina: we should correctly determine the send() type based on the signature of the generator. Additionally, we should correctly handle the return type of the generator. + "#, + test_generator_send_inference, + r#" + +from typing import Generator, reveal_type + +class Yield: pass +class Send: pass +class Return: pass + +def my_generator(n: int) -> Generator[Yield, Send, Return]: + s = yield Yield() + + reveal_type(s) # E: revealed type: Any + return Return() + +"#, +); + testcase_with_bug!( "TODOs", test_yield_with_iterator, @@ -75,9 +97,9 @@ testcase_with_bug!( from typing import Iterator, reveal_type def gen_numbers() -> Iterator[int]: - yield 1 # E: TODO: ExprYield - Answers::expr_infer - yield 2 # E: TODO: ExprYield - Answers::expr_infer - yield 3 # E: TODO: ExprYield - Answers::expr_infer + yield 1 + yield 2 + yield 3 reveal_type(gen_numbers()) # E: revealed type: Iterator[int] @@ -95,12 +117,12 @@ and Type of "another_generator()" should be "Generator[Literal[2], Any, None]" from typing import Generator, reveal_type def nested_generator(): - yield 1 # E: TODO: ExprYield - Answers::expr_infer + yield 1 yield from another_generator() # E: TODO: YieldFrom(ExprYieldFrom - Answers::expr_infer - yield 3 # E: TODO: ExprYield - Answers::expr_infer + yield 3 def another_generator(): - yield 2 # E: TODO: ExprYield - Answers::expr_infer + yield 2 reveal_type(nested_generator()) # E: revealed type: Generator[Literal[1, 3], Unknown, None] reveal_type(another_generator()) # E: revealed type: Generator[Literal[2], Unknown, None] @@ -116,7 +138,7 @@ from typing import Generator, reveal_type def f(value) -> Generator[int, None, None]: while True: - yield value # E: TODO: ExprYield - Answers::expr_infer + yield value reveal_type(f(3)) # E: revealed type: Generator[int, None, None] @@ -133,7 +155,7 @@ T = TypeVar('T') def f(value: T) -> Generator[T, None, None]: while True: - yield value # E: TODO: ExprYield - Answers::expr_infer + yield value reveal_type(f(3)) # E: revealed type: Generator[int, None, None] @@ -147,9 +169,9 @@ testcase_with_bug!( from typing import AsyncGenerator, reveal_type # E: Could not import `AsyncGenerator` from `typing` async def async_count_up_to() -> AsyncGenerator[int, None]: - yield 2 # E: TODO: ExprYield - Answers::expr_infer + yield 2 -reveal_type(async_count_up_to()) # E: Coroutine[Unknown, Unknown, Error] +reveal_type(async_count_up_to()) # E: revealed type: Coroutine[Unknown, Unknown, Error] "#, ); @@ -161,25 +183,25 @@ testcase_with_bug!( from typing import reveal_type async def async_count_up_to(): - yield 2 # E: TODO: ExprYield - Answers::expr_infer + yield 2 -reveal_type(async_count_up_to()) # E: Coroutine[Unknown, Unknown, Generator[Literal[2], Unknown, None]] +reveal_type(async_count_up_to()) # E: revealed type: Coroutine[Unknown, Unknown, Generator[Literal[2], Unknown, None]] "#, ); testcase_with_bug!( - "TODO zeina: infer send type.", + "TODO zeina: We are incorrectly inferring generators that return generators.", test_inferring_generators_that_return_generators, r#" -from typing import Generator, assert_type, reveal_type +from typing import Generator, assert_type def generator() -> Generator[int, None, None]: ... def generator2(x: int): - yield x # E: TODO: ExprYield - Answers::expr_infer + yield x return generator() -assert_type(generator2(1), Generator[int, None, Generator[int, None, None]]) # E: Generator[int, Any, Generator[int, None, None]] +assert_type(generator2(1), Generator[int, Any, Generator[int, None, None]]) "#, );