Skip to content

Commit dd508f4

Browse files
derive(SmartPointer): rewrite bounds in where and generic bounds
1 parent a2d5819 commit dd508f4

File tree

2 files changed

+261
-11
lines changed

2 files changed

+261
-11
lines changed

compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs

+192-11
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
use std::mem::swap;
22

33
use ast::HasAttrs;
4+
use rustc_ast::mut_visit::MutVisitor;
45
use rustc_ast::{
56
self as ast, GenericArg, GenericBound, GenericParamKind, ItemKind, MetaItem,
67
TraitBoundModifiers, VariantData,
78
};
89
use rustc_attr as attr;
10+
use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
911
use rustc_expand::base::{Annotatable, ExtCtxt};
1012
use rustc_span::symbol::{sym, Ident};
11-
use rustc_span::Span;
13+
use rustc_span::{Span, Symbol};
1214
use smallvec::{smallvec, SmallVec};
1315
use thin_vec::{thin_vec, ThinVec};
1416

17+
type AstTy = ast::ptr::P<ast::Ty>;
18+
1519
macro_rules! path {
1620
($span:expr, $($part:ident)::*) => { vec![$(Ident::new(sym::$part, $span),)*] }
1721
}
1822

23+
macro_rules! symbols {
24+
($($part:ident)::*) => { [$(sym::$part),*] }
25+
}
26+
1927
pub fn expand_deriving_smart_ptr(
2028
cx: &ExtCtxt<'_>,
2129
span: Span,
@@ -143,31 +151,204 @@ pub fn expand_deriving_smart_ptr(
143151

144152
// Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it.
145153
let mut impl_generics = generics.clone();
154+
let pointee_ty_ident = generics.params[pointee_param_idx].ident;
155+
let mut self_bounds;
146156
{
147157
let p = &mut impl_generics.params[pointee_param_idx];
158+
self_bounds = p.bounds.clone();
148159
let arg = GenericArg::Type(s_ty.clone());
149160
let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]);
150161
p.bounds.push(cx.trait_bound(unsize, false));
151162
let mut attrs = thin_vec![];
152163
swap(&mut p.attrs, &mut attrs);
153164
p.attrs = attrs.into_iter().filter(|attr| !attr.has_name(sym::pointee)).collect();
154165
}
166+
// We should not set default values to constant generic parameters
167+
for params in &mut impl_generics.params {
168+
if let ast::GenericParamKind::Const { default, .. } = &mut params.kind {
169+
*default = None;
170+
}
171+
}
155172

156173
// Add the `__S: ?Sized` extra parameter to the impl block.
174+
// We should also commute the bounds from `#[pointee]` to `__S` as required by `Unsize<__S>`.
157175
let sized = cx.path_global(span, path!(span, core::marker::Sized));
158-
let bound = GenericBound::Trait(
159-
cx.poly_trait_ref(span, sized),
160-
TraitBoundModifiers {
161-
polarity: ast::BoundPolarity::Maybe(span),
162-
constness: ast::BoundConstness::Never,
163-
asyncness: ast::BoundAsyncness::Normal,
164-
},
165-
);
166-
let extra_param = cx.typaram(span, Ident::new(sym::__S, span), vec![bound], None);
167-
impl_generics.params.push(extra_param);
176+
if self_bounds.iter().all(|bound| {
177+
if let GenericBound::Trait(
178+
trait_ref,
179+
TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. },
180+
) = bound
181+
{
182+
!is_sized_marker(&trait_ref.trait_ref.path)
183+
} else {
184+
false
185+
}
186+
}) {
187+
self_bounds.push(GenericBound::Trait(
188+
cx.poly_trait_ref(span, sized),
189+
TraitBoundModifiers {
190+
polarity: ast::BoundPolarity::Maybe(span),
191+
constness: ast::BoundConstness::Never,
192+
asyncness: ast::BoundAsyncness::Normal,
193+
},
194+
));
195+
}
196+
{
197+
let mut substitution =
198+
TypeSubstitution { from_name: pointee_ty_ident.name, to_ty: &s_ty, rewritten: false };
199+
for bound in &mut self_bounds {
200+
substitution.visit_param_bound(bound);
201+
}
202+
}
203+
204+
// We should also commute the where bounds from `#[pointee]` to `__S`
205+
// as well as any bound that implicitly involves the pointee type.
206+
for bound in &generics.where_clause.predicates {
207+
if let ast::WherePredicate::BoundPredicate(bound) = bound {
208+
let bound_on_pointee = bound
209+
.bounded_ty
210+
.kind
211+
.is_simple_path()
212+
.map_or(false, |name| name == pointee_ty_ident.name);
213+
214+
let bounds: Vec<_> = bound
215+
.bounds
216+
.iter()
217+
.filter(|bound| {
218+
if let GenericBound::Trait(
219+
trait_ref,
220+
TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. },
221+
) = bound
222+
{
223+
!bound_on_pointee || !is_sized_marker(&trait_ref.trait_ref.path)
224+
} else {
225+
true
226+
}
227+
})
228+
.cloned()
229+
.collect();
230+
let mut substitution = TypeSubstitution {
231+
from_name: pointee_ty_ident.name,
232+
to_ty: &s_ty,
233+
rewritten: bounds.len() != bound.bounds.len(),
234+
};
235+
let mut predicate = ast::WherePredicate::BoundPredicate(ast::WhereBoundPredicate {
236+
span: bound.span,
237+
bound_generic_params: bound.bound_generic_params.clone(),
238+
bounded_ty: bound.bounded_ty.clone(),
239+
bounds,
240+
});
241+
substitution.visit_where_predicate(&mut predicate);
242+
if substitution.rewritten {
243+
impl_generics.where_clause.predicates.push(predicate);
244+
}
245+
}
246+
}
247+
248+
let extra_param = cx.typaram(span, Ident::new(sym::__S, span), self_bounds, None);
249+
impl_generics.params.insert(pointee_param_idx + 1, extra_param);
168250

169251
// Add the impl blocks for `DispatchFromDyn` and `CoerceUnsized`.
170252
let gen_args = vec![GenericArg::Type(alt_self_type.clone())];
171253
add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone());
172254
add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args.clone());
173255
}
256+
257+
fn is_sized_marker(path: &ast::Path) -> bool {
258+
const CORE_UNSIZE: [Symbol; 3] = symbols!(core::marker::Sized);
259+
const STD_UNSIZE: [Symbol; 3] = symbols!(std::marker::Sized);
260+
if path.segments.len() == 3 {
261+
path.segments.iter().zip(CORE_UNSIZE).all(|(segment, symbol)| segment.ident.name == symbol)
262+
|| path
263+
.segments
264+
.iter()
265+
.zip(STD_UNSIZE)
266+
.all(|(segment, symbol)| segment.ident.name == symbol)
267+
} else {
268+
*path == sym::Sized
269+
}
270+
}
271+
272+
struct TypeSubstitution<'a> {
273+
from_name: Symbol,
274+
to_ty: &'a AstTy,
275+
rewritten: bool,
276+
}
277+
278+
impl<'a> ast::mut_visit::MutVisitor for TypeSubstitution<'a> {
279+
fn visit_ty(&mut self, ty: &mut AstTy) {
280+
if let Some(name) = ty.kind.is_simple_path()
281+
&& name == self.from_name
282+
{
283+
*ty = self.to_ty.clone();
284+
self.rewritten = true;
285+
return;
286+
}
287+
match &mut ty.kind {
288+
ast::TyKind::Slice(_)
289+
| ast::TyKind::Array(_, _)
290+
| ast::TyKind::Ptr(_)
291+
| ast::TyKind::Ref(_, _)
292+
| ast::TyKind::BareFn(_)
293+
| ast::TyKind::Never
294+
| ast::TyKind::Tup(_)
295+
| ast::TyKind::AnonStruct(_, _)
296+
| ast::TyKind::AnonUnion(_, _)
297+
| ast::TyKind::Path(_, _)
298+
| ast::TyKind::TraitObject(_, _)
299+
| ast::TyKind::ImplTrait(_, _)
300+
| ast::TyKind::Paren(_)
301+
| ast::TyKind::Typeof(_)
302+
| ast::TyKind::Infer
303+
| ast::TyKind::MacCall(_)
304+
| ast::TyKind::Pat(_, _) => ast::mut_visit::noop_visit_ty(ty, self),
305+
ast::TyKind::ImplicitSelf
306+
| ast::TyKind::CVarArgs
307+
| ast::TyKind::Dummy
308+
| ast::TyKind::Err(_) => {}
309+
}
310+
}
311+
312+
fn visit_param_bound(&mut self, bound: &mut GenericBound) {
313+
match bound {
314+
GenericBound::Trait(trait_ref, _) => {
315+
if trait_ref
316+
.bound_generic_params
317+
.iter()
318+
.any(|param| param.ident.name == self.from_name)
319+
{
320+
return;
321+
}
322+
self.visit_poly_trait_ref(trait_ref);
323+
}
324+
325+
GenericBound::Use(args, _span) => {
326+
for arg in args {
327+
self.visit_precise_capturing_arg(arg);
328+
}
329+
}
330+
GenericBound::Outlives(_) => {}
331+
}
332+
}
333+
334+
fn visit_where_predicate(&mut self, where_predicate: &mut ast::WherePredicate) {
335+
match where_predicate {
336+
rustc_ast::WherePredicate::BoundPredicate(bound) => {
337+
if bound.bound_generic_params.iter().any(|param| param.ident.name == self.from_name)
338+
{
339+
// Name is shadowed so we must skip the rest
340+
return;
341+
}
342+
bound
343+
.bound_generic_params
344+
.flat_map_in_place(|param| self.flat_map_generic_param(param));
345+
self.visit_ty(&mut bound.bounded_ty);
346+
for bound in &mut bound.bounds {
347+
self.visit_param_bound(bound)
348+
}
349+
}
350+
rustc_ast::WherePredicate::RegionPredicate(_)
351+
| rustc_ast::WherePredicate::EqPredicate(_) => {}
352+
}
353+
}
354+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
//@ check-pass
2+
3+
#![feature(derive_smart_pointer)]
4+
5+
#[derive(core::marker::SmartPointer)]
6+
#[repr(transparent)]
7+
pub struct Ptr<'a, #[pointee] T: OnDrop + ?Sized, X> {
8+
data: &'a mut T,
9+
x: core::marker::PhantomData<X>,
10+
}
11+
12+
pub trait OnDrop {
13+
fn on_drop(&mut self);
14+
}
15+
16+
#[derive(core::marker::SmartPointer)]
17+
#[repr(transparent)]
18+
pub struct Ptr2<'a, #[pointee] T: ?Sized, X>
19+
where
20+
T: OnDrop,
21+
{
22+
data: &'a mut T,
23+
x: core::marker::PhantomData<X>,
24+
}
25+
26+
pub trait MyTrait<T: ?Sized> {}
27+
28+
#[derive(core::marker::SmartPointer)]
29+
#[repr(transparent)]
30+
pub struct Ptr3<'a, #[pointee] T: ?Sized, X>
31+
where
32+
T: MyTrait<T>,
33+
{
34+
data: &'a mut T,
35+
x: core::marker::PhantomData<X>,
36+
}
37+
38+
#[derive(core::marker::SmartPointer)]
39+
#[repr(transparent)]
40+
pub struct Ptr4<'a, #[pointee] T: MyTrait<T> + ?Sized, X> {
41+
data: &'a mut T,
42+
x: core::marker::PhantomData<X>,
43+
}
44+
45+
#[derive(core::marker::SmartPointer)]
46+
#[repr(transparent)]
47+
pub struct Ptr5<'a, #[pointee] T: ?Sized, X>
48+
where
49+
Ptr5Companion<T>: MyTrait<T>,
50+
{
51+
data: &'a mut T,
52+
x: core::marker::PhantomData<X>,
53+
}
54+
55+
pub struct Ptr5Companion<T: ?Sized>(core::marker::PhantomData<T>);
56+
57+
// a reduced example from https://lore.kernel.org/all/[email protected]/
58+
#[repr(transparent)]
59+
#[derive(core::marker::SmartPointer)]
60+
pub struct ListArc<#[pointee] T, const ID: u64 = 0>
61+
where
62+
T: ListArcSafe<ID> + ?Sized,
63+
{
64+
arc: *const T,
65+
}
66+
67+
pub trait ListArcSafe<const ID: u64> {}
68+
69+
fn main() {}

0 commit comments

Comments
 (0)