Skip to content

Commit 9dc9a7a

Browse files
bors[bot]flodiebold
andcommitted
Merge #1496
1496: Add trait obligations for where clauses when calling functions/methods r=matklad a=flodiebold E.g. if we call `foo<T: Into<u32>>(x)`, that adds an obligation that `x: Into<u32>`, etc., which sometimes allows type inference to make further progress. Co-authored-by: Florian Diebold <[email protected]>
2 parents 219e0e8 + f854a29 commit 9dc9a7a

File tree

9 files changed

+171
-42
lines changed

9 files changed

+171
-42
lines changed

crates/ra_hir/src/db.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,10 @@ pub trait HirDatabase: DefDatabase + AstDatabase {
163163
#[salsa::invoke(crate::ty::callable_item_sig)]
164164
fn callable_item_signature(&self, def: CallableDef) -> FnSig;
165165

166-
#[salsa::invoke(crate::ty::generic_predicates)]
166+
#[salsa::invoke(crate::ty::generic_predicates_query)]
167167
fn generic_predicates(&self, def: GenericDef) -> Arc<[GenericPredicate]>;
168168

169-
#[salsa::invoke(crate::ty::generic_defaults)]
169+
#[salsa::invoke(crate::ty::generic_defaults_query)]
170170
fn generic_defaults(&self, def: GenericDef) -> Substs;
171171

172172
#[salsa::invoke(crate::expr::body_with_source_map_query)]

crates/ra_hir/src/generics.rs

+9-3
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ use crate::{
1111
db::{AstDatabase, DefDatabase, HirDatabase},
1212
path::Path,
1313
type_ref::TypeRef,
14-
AdtDef, AsName, Container, Enum, Function, HasSource, ImplBlock, Name, Struct, Trait,
15-
TypeAlias, Union,
14+
AdtDef, AsName, Container, Enum, EnumVariant, Function, HasSource, ImplBlock, Name, Struct,
15+
Trait, TypeAlias, Union,
1616
};
1717

1818
/// Data about a generic parameter (to a function, struct, impl, ...).
@@ -50,8 +50,11 @@ pub enum GenericDef {
5050
Trait(Trait),
5151
TypeAlias(TypeAlias),
5252
ImplBlock(ImplBlock),
53+
// enum variants cannot have generics themselves, but their parent enums
54+
// can, and this makes some code easier to write
55+
EnumVariant(EnumVariant),
5356
}
54-
impl_froms!(GenericDef: Function, Struct, Union, Enum, Trait, TypeAlias, ImplBlock);
57+
impl_froms!(GenericDef: Function, Struct, Union, Enum, Trait, TypeAlias, ImplBlock, EnumVariant);
5558

5659
impl GenericParams {
5760
pub(crate) fn generic_params_query(
@@ -62,6 +65,7 @@ impl GenericParams {
6265
let parent = match def {
6366
GenericDef::Function(it) => it.container(db).map(GenericDef::from),
6467
GenericDef::TypeAlias(it) => it.container(db).map(GenericDef::from),
68+
GenericDef::EnumVariant(it) => Some(it.parent_enum(db).into()),
6569
GenericDef::Struct(_)
6670
| GenericDef::Union(_)
6771
| GenericDef::Enum(_)
@@ -86,6 +90,7 @@ impl GenericParams {
8690
}
8791
GenericDef::TypeAlias(it) => generics.fill(&*it.source(db).ast, start),
8892
GenericDef::ImplBlock(it) => generics.fill(&*it.source(db).ast, start),
93+
GenericDef::EnumVariant(_) => {}
8994
}
9095

9196
Arc::new(generics)
@@ -184,6 +189,7 @@ impl GenericDef {
184189
GenericDef::Trait(inner) => inner.resolver(db),
185190
GenericDef::TypeAlias(inner) => inner.resolver(db),
186191
GenericDef::ImplBlock(inner) => inner.resolver(db),
192+
GenericDef::EnumVariant(inner) => inner.parent_enum(db).resolver(db),
187193
}
188194
}
189195
}

crates/ra_hir/src/resolve.rs

+12
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,18 @@ impl Resolver {
221221
pub(crate) fn krate(&self) -> Option<Crate> {
222222
self.module().map(|t| t.0.krate())
223223
}
224+
225+
pub(crate) fn where_predicates_in_scope<'a>(
226+
&'a self,
227+
) -> impl Iterator<Item = &'a crate::generics::WherePredicate> + 'a {
228+
self.scopes
229+
.iter()
230+
.filter_map(|scope| match scope {
231+
Scope::GenericParams(params) => Some(params),
232+
_ => None,
233+
})
234+
.flat_map(|params| params.where_predicates.iter())
235+
}
224236
}
225237

226238
impl Resolver {

crates/ra_hir/src/ty.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ pub(crate) use autoderef::autoderef;
2323
pub(crate) use infer::{infer_query, InferTy, InferenceResult};
2424
pub use lower::CallableDef;
2525
pub(crate) use lower::{
26-
callable_item_sig, generic_defaults, generic_predicates, type_for_def, type_for_field,
27-
TypableDef,
26+
callable_item_sig, generic_defaults_query, generic_predicates_query, type_for_def,
27+
type_for_field, TypableDef,
2828
};
2929
pub(crate) use traits::ProjectionPredicate;
3030

crates/ra_hir/src/ty/infer.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,14 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
849849
fn register_obligations_for_call(&mut self, callable_ty: &Ty) {
850850
if let Ty::Apply(a_ty) = callable_ty {
851851
if let TypeCtor::FnDef(def) = a_ty.ctor {
852+
let generic_predicates = self.db.generic_predicates(def.into());
853+
for predicate in generic_predicates.iter() {
854+
let predicate = predicate.clone().subst(&a_ty.parameters);
855+
if let Some(obligation) = Obligation::from_predicate(predicate) {
856+
self.obligations.push(obligation);
857+
}
858+
}
852859
// add obligation for trait implementation, if this is a trait method
853-
// FIXME also register obligations from where clauses from the trait or impl and method
854860
match def {
855861
CallableDef::Function(f) => {
856862
if let Some(trait_) = f.parent_trait(self.db) {
@@ -992,7 +998,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
992998
(Vec::new(), Ty::Unknown)
993999
}
9941000
};
995-
// FIXME register obligations from where clauses from the function
1001+
self.register_obligations_for_call(&callee_ty);
9961002
let param_iter = param_tys.into_iter().chain(repeat(Ty::Unknown));
9971003
for (arg, param) in args.iter().zip(param_iter) {
9981004
self.infer_expr(*arg, &Expectation::has_type(param));

crates/ra_hir/src/ty/lower.rs

+14-6
Original file line numberDiff line numberDiff line change
@@ -318,15 +318,13 @@ pub(crate) fn type_for_field(db: &impl HirDatabase, field: StructField) -> Ty {
318318
}
319319

320320
/// Resolve the where clause(s) of an item with generics.
321-
pub(crate) fn generic_predicates(
321+
pub(crate) fn generic_predicates_query(
322322
db: &impl HirDatabase,
323323
def: GenericDef,
324324
) -> Arc<[GenericPredicate]> {
325325
let resolver = def.resolver(db);
326-
let generic_params = def.generic_params(db);
327-
let predicates = generic_params
328-
.where_predicates
329-
.iter()
326+
let predicates = resolver
327+
.where_predicates_in_scope()
330328
.map(|pred| {
331329
TraitRef::for_where_predicate(db, &resolver, pred)
332330
.map_or(GenericPredicate::Error, GenericPredicate::Implemented)
@@ -336,7 +334,7 @@ pub(crate) fn generic_predicates(
336334
}
337335

338336
/// Resolve the default type params from generics
339-
pub(crate) fn generic_defaults(db: &impl HirDatabase, def: GenericDef) -> Substs {
337+
pub(crate) fn generic_defaults_query(db: &impl HirDatabase, def: GenericDef) -> Substs {
340338
let resolver = def.resolver(db);
341339
let generic_params = def.generic_params(db);
342340

@@ -511,3 +509,13 @@ pub enum CallableDef {
511509
EnumVariant(EnumVariant),
512510
}
513511
impl_froms!(CallableDef: Function, Struct, EnumVariant);
512+
513+
impl From<CallableDef> for GenericDef {
514+
fn from(def: CallableDef) -> GenericDef {
515+
match def {
516+
CallableDef::Function(f) => f.into(),
517+
CallableDef::Struct(s) => s.into(),
518+
CallableDef::EnumVariant(e) => e.into(),
519+
}
520+
}
521+
}

crates/ra_hir/src/ty/tests.rs

+113-21
Original file line numberDiff line numberDiff line change
@@ -2232,16 +2232,18 @@ fn test() {
22322232
}
22332233
"#),
22342234
@r###"
2235-
[86; 87) 't': T
2236-
[92; 94) '{}': ()
2237-
[105; 144) '{ ...(s); }': ()
2238-
[115; 116) 's': S<{unknown}>
2239-
[119; 120) 'S': S<{unknown}>(T) -> S<T>
2240-
[119; 129) 'S(unknown)': S<{unknown}>
2241-
[121; 128) 'unknown': {unknown}
2242-
[135; 138) 'foo': fn foo<S<{unknown}>>(T) -> ()
2243-
[135; 141) 'foo(s)': ()
2244-
[139; 140) 's': S<{unknown}>"###
2235+
2236+
⋮[86; 87) 't': T
2237+
⋮[92; 94) '{}': ()
2238+
⋮[105; 144) '{ ...(s); }': ()
2239+
⋮[115; 116) 's': S<u32>
2240+
⋮[119; 120) 'S': S<u32>(T) -> S<T>
2241+
⋮[119; 129) 'S(unknown)': S<u32>
2242+
⋮[121; 128) 'unknown': u32
2243+
⋮[135; 138) 'foo': fn foo<S<u32>>(T) -> ()
2244+
⋮[135; 141) 'foo(s)': ()
2245+
⋮[139; 140) 's': S<u32>
2246+
"###
22452247
);
22462248
}
22472249

@@ -2259,17 +2261,19 @@ fn test() {
22592261
}
22602262
"#),
22612263
@r###"
2262-
[87; 88) 't': T
2263-
[98; 100) '{}': ()
2264-
[111; 163) '{ ...(s); }': ()
2265-
[121; 122) 's': S<{unknown}>
2266-
[125; 126) 'S': S<{unknown}>(T) -> S<T>
2267-
[125; 135) 'S(unknown)': S<{unknown}>
2268-
[127; 134) 'unknown': {unknown}
2269-
[145; 146) 'x': u32
2270-
[154; 157) 'foo': fn foo<u32, S<{unknown}>>(T) -> U
2271-
[154; 160) 'foo(s)': u32
2272-
[158; 159) 's': S<{unknown}>"###
2264+
2265+
⋮[87; 88) 't': T
2266+
⋮[98; 100) '{}': ()
2267+
⋮[111; 163) '{ ...(s); }': ()
2268+
⋮[121; 122) 's': S<u32>
2269+
⋮[125; 126) 'S': S<u32>(T) -> S<T>
2270+
⋮[125; 135) 'S(unknown)': S<u32>
2271+
⋮[127; 134) 'unknown': u32
2272+
⋮[145; 146) 'x': u32
2273+
⋮[154; 157) 'foo': fn foo<u32, S<u32>>(T) -> U
2274+
⋮[154; 160) 'foo(s)': u32
2275+
⋮[158; 159) 's': S<u32>
2276+
"###
22732277
);
22742278
}
22752279

@@ -2822,6 +2826,94 @@ fn test(s: S) {
28222826
assert_eq!(t, "{unknown}");
28232827
}
28242828

2829+
#[test]
2830+
fn obligation_from_function_clause() {
2831+
let t = type_at(
2832+
r#"
2833+
//- /main.rs
2834+
struct S;
2835+
2836+
trait Trait<T> {}
2837+
impl Trait<u32> for S {}
2838+
2839+
fn foo<T: Trait<U>, U>(t: T) -> U {}
2840+
2841+
fn test(s: S) {
2842+
foo(s)<|>;
2843+
}
2844+
"#,
2845+
);
2846+
assert_eq!(t, "u32");
2847+
}
2848+
2849+
#[test]
2850+
fn obligation_from_method_clause() {
2851+
let t = type_at(
2852+
r#"
2853+
//- /main.rs
2854+
struct S;
2855+
2856+
trait Trait<T> {}
2857+
impl Trait<isize> for S {}
2858+
2859+
struct O;
2860+
impl O {
2861+
fn foo<T: Trait<U>, U>(&self, t: T) -> U {}
2862+
}
2863+
2864+
fn test() {
2865+
O.foo(S)<|>;
2866+
}
2867+
"#,
2868+
);
2869+
assert_eq!(t, "isize");
2870+
}
2871+
2872+
#[test]
2873+
fn obligation_from_self_method_clause() {
2874+
let t = type_at(
2875+
r#"
2876+
//- /main.rs
2877+
struct S;
2878+
2879+
trait Trait<T> {}
2880+
impl Trait<i64> for S {}
2881+
2882+
impl S {
2883+
fn foo<U>(&self) -> U where Self: Trait<U> {}
2884+
}
2885+
2886+
fn test() {
2887+
S.foo()<|>;
2888+
}
2889+
"#,
2890+
);
2891+
assert_eq!(t, "i64");
2892+
}
2893+
2894+
#[test]
2895+
fn obligation_from_impl_clause() {
2896+
let t = type_at(
2897+
r#"
2898+
//- /main.rs
2899+
struct S;
2900+
2901+
trait Trait<T> {}
2902+
impl Trait<&str> for S {}
2903+
2904+
struct O<T>;
2905+
impl<U, T: Trait<U>> O<T> {
2906+
fn foo(&self) -> U {}
2907+
}
2908+
2909+
fn test(o: O<S>) {
2910+
o.foo()<|>;
2911+
}
2912+
"#,
2913+
);
2914+
assert_eq!(t, "&str");
2915+
}
2916+
28252917
fn type_at_pos(db: &MockDatabase, pos: FilePosition) -> String {
28262918
let file = db.parse(pos.file_id).ok().unwrap();
28272919
let expr = algo::find_node_at_offset::<ast::Expr>(file.syntax(), pos.offset).unwrap();

crates/ra_hir/src/ty/traits.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use parking_lot::Mutex;
77
use ra_prof::profile;
88
use rustc_hash::FxHashSet;
99

10-
use super::{Canonical, ProjectionTy, TraitRef, Ty};
10+
use super::{Canonical, GenericPredicate, ProjectionTy, TraitRef, Ty};
1111
use crate::{db::HirDatabase, Crate, ImplBlock, Trait};
1212

1313
use self::chalk::{from_chalk, ToChalk};
@@ -78,6 +78,15 @@ pub enum Obligation {
7878
// Projection(ProjectionPredicate),
7979
}
8080

81+
impl Obligation {
82+
pub fn from_predicate(predicate: GenericPredicate) -> Option<Obligation> {
83+
match predicate {
84+
GenericPredicate::Implemented(trait_ref) => Some(Obligation::Trait(trait_ref)),
85+
GenericPredicate::Error => None,
86+
}
87+
}
88+
}
89+
8190
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
8291
pub struct ProjectionPredicate {
8392
pub projection_ty: ProjectionTy,

crates/ra_hir/src/ty/traits/chalk.rs

+1-5
Original file line numberDiff line numberDiff line change
@@ -428,11 +428,7 @@ pub(crate) fn struct_datum_query(
428428
CallableDef::Struct(s) => s.module(db).krate(db),
429429
CallableDef::EnumVariant(v) => v.parent_enum(db).module(db).krate(db),
430430
} != Some(krate);
431-
let generic_def: GenericDef = match callable {
432-
CallableDef::Function(f) => f.into(),
433-
CallableDef::Struct(s) => s.into(),
434-
CallableDef::EnumVariant(v) => v.parent_enum(db).into(),
435-
};
431+
let generic_def: GenericDef = callable.into();
436432
let generic_params = generic_def.generic_params(db);
437433
let bound_vars = Substs::bound_vars(&generic_params);
438434
let where_clauses = convert_where_clauses(db, generic_def, &bound_vars);

0 commit comments

Comments
 (0)