Skip to content

Commit 6a30a7a

Browse files
Consider principal trait ref's auto-trait super-traits in dyn upcasting
1 parent 8c0b4f6 commit 6a30a7a

File tree

4 files changed

+93
-51
lines changed

4 files changed

+93
-51
lines changed

compiler/rustc_trait_selection/src/solve/trait_goals.rs

+15-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
//! Dealing with trait goals, i.e. `T: Trait<'a, U>`.
22
3+
use crate::traits::supertrait_def_ids;
4+
35
use super::assembly::{self, structural_traits, Candidate};
46
use super::{EvalCtxt, GoalSource, SolverMode};
7+
use rustc_data_structures::fx::FxIndexSet;
58
use rustc_hir::def_id::DefId;
69
use rustc_hir::{LangItem, Movability};
710
use rustc_infer::traits::query::NoSolution;
@@ -600,13 +603,6 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
600603
let tcx = self.tcx();
601604
let Goal { predicate: (a_ty, _b_ty), .. } = goal;
602605

603-
// All of a's auto traits need to be in b's auto traits.
604-
let auto_traits_compatible =
605-
b_data.auto_traits().all(|b| a_data.auto_traits().any(|a| a == b));
606-
if !auto_traits_compatible {
607-
return vec![];
608-
}
609-
610606
let mut responses = vec![];
611607
// If the principal def ids match (or are both none), then we're not doing
612608
// trait upcasting. We're just removing auto traits (or shortening the lifetime).
@@ -694,6 +690,17 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
694690
) -> QueryResult<'tcx> {
695691
let param_env = goal.param_env;
696692

693+
// We may upcast to auto traits that are either explicitly listed in
694+
// the object type's bounds, or implied by the principal trait ref's
695+
// supertraits.
696+
let a_auto_traits: FxIndexSet<DefId> = a_data
697+
.auto_traits()
698+
.chain(a_data.principal_def_id().into_iter().flat_map(|principal_def_id| {
699+
supertrait_def_ids(self.tcx(), principal_def_id)
700+
.filter(|def_id| self.tcx().trait_is_auto(*def_id))
701+
}))
702+
.collect();
703+
697704
// More than one projection in a_ty's bounds may match the projection
698705
// in b_ty's bound. Use this to first determine *which* apply without
699706
// having any inference side-effects. We process obligations because
@@ -743,7 +750,7 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
743750
}
744751
// Check that b_ty's auto traits are present in a_ty's bounds.
745752
ty::ExistentialPredicate::AutoTrait(def_id) => {
746-
if !a_data.auto_traits().any(|source_def_id| source_def_id == def_id) {
753+
if !a_auto_traits.contains(&def_id) {
747754
return Err(NoSolution);
748755
}
749756
}

compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs

+52-42
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
99
use hir::def_id::DefId;
1010
use hir::LangItem;
11+
use rustc_data_structures::fx::FxIndexSet;
1112
use rustc_hir as hir;
1213
use rustc_infer::traits::ObligationCause;
1314
use rustc_infer::traits::{Obligation, PolyTraitObligation, SelectionError};
@@ -807,52 +808,61 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
807808
//
808809
// We always perform upcasting coercions when we can because of reason
809810
// #2 (region bounds).
810-
let auto_traits_compatible = b_data
811-
.auto_traits()
812-
// All of a's auto traits need to be in b's auto traits.
813-
.all(|b| a_data.auto_traits().any(|a| a == b));
814-
if auto_traits_compatible {
815-
let principal_def_id_a = a_data.principal_def_id();
816-
let principal_def_id_b = b_data.principal_def_id();
817-
if principal_def_id_a == principal_def_id_b {
818-
// no cyclic
811+
let principal_def_id_a = a_data.principal_def_id();
812+
let principal_def_id_b = b_data.principal_def_id();
813+
if principal_def_id_a == principal_def_id_b {
814+
// We may upcast to auto traits that are either explicitly listed in
815+
// the object type's bounds, or implied by the principal trait ref's
816+
// supertraits.
817+
let a_auto_traits: FxIndexSet<DefId> = a_data
818+
.auto_traits()
819+
.chain(principal_def_id_a.into_iter().flat_map(|principal_def_id| {
820+
util::supertrait_def_ids(self.tcx(), principal_def_id)
821+
.filter(|def_id| self.tcx().trait_is_auto(*def_id))
822+
}))
823+
.collect();
824+
let auto_traits_compatible = b_data
825+
.auto_traits()
826+
// All of a's auto traits need to be in b's auto traits.
827+
.all(|b| a_auto_traits.contains(&b));
828+
if auto_traits_compatible {
819829
candidates.vec.push(BuiltinUnsizeCandidate);
820-
} else if principal_def_id_a.is_some() && principal_def_id_b.is_some() {
821-
// not casual unsizing, now check whether this is trait upcasting coercion.
822-
let principal_a = a_data.principal().unwrap();
823-
let target_trait_did = principal_def_id_b.unwrap();
824-
let source_trait_ref = principal_a.with_self_ty(self.tcx(), source);
825-
if let Some(deref_trait_ref) = self.need_migrate_deref_output_trait_object(
826-
source,
827-
obligation.param_env,
828-
&obligation.cause,
829-
) {
830-
if deref_trait_ref.def_id() == target_trait_did {
831-
return;
832-
}
830+
}
831+
} else if principal_def_id_a.is_some() && principal_def_id_b.is_some() {
832+
// not casual unsizing, now check whether this is trait upcasting coercion.
833+
let principal_a = a_data.principal().unwrap();
834+
let target_trait_did = principal_def_id_b.unwrap();
835+
let source_trait_ref = principal_a.with_self_ty(self.tcx(), source);
836+
if let Some(deref_trait_ref) = self.need_migrate_deref_output_trait_object(
837+
source,
838+
obligation.param_env,
839+
&obligation.cause,
840+
) {
841+
if deref_trait_ref.def_id() == target_trait_did {
842+
return;
833843
}
844+
}
834845

835-
for (idx, upcast_trait_ref) in
836-
util::supertraits(self.tcx(), source_trait_ref).enumerate()
837-
{
838-
self.infcx.probe(|_| {
839-
if upcast_trait_ref.def_id() == target_trait_did
840-
&& let Ok(nested) = self.match_upcast_principal(
841-
obligation,
842-
upcast_trait_ref,
843-
a_data,
844-
b_data,
845-
a_region,
846-
b_region,
847-
)
848-
{
849-
if nested.is_none() {
850-
candidates.ambiguous = true;
851-
}
852-
candidates.vec.push(TraitUpcastingUnsizeCandidate(idx));
846+
for (idx, upcast_trait_ref) in
847+
util::supertraits(self.tcx(), source_trait_ref).enumerate()
848+
{
849+
self.infcx.probe(|_| {
850+
if upcast_trait_ref.def_id() == target_trait_did
851+
&& let Ok(nested) = self.match_upcast_principal(
852+
obligation,
853+
upcast_trait_ref,
854+
a_data,
855+
b_data,
856+
a_region,
857+
b_region,
858+
)
859+
{
860+
if nested.is_none() {
861+
candidates.ambiguous = true;
853862
}
854-
})
855-
}
863+
candidates.vec.push(TraitUpcastingUnsizeCandidate(idx));
864+
}
865+
})
856866
}
857867
}
858868
}

compiler/rustc_trait_selection/src/traits/select/mod.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -2513,6 +2513,17 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
25132513
let tcx = self.tcx();
25142514
let mut nested = vec![];
25152515

2516+
// We may upcast to auto traits that are either explicitly listed in
2517+
// the object type's bounds, or implied by the principal trait ref's
2518+
// supertraits.
2519+
let a_auto_traits: FxIndexSet<DefId> = a_data
2520+
.auto_traits()
2521+
.chain(a_data.principal_def_id().into_iter().flat_map(|principal_def_id| {
2522+
util::supertrait_def_ids(tcx, principal_def_id)
2523+
.filter(|def_id| tcx.trait_is_auto(*def_id))
2524+
}))
2525+
.collect();
2526+
25162527
let upcast_principal = normalize_with_depth_to(
25172528
self,
25182529
obligation.param_env,
@@ -2575,7 +2586,7 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
25752586
}
25762587
// Check that b_ty's auto traits are present in a_ty's bounds.
25772588
ty::ExistentialPredicate::AutoTrait(def_id) => {
2578-
if !a_data.auto_traits().any(|source_def_id| source_def_id == def_id) {
2589+
if !a_auto_traits.contains(&def_id) {
25792590
return Err(SelectionError::Unimplemented);
25802591
}
25812592
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// check-pass
2+
// revisions: current next
3+
//[next] compile-flags: -Znext-solver
4+
5+
#![feature(trait_upcasting)]
6+
7+
trait Target {}
8+
trait Source: Send + Target {}
9+
10+
fn upcast(x: &dyn Source) -> &(dyn Target + Send) { x }
11+
12+
fn same(x: &dyn Source) -> &(dyn Source + Send) { x }
13+
14+
fn main() {}

0 commit comments

Comments
 (0)