From 3e01a76584168494fa866a4b02f791e4bba97f08 Mon Sep 17 00:00:00 2001 From: Chayim Refael Friedman Date: Tue, 3 Dec 2024 20:46:45 +0200 Subject: [PATCH] Support `AsyncFnX` traits Only in calls, because to support them in bounds we need support from Chalk. However we don't yet report error from bounds anyway, so this is less severe. The returned future is shown in its name within inlay hints instead of as a nicer `impl Future`, but that can wait for another PR. --- crates/hir-def/src/lang_item.rs | 3 + crates/hir-ty/src/infer/expr.rs | 6 +- crates/hir-ty/src/infer/unify.rs | 114 +++++++++-------- crates/hir-ty/src/mir/lower.rs | 12 +- crates/hir-ty/src/tests/traits.rs | 50 ++++++++ crates/hir-ty/src/traits.rs | 26 +++- .../src/handlers/expected_function.rs | 21 ++++ crates/intern/src/symbol/symbols.rs | 8 ++ crates/test-utils/src/minicore.rs | 119 +++++++++++++++++- 9 files changed, 291 insertions(+), 68 deletions(-) diff --git a/crates/hir-def/src/lang_item.rs b/crates/hir-def/src/lang_item.rs index 166c965d14c6..0629d87e5444 100644 --- a/crates/hir-def/src/lang_item.rs +++ b/crates/hir-def/src/lang_item.rs @@ -376,6 +376,9 @@ language_item_table! { Fn, sym::fn_, fn_trait, Target::Trait, GenericRequirement::Exact(1); FnMut, sym::fn_mut, fn_mut_trait, Target::Trait, GenericRequirement::Exact(1); FnOnce, sym::fn_once, fn_once_trait, Target::Trait, GenericRequirement::Exact(1); + AsyncFn, sym::async_fn, async_fn_trait, Target::Trait, GenericRequirement::Exact(1); + AsyncFnMut, sym::async_fn_mut, async_fn_mut_trait, Target::Trait, GenericRequirement::Exact(1); + AsyncFnOnce, sym::async_fn_once, async_fn_once_trait, Target::Trait, GenericRequirement::Exact(1); FnOnceOutput, sym::fn_once_output, fn_once_output, Target::AssocTy, GenericRequirement::None; diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index 32b4ea2f28ba..c21ff19c45dc 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -1287,8 +1287,8 @@ impl InferenceContext<'_> { tgt_expr: ExprId, ) { match fn_x { - FnTrait::FnOnce => (), - FnTrait::FnMut => { + FnTrait::FnOnce | FnTrait::AsyncFnOnce => (), + FnTrait::FnMut | FnTrait::AsyncFnMut => { if let TyKind::Ref(Mutability::Mut, lt, inner) = derefed_callee.kind(Interner) { if adjustments .last() @@ -1312,7 +1312,7 @@ impl InferenceContext<'_> { )); } } - FnTrait::Fn => { + FnTrait::Fn | FnTrait::AsyncFn => { if !matches!(derefed_callee.kind(Interner), TyKind::Ref(Mutability::Not, _, _)) { adjustments.push(Adjustment::borrow( Mutability::Not, diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs index e4881d752013..a83e58bc65be 100644 --- a/crates/hir-ty/src/infer/unify.rs +++ b/crates/hir-ty/src/infer/unify.rs @@ -794,69 +794,75 @@ impl<'a> InferenceTable<'a> { ty: &Ty, num_args: usize, ) -> Option<(FnTrait, Vec, Ty)> { - let krate = self.trait_env.krate; - let fn_once_trait = FnTrait::FnOnce.get_id(self.db, krate)?; - let trait_data = self.db.trait_data(fn_once_trait); - let output_assoc_type = - trait_data.associated_type_by_name(&Name::new_symbol_root(sym::Output.clone()))?; - - let mut arg_tys = Vec::with_capacity(num_args); - let arg_ty = TyBuilder::tuple(num_args) - .fill(|it| { - let arg = match it { - ParamKind::Type => self.new_type_var(), - ParamKind::Lifetime => unreachable!("Tuple with lifetime parameter"), - ParamKind::Const(_) => unreachable!("Tuple with const parameter"), - }; - arg_tys.push(arg.clone()); - arg.cast(Interner) - }) - .build(); - - let b = TyBuilder::trait_ref(self.db, fn_once_trait); - if b.remaining() != 2 { - return None; - } - let mut trait_ref = b.push(ty.clone()).push(arg_ty).build(); + for (fn_trait_name, output_assoc_name, subtraits) in [ + (FnTrait::FnOnce, sym::Output.clone(), &[FnTrait::Fn, FnTrait::FnMut][..]), + (FnTrait::AsyncFnMut, sym::CallRefFuture.clone(), &[FnTrait::AsyncFn]), + (FnTrait::AsyncFnOnce, sym::CallOnceFuture.clone(), &[]), + ] { + let krate = self.trait_env.krate; + let fn_trait = fn_trait_name.get_id(self.db, krate)?; + let trait_data = self.db.trait_data(fn_trait); + let output_assoc_type = + trait_data.associated_type_by_name(&Name::new_symbol_root(output_assoc_name))?; + + let mut arg_tys = Vec::with_capacity(num_args); + let arg_ty = TyBuilder::tuple(num_args) + .fill(|it| { + let arg = match it { + ParamKind::Type => self.new_type_var(), + ParamKind::Lifetime => unreachable!("Tuple with lifetime parameter"), + ParamKind::Const(_) => unreachable!("Tuple with const parameter"), + }; + arg_tys.push(arg.clone()); + arg.cast(Interner) + }) + .build(); + + let b = TyBuilder::trait_ref(self.db, fn_trait); + if b.remaining() != 2 { + return None; + } + let mut trait_ref = b.push(ty.clone()).push(arg_ty).build(); - let projection = { - TyBuilder::assoc_type_projection( + let projection = TyBuilder::assoc_type_projection( self.db, output_assoc_type, Some(trait_ref.substitution.clone()), ) - .build() - }; + .fill_with_unknown() + .build(); - let trait_env = self.trait_env.env.clone(); - let obligation = InEnvironment { - goal: trait_ref.clone().cast(Interner), - environment: trait_env.clone(), - }; - let canonical = self.canonicalize(obligation.clone()); - if self.db.trait_solve(krate, self.trait_env.block, canonical.cast(Interner)).is_some() { - self.register_obligation(obligation.goal); - let return_ty = self.normalize_projection_ty(projection); - for fn_x in [FnTrait::Fn, FnTrait::FnMut, FnTrait::FnOnce] { - let fn_x_trait = fn_x.get_id(self.db, krate)?; - trait_ref.trait_id = to_chalk_trait_id(fn_x_trait); - let obligation: chalk_ir::InEnvironment> = InEnvironment { - goal: trait_ref.clone().cast(Interner), - environment: trait_env.clone(), - }; - let canonical = self.canonicalize(obligation.clone()); - if self - .db - .trait_solve(krate, self.trait_env.block, canonical.cast(Interner)) - .is_some() - { - return Some((fn_x, arg_tys, return_ty)); + let trait_env = self.trait_env.env.clone(); + let obligation = InEnvironment { + goal: trait_ref.clone().cast(Interner), + environment: trait_env.clone(), + }; + let canonical = self.canonicalize(obligation.clone()); + if self.db.trait_solve(krate, self.trait_env.block, canonical.cast(Interner)).is_some() + { + self.register_obligation(obligation.goal); + let return_ty = self.normalize_projection_ty(projection); + for &fn_x in subtraits { + let fn_x_trait = fn_x.get_id(self.db, krate)?; + trait_ref.trait_id = to_chalk_trait_id(fn_x_trait); + let obligation: chalk_ir::InEnvironment> = + InEnvironment { + goal: trait_ref.clone().cast(Interner), + environment: trait_env.clone(), + }; + let canonical = self.canonicalize(obligation.clone()); + if self + .db + .trait_solve(krate, self.trait_env.block, canonical.cast(Interner)) + .is_some() + { + return Some((fn_x, arg_tys, return_ty)); + } } + return Some((fn_trait_name, arg_tys, return_ty)); } - unreachable!("It should at least implement FnOnce at this point"); - } else { - None } + None } pub(super) fn insert_type_vars(&mut self, ty: T) -> T diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs index c4e064005106..1d1044df6e96 100644 --- a/crates/hir-ty/src/mir/lower.rs +++ b/crates/hir-ty/src/mir/lower.rs @@ -2023,11 +2023,11 @@ pub fn mir_body_for_closure_query( ctx.result.locals.alloc(Local { ty: infer[*root].clone() }); let closure_local = ctx.result.locals.alloc(Local { ty: match kind { - FnTrait::FnOnce => infer[expr].clone(), - FnTrait::FnMut => { + FnTrait::FnOnce | FnTrait::AsyncFnOnce => infer[expr].clone(), + FnTrait::FnMut | FnTrait::AsyncFnMut => { TyKind::Ref(Mutability::Mut, error_lifetime(), infer[expr].clone()).intern(Interner) } - FnTrait::Fn => { + FnTrait::Fn | FnTrait::AsyncFn => { TyKind::Ref(Mutability::Not, error_lifetime(), infer[expr].clone()).intern(Interner) } }, @@ -2055,8 +2055,10 @@ pub fn mir_body_for_closure_query( let mut err = None; let closure_local = ctx.result.locals.iter().nth(1).unwrap().0; let closure_projection = match kind { - FnTrait::FnOnce => vec![], - FnTrait::FnMut | FnTrait::Fn => vec![ProjectionElem::Deref], + FnTrait::FnOnce | FnTrait::AsyncFnOnce => vec![], + FnTrait::FnMut | FnTrait::Fn | FnTrait::AsyncFnMut | FnTrait::AsyncFn => { + vec![ProjectionElem::Deref] + } }; ctx.result.walk_places(|p, store| { if let Some(it) = upvar_map.get(&p.local) { diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs index 9b982a124e7b..2bd16fdd6086 100644 --- a/crates/hir-ty/src/tests/traits.rs +++ b/crates/hir-ty/src/tests/traits.rs @@ -4790,3 +4790,53 @@ fn allowed3(baz: impl Baz>) {} "#]], ) } + +#[test] +fn async_fn_traits() { + check_infer( + r#" +//- minicore: async_fn +async fn foo i32>(a: T) { + let fut1 = a(0); + fut1.await; +} +async fn bar i32>(mut b: T) { + let fut2 = b(0); + fut2.await; +} +async fn baz i32>(c: T) { + let fut3 = c(0); + fut3.await; +} + "#, + expect![[r#" + 37..38 'a': T + 43..83 '{ ...ait; }': () + 43..83 '{ ...ait; }': impl Future + 53..57 'fut1': AsyncFnMut::CallRefFuture<'?, T, (u32,)> + 60..61 'a': T + 60..64 'a(0)': AsyncFnMut::CallRefFuture<'?, T, (u32,)> + 62..63 '0': u32 + 70..74 'fut1': AsyncFnMut::CallRefFuture<'?, T, (u32,)> + 70..80 'fut1.await': i32 + 124..129 'mut b': T + 134..174 '{ ...ait; }': () + 134..174 '{ ...ait; }': impl Future + 144..148 'fut2': AsyncFnMut::CallRefFuture<'?, T, (u32,)> + 151..152 'b': T + 151..155 'b(0)': AsyncFnMut::CallRefFuture<'?, T, (u32,)> + 153..154 '0': u32 + 161..165 'fut2': AsyncFnMut::CallRefFuture<'?, T, (u32,)> + 161..171 'fut2.await': i32 + 216..217 'c': T + 222..262 '{ ...ait; }': () + 222..262 '{ ...ait; }': impl Future + 232..236 'fut3': AsyncFnOnce::CallOnceFuture + 239..240 'c': T + 239..243 'c(0)': AsyncFnOnce::CallOnceFuture + 241..242 '0': u32 + 249..253 'fut3': AsyncFnOnce::CallOnceFuture + 249..259 'fut3.await': i32 + "#]], + ); +} diff --git a/crates/hir-ty/src/traits.rs b/crates/hir-ty/src/traits.rs index 51ccd4ef293f..8cb7dbf60f37 100644 --- a/crates/hir-ty/src/traits.rs +++ b/crates/hir-ty/src/traits.rs @@ -220,6 +220,10 @@ pub enum FnTrait { FnOnce, FnMut, Fn, + + AsyncFnOnce, + AsyncFnMut, + AsyncFn, } impl fmt::Display for FnTrait { @@ -228,6 +232,9 @@ impl fmt::Display for FnTrait { FnTrait::FnOnce => write!(f, "FnOnce"), FnTrait::FnMut => write!(f, "FnMut"), FnTrait::Fn => write!(f, "Fn"), + FnTrait::AsyncFnOnce => write!(f, "AsyncFnOnce"), + FnTrait::AsyncFnMut => write!(f, "AsyncFnMut"), + FnTrait::AsyncFn => write!(f, "AsyncFn"), } } } @@ -238,6 +245,9 @@ impl FnTrait { FnTrait::FnOnce => "call_once", FnTrait::FnMut => "call_mut", FnTrait::Fn => "call", + FnTrait::AsyncFnOnce => "async_call_once", + FnTrait::AsyncFnMut => "async_call_mut", + FnTrait::AsyncFn => "async_call", } } @@ -246,6 +256,9 @@ impl FnTrait { FnTrait::FnOnce => LangItem::FnOnce, FnTrait::FnMut => LangItem::FnMut, FnTrait::Fn => LangItem::Fn, + FnTrait::AsyncFnOnce => LangItem::AsyncFnOnce, + FnTrait::AsyncFnMut => LangItem::AsyncFnMut, + FnTrait::AsyncFn => LangItem::AsyncFn, } } @@ -254,15 +267,19 @@ impl FnTrait { LangItem::FnOnce => Some(FnTrait::FnOnce), LangItem::FnMut => Some(FnTrait::FnMut), LangItem::Fn => Some(FnTrait::Fn), + LangItem::AsyncFnOnce => Some(FnTrait::AsyncFnOnce), + LangItem::AsyncFnMut => Some(FnTrait::AsyncFnMut), + LangItem::AsyncFn => Some(FnTrait::AsyncFn), _ => None, } } pub const fn to_chalk_ir(self) -> rust_ir::ClosureKind { + // Chalk doesn't support async fn traits. match self { - FnTrait::FnOnce => rust_ir::ClosureKind::FnOnce, - FnTrait::FnMut => rust_ir::ClosureKind::FnMut, - FnTrait::Fn => rust_ir::ClosureKind::Fn, + FnTrait::AsyncFnOnce | FnTrait::FnOnce => rust_ir::ClosureKind::FnOnce, + FnTrait::AsyncFnMut | FnTrait::FnMut => rust_ir::ClosureKind::FnMut, + FnTrait::AsyncFn | FnTrait::Fn => rust_ir::ClosureKind::Fn, } } @@ -271,6 +288,9 @@ impl FnTrait { FnTrait::FnOnce => Name::new_symbol_root(sym::call_once.clone()), FnTrait::FnMut => Name::new_symbol_root(sym::call_mut.clone()), FnTrait::Fn => Name::new_symbol_root(sym::call.clone()), + FnTrait::AsyncFnOnce => Name::new_symbol_root(sym::async_call_once.clone()), + FnTrait::AsyncFnMut => Name::new_symbol_root(sym::async_call_mut.clone()), + FnTrait::AsyncFn => Name::new_symbol_root(sym::async_call.clone()), } } diff --git a/crates/ide-diagnostics/src/handlers/expected_function.rs b/crates/ide-diagnostics/src/handlers/expected_function.rs index 02299197b125..e3a1e12e0296 100644 --- a/crates/ide-diagnostics/src/handlers/expected_function.rs +++ b/crates/ide-diagnostics/src/handlers/expected_function.rs @@ -37,4 +37,25 @@ fn foo() { "#, ); } + + #[test] + fn no_error_for_async_fn_traits() { + check_diagnostics( + r#" +//- minicore: async_fn +async fn f(it: impl AsyncFn(u32) -> i32) { + let fut = it(0); + let _: i32 = fut.await; +} +async fn g(mut it: impl AsyncFnMut(u32) -> i32) { + let fut = it(0); + let _: i32 = fut.await; +} +async fn h(it: impl AsyncFnOnce(u32) -> i32) { + let fut = it(0); + let _: i32 = fut.await; +} + "#, + ); + } } diff --git a/crates/intern/src/symbol/symbols.rs b/crates/intern/src/symbol/symbols.rs index 865518fe941e..ee96eff33097 100644 --- a/crates/intern/src/symbol/symbols.rs +++ b/crates/intern/src/symbol/symbols.rs @@ -150,6 +150,9 @@ define_symbols! { C, call_mut, call_once, + async_call_once, + async_call_mut, + async_call, call, cdecl, Center, @@ -221,6 +224,9 @@ define_symbols! { fn_mut, fn_once_output, fn_once, + async_fn_once, + async_fn_mut, + async_fn, fn_ptr_addr, fn_ptr_trait, format_alignment, @@ -334,6 +340,8 @@ define_symbols! { Option, Ord, Output, + CallRefFuture, + CallOnceFuture, owned_box, packed, panic_2015, diff --git a/crates/test-utils/src/minicore.rs b/crates/test-utils/src/minicore.rs index 07767d5ae9f6..6ee577d034db 100644 --- a/crates/test-utils/src/minicore.rs +++ b/crates/test-utils/src/minicore.rs @@ -12,6 +12,7 @@ //! asm: //! assert: //! as_ref: sized +//! async_fn: fn, tuple, future //! bool_impl: option, fn //! builtin_impls: //! cell: copy, drop @@ -29,7 +30,7 @@ //! eq: sized //! error: fmt //! fmt: option, result, transmute, coerce_unsized, copy, clone, derive -//! fn: +//! fn: tuple //! from: sized //! future: pin //! coroutine: pin @@ -60,6 +61,7 @@ //! sync: sized //! transmute: //! try: infallible +//! tuple: //! unpin: sized //! unsize: sized //! todo: panic @@ -138,10 +140,10 @@ pub mod marker { } // endregion:copy - // region:fn + // region:tuple #[lang = "tuple_trait"] pub trait Tuple {} - // endregion:fn + // endregion:tuple // region:phantom_data #[lang = "phantom_data"] @@ -682,6 +684,116 @@ pub mod ops { } pub use self::function::{Fn, FnMut, FnOnce}; // endregion:fn + + // region:async_fn + mod async_function { + use crate::{future::Future, marker::Tuple}; + + #[lang = "async_fn"] + #[fundamental] + pub trait AsyncFn: AsyncFnMut { + extern "rust-call" fn async_call(&self, args: Args) -> Self::CallRefFuture<'_>; + } + + #[lang = "async_fn_mut"] + #[fundamental] + pub trait AsyncFnMut: AsyncFnOnce { + #[lang = "call_ref_future"] + type CallRefFuture<'a>: Future + where + Self: 'a; + extern "rust-call" fn async_call_mut(&mut self, args: Args) -> Self::CallRefFuture<'_>; + } + + #[lang = "async_fn_once"] + #[fundamental] + pub trait AsyncFnOnce { + #[lang = "async_fn_once_output"] + type Output; + #[lang = "call_once_future"] + type CallOnceFuture: Future; + extern "rust-call" fn async_call_once(self, args: Args) -> Self::CallOnceFuture; + } + + mod impls { + use super::{AsyncFn, AsyncFnMut, AsyncFnOnce}; + use crate::marker::Tuple; + + impl AsyncFn for &F + where + F: AsyncFn, + { + extern "rust-call" fn async_call(&self, args: A) -> Self::CallRefFuture<'_> { + F::async_call(*self, args) + } + } + + #[unstable(feature = "async_fn_traits", issue = "none")] + impl AsyncFnMut for &F + where + F: AsyncFn, + { + type CallRefFuture<'a> + = F::CallRefFuture<'a> + where + Self: 'a; + + extern "rust-call" fn async_call_mut( + &mut self, + args: A, + ) -> Self::CallRefFuture<'_> { + F::async_call(*self, args) + } + } + + #[unstable(feature = "async_fn_traits", issue = "none")] + impl<'a, A: Tuple, F: ?Sized> AsyncFnOnce for &'a F + where + F: AsyncFn, + { + type Output = F::Output; + type CallOnceFuture = F::CallRefFuture<'a>; + + extern "rust-call" fn async_call_once(self, args: A) -> Self::CallOnceFuture { + F::async_call(self, args) + } + } + + #[unstable(feature = "async_fn_traits", issue = "none")] + impl AsyncFnMut for &mut F + where + F: AsyncFnMut, + { + type CallRefFuture<'a> + = F::CallRefFuture<'a> + where + Self: 'a; + + extern "rust-call" fn async_call_mut( + &mut self, + args: A, + ) -> Self::CallRefFuture<'_> { + F::async_call_mut(*self, args) + } + } + + #[unstable(feature = "async_fn_traits", issue = "none")] + impl<'a, A: Tuple, F: ?Sized> AsyncFnOnce for &'a mut F + where + F: AsyncFnMut, + { + type Output = F::Output; + type CallOnceFuture = F::CallRefFuture<'a>; + + extern "rust-call" fn async_call_once(self, args: A) -> Self::CallOnceFuture { + F::async_call_mut(self, args) + } + } + } + } + pub use self::async_function::{AsyncFn, AsyncFnMut, AsyncFnOnce}; + // endregion:async_fn + // region:try mod try_ { use crate::convert::Infallible; @@ -1684,6 +1796,7 @@ pub mod prelude { marker::Sync, // :sync mem::drop, // :drop ops::Drop, // :drop + ops::{AsyncFn, AsyncFnMut, AsyncFnOnce}, // :async_fn ops::{Fn, FnMut, FnOnce}, // :fn option::Option::{self, None, Some}, // :option panic, // :panic