Skip to content

Commit b8c93f1

Browse files
Coroutine closures implement regular Fn traits, when possible
1 parent 08af64e commit b8c93f1

File tree

5 files changed

+142
-18
lines changed

5 files changed

+142
-18
lines changed

compiler/rustc_hir_typeck/src/closure.rs

+12-5
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,18 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
5656
// It's always helpful for inference if we know the kind of
5757
// closure sooner rather than later, so first examine the expected
5858
// type, and see if can glean a closure kind from there.
59-
let (expected_sig, expected_kind) = match expected.to_option(self) {
60-
Some(ty) => {
61-
self.deduce_closure_signature(self.try_structurally_resolve_type(expr_span, ty))
62-
}
63-
None => (None, None),
59+
let (expected_sig, expected_kind) = match closure.kind {
60+
hir::ClosureKind::Closure => match expected.to_option(self) {
61+
Some(ty) => {
62+
self.deduce_closure_signature(self.try_structurally_resolve_type(expr_span, ty))
63+
}
64+
None => (None, None),
65+
},
66+
// We don't want to deduce a signature from `Fn` bounds for coroutines
67+
// or coroutine-closures, because the former does not implement `Fn`
68+
// ever, and the latter's signature doesn't correspond to the coroutine
69+
// type that it returns.
70+
hir::ClosureKind::Coroutine(_) | hir::ClosureKind::CoroutineClosure(_) => (None, None),
6471
};
6572

6673
let ClosureSignatures { bound_sig, mut liberated_sig } =

compiler/rustc_trait_selection/src/traits/project.rs

+70-4
Original file line numberDiff line numberDiff line change
@@ -2074,7 +2074,9 @@ fn confirm_select_candidate<'cx, 'tcx>(
20742074
} else if lang_items.async_iterator_trait() == Some(trait_def_id) {
20752075
confirm_async_iterator_candidate(selcx, obligation, data)
20762076
} else if selcx.tcx().fn_trait_kind_from_def_id(trait_def_id).is_some() {
2077-
if obligation.predicate.self_ty().is_closure() {
2077+
if obligation.predicate.self_ty().is_closure()
2078+
|| obligation.predicate.self_ty().is_coroutine_closure()
2079+
{
20782080
confirm_closure_candidate(selcx, obligation, data)
20792081
} else {
20802082
confirm_fn_pointer_candidate(selcx, obligation, data)
@@ -2386,11 +2388,75 @@ fn confirm_closure_candidate<'cx, 'tcx>(
23862388
obligation: &ProjectionTyObligation<'tcx>,
23872389
nested: Vec<PredicateObligation<'tcx>>,
23882390
) -> Progress<'tcx> {
2391+
let tcx = selcx.tcx();
23892392
let self_ty = selcx.infcx.shallow_resolve(obligation.predicate.self_ty());
2390-
let ty::Closure(_, args) = self_ty.kind() else {
2391-
unreachable!("expected closure self type for closure candidate, found {self_ty}")
2393+
let closure_sig = match *self_ty.kind() {
2394+
ty::Closure(_, args) => args.as_closure().sig(),
2395+
2396+
// Construct a "normal" `FnOnce` signature for coroutine-closure. This is
2397+
// basically duplicated with the `AsyncFnOnce::CallOnce` confirmation, but
2398+
// I didn't see a good way to unify those.
2399+
ty::CoroutineClosure(def_id, args) => {
2400+
let args = args.as_coroutine_closure();
2401+
let kind_ty = args.kind_ty();
2402+
args.coroutine_closure_sig().map_bound(|sig| {
2403+
// If we know the kind and upvars, use that directly.
2404+
// Otherwise, defer to `AsyncFnKindHelper::Upvars` to delay
2405+
// the projection, like the `AsyncFn*` traits do.
2406+
let output_ty = if let Some(_) = kind_ty.to_opt_closure_kind() {
2407+
sig.to_coroutine_given_kind_and_upvars(
2408+
tcx,
2409+
args.parent_args(),
2410+
tcx.coroutine_for_closure(def_id),
2411+
ty::ClosureKind::FnOnce,
2412+
tcx.lifetimes.re_static,
2413+
args.tupled_upvars_ty(),
2414+
args.coroutine_captures_by_ref_ty(),
2415+
)
2416+
} else {
2417+
let async_fn_kind_trait_def_id =
2418+
tcx.require_lang_item(LangItem::AsyncFnKindHelper, None);
2419+
let upvars_projection_def_id = tcx
2420+
.associated_items(async_fn_kind_trait_def_id)
2421+
.filter_by_name_unhygienic(sym::Upvars)
2422+
.next()
2423+
.unwrap()
2424+
.def_id;
2425+
let tupled_upvars_ty = Ty::new_projection(
2426+
tcx,
2427+
upvars_projection_def_id,
2428+
[
2429+
ty::GenericArg::from(kind_ty),
2430+
Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce).into(),
2431+
tcx.lifetimes.re_static.into(),
2432+
sig.tupled_inputs_ty.into(),
2433+
args.tupled_upvars_ty().into(),
2434+
args.coroutine_captures_by_ref_ty().into(),
2435+
],
2436+
);
2437+
sig.to_coroutine(
2438+
tcx,
2439+
args.parent_args(),
2440+
Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
2441+
tcx.coroutine_for_closure(def_id),
2442+
tupled_upvars_ty,
2443+
)
2444+
};
2445+
tcx.mk_fn_sig(
2446+
[sig.tupled_inputs_ty],
2447+
output_ty,
2448+
sig.c_variadic,
2449+
sig.unsafety,
2450+
sig.abi,
2451+
)
2452+
})
2453+
}
2454+
2455+
_ => {
2456+
unreachable!("expected closure self type for closure candidate, found {self_ty}");
2457+
}
23922458
};
2393-
let closure_sig = args.as_closure().sig();
2459+
23942460
let Normalized { value: closure_sig, obligations } = normalize_with_depth(
23952461
selcx,
23962462
obligation.param_env,

compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs

+25
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,31 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
332332
}
333333
}
334334
}
335+
ty::CoroutineClosure(def_id, args) => {
336+
let is_const = self.tcx().is_const_fn_raw(def_id);
337+
match self.infcx.closure_kind(self_ty) {
338+
Some(closure_kind) => {
339+
let no_borrows = self
340+
.infcx
341+
.shallow_resolve(args.as_coroutine_closure().tupled_upvars_ty())
342+
.tuple_fields()
343+
.is_empty();
344+
if no_borrows && closure_kind.extends(kind) {
345+
candidates.vec.push(ClosureCandidate { is_const });
346+
} else if kind == ty::ClosureKind::FnOnce {
347+
candidates.vec.push(ClosureCandidate { is_const });
348+
}
349+
}
350+
None => {
351+
if kind == ty::ClosureKind::FnOnce {
352+
candidates.vec.push(ClosureCandidate { is_const });
353+
} else {
354+
// This stays ambiguous until kind+upvars are determined.
355+
candidates.ambiguous = true;
356+
}
357+
}
358+
}
359+
}
335360
ty::Infer(ty::TyVar(_)) => {
336361
debug!("assemble_unboxed_closure_candidates: ambiguous self-type");
337362
candidates.ambiguous = true;

compiler/rustc_trait_selection/src/traits/select/confirmation.rs

+17-9
Original file line numberDiff line numberDiff line change
@@ -865,17 +865,25 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
865865
// touch bound regions, they just capture the in-scope
866866
// type/region parameters.
867867
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
868-
let ty::Closure(closure_def_id, args) = *self_ty.kind() else {
869-
bug!("closure candidate for non-closure {:?}", obligation);
868+
let trait_ref = match *self_ty.kind() {
869+
ty::Closure(_, args) => {
870+
self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_)
871+
}
872+
ty::CoroutineClosure(_, args) => {
873+
args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
874+
ty::TraitRef::new(
875+
self.tcx(),
876+
obligation.predicate.def_id(),
877+
[self_ty, sig.tupled_inputs_ty],
878+
)
879+
})
880+
}
881+
_ => {
882+
bug!("closure candidate for non-closure {:?}", obligation);
883+
}
870884
};
871885

872-
let trait_ref =
873-
self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_);
874-
let nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
875-
876-
debug!(?closure_def_id, ?trait_ref, ?nested, "confirm closure candidate obligations");
877-
878-
Ok(nested)
886+
self.confirm_poly_trait_refs(obligation, trait_ref)
879887
}
880888

881889
#[instrument(skip(self), level = "debug")]

compiler/rustc_ty_utils/src/instance.rs

+18
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,24 @@ fn resolve_associated_item<'tcx>(
278278
def: ty::InstanceDef::FnPtrShim(trait_item_id, rcvr_args.type_at(0)),
279279
args: rcvr_args,
280280
}),
281+
ty::CoroutineClosure(coroutine_closure_def_id, args) => {
282+
// When a coroutine-closure implements the `Fn` traits, then it
283+
// always dispatches to the `FnOnce` implementation. This is to
284+
// ensure that the `closure_kind` of the resulting closure is in
285+
// sync with the built-in trait implementations (since all of the
286+
// implementations return `FnOnce::Output`).
287+
if ty::ClosureKind::FnOnce == args.as_coroutine_closure().kind() {
288+
Some(Instance::new(coroutine_closure_def_id, args))
289+
} else {
290+
Some(Instance {
291+
def: ty::InstanceDef::ConstructCoroutineInClosureShim {
292+
coroutine_closure_def_id,
293+
target_kind: ty::ClosureKind::FnOnce,
294+
},
295+
args,
296+
})
297+
}
298+
}
281299
_ => bug!(
282300
"no built-in definition for `{trait_ref}::{}` for non-fn type",
283301
tcx.item_name(trait_item_id)

0 commit comments

Comments
 (0)