Skip to content

Commit

Permalink
Add or-patterns to pattern types
Browse files Browse the repository at this point in the history
  • Loading branch information
oli-obk committed Mar 11, 2025
1 parent e578e31 commit d0db6bc
Show file tree
Hide file tree
Showing 35 changed files with 477 additions and 10 deletions.
2 changes: 2 additions & 0 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2364,6 +2364,8 @@ pub enum TyPatKind {
/// A range pattern (e.g., `1...2`, `1..2`, `1..`, `..2`, `1..=2`, `..=2`).
Range(Option<P<AnonConst>>, Option<P<AnonConst>>, Spanned<RangeEnd>),

Or(ThinVec<P<TyPat>>),

/// A `!null` pattern for raw pointers.
NotNull,

Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_ast/src/mut_visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ pub fn walk_ty_pat<T: MutVisitor>(vis: &mut T, ty: &mut P<TyPat>) {
visit_opt(start, |c| vis.visit_anon_const(c));
visit_opt(end, |c| vis.visit_anon_const(c));
}
TyPatKind::Or(variants) => visit_thin_vec(variants, |p| vis.visit_ty_pat(p)),
TyPatKind::NotNull | TyPatKind::Err(_) => {}
}
visit_lazy_tts(vis, tokens);
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_ast/src/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ pub fn walk_ty_pat<'a, V: Visitor<'a>>(visitor: &mut V, tp: &'a TyPat) -> V::Res
visit_opt!(visitor, visit_anon_const, start);
visit_opt!(visitor, visit_anon_const, end);
}
TyPatKind::Or(variants) => walk_list!(visitor, visit_ty_pat, variants),
TyPatKind::NotNull | TyPatKind::Err(_) => {}
}
V::Result::output()
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_ast_lowering/src/pat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,11 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
)
}),
),
TyPatKind::Or(variants) => {
hir::TyPatKind::Or(self.arena.alloc_from_iter(
variants.iter().map(|pat| self.lower_ty_pat_mut(pat, base_type)),
))
}
TyPatKind::NotNull => hir::TyPatKind::NotNull,
TyPatKind::Err(guar) => hir::TyPatKind::Err(*guar),
};
Expand Down
11 changes: 11 additions & 0 deletions compiler/rustc_ast_pretty/src/pprust/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,17 @@ impl<'a> State<'a> {
self.print_expr_anon_const(end, &[]);
}
}
rustc_ast::TyPatKind::Or(variants) => {
let mut first = true;
for pat in variants {
if first {
first = false
} else {
self.word(" | ");
}
self.print_ty_pat(pat);
}
}
rustc_ast::TyPatKind::NotNull => self.word("!null"),
rustc_ast::TyPatKind::Err(_) => {
self.popen();
Expand Down
16 changes: 15 additions & 1 deletion compiler/rustc_builtin_macros/src/pattern_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use rustc_ast::{AnonConst, DUMMY_NODE_ID, Ty, TyPat, TyPatKind, ast, token};
use rustc_errors::PResult;
use rustc_expand::base::{self, DummyResult, ExpandResult, ExtCtxt, MacroExpanderResult};
use rustc_parse::exp;
use rustc_parse::parser::{CommaRecoveryMode, RecoverColon, RecoverComma};
use rustc_span::Span;

pub(crate) fn expand<'cx>(
Expand Down Expand Up @@ -33,7 +34,17 @@ fn parse_pat_ty<'a>(cx: &mut ExtCtxt<'a>, stream: TokenStream) -> PResult<'a, (P
let span = start.to(parser.token.span);
ty_pat(TyPatKind::NotNull, span)
} else {
pat_to_ty_pat(cx, parser.parse_pat_no_top_alt(None, None)?.into_inner())
pat_to_ty_pat(
cx,
parser
.parse_pat_no_top_guard(
None,
RecoverComma::No,
RecoverColon::No,
CommaRecoveryMode::EitherTupleOrPipe,
)?
.into_inner(),
)
};

if parser.token != token::Eof {
Expand All @@ -53,6 +64,9 @@ fn pat_to_ty_pat(cx: &mut ExtCtxt<'_>, pat: ast::Pat) -> P<TyPat> {
end.map(|value| P(AnonConst { id: DUMMY_NODE_ID, value })),
include_end,
),
ast::PatKind::Or(variants) => TyPatKind::Or(
variants.into_iter().map(|pat| pat_to_ty_pat(cx, pat.into_inner())).collect(),
),
ast::PatKind::Err(guar) => TyPatKind::Err(guar),
_ => TyPatKind::Err(cx.dcx().span_err(pat.span, "pattern not supported in pattern types")),
};
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_const_eval/src/interpret/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ pub(crate) fn eval_nullary_intrinsic<'tcx>(
ty::Pat(_, pat) => match **pat {
ty::PatternKind::Range { .. } => ConstValue::from_target_usize(0u64, &tcx),
ty::PatternKind::NotNull => ConstValue::from_target_usize(0_u64, &tcx),
// FIXME(pattern_types): make this report the number of distinct variants used in the
// or pattern in case the base type is an enum.
ty::PatternKind::Or(_) => ConstValue::from_target_usize(0_u64, &tcx),
},
ty::Bound(_, _) => bug!("bound ty during ctfe"),
ty::Bool
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_const_eval/src/interpret/validity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,10 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValueVisitor<'tcx, M> for ValidityVisitor<'rt,
// handled fully by `visit_scalar` (called below).
ty::PatternKind::Range { .. } => {},
ty::PatternKind::NotNull => {},

// FIXME(pattern_types): check that the value is covered by one of the variants.
// The layout may pessimistically cover actually illegal ranges.
ty::PatternKind::Or(_patterns) => {}
}
}
_ => {
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,9 @@ pub enum TyPatKind<'hir> {
/// A range pattern (e.g., `1..=2` or `1..2`).
Range(&'hir ConstArg<'hir>, &'hir ConstArg<'hir>),

/// A list of patterns where only one needs to be satisfied
Or(&'hir [TyPat<'hir>]),

/// A pattern that excludes null pointers
NotNull,

Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_hir/src/intravisit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ pub fn walk_ty_pat<'v, V: Visitor<'v>>(visitor: &mut V, pattern: &'v TyPat<'v>)
try_visit!(visitor.visit_const_arg_unambig(lower_bound));
try_visit!(visitor.visit_const_arg_unambig(upper_bound));
}
TyPatKind::Or(patterns) => walk_list!(visitor, visit_pattern_type_pattern, patterns),
TyPatKind::NotNull | TyPatKind::Err(_) => (),
}
V::Result::output()
Expand Down
8 changes: 5 additions & 3 deletions compiler/rustc_hir_analysis/src/collect/type_of.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ fn const_arg_anon_type_of<'tcx>(icx: &ItemCtxt<'tcx>, arg_hir_id: HirId, span: S
}

Node::TyPat(pat) => {
let hir::TyKind::Pat(ty, p) = tcx.parent_hir_node(pat.hir_id).expect_ty().kind else {
bug!()
let node = match tcx.parent_hir_node(pat.hir_id) {
// Or patterns can be nested one level deep
Node::TyPat(p) => tcx.parent_hir_node(p.hir_id),
other => other,
};
assert_eq!(p.hir_id, pat.hir_id);
let hir::TyKind::Pat(ty, _) = node.expect_ty().kind else { bug!() };
icx.lower_ty(ty)
}

Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_hir_analysis/src/hir_ty_lowering/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2709,6 +2709,7 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
ty_span: Span,
pat: &hir::TyPat<'tcx>,
) -> Result<ty::PatternKind<'tcx>, ErrorGuaranteed> {
let tcx = self.tcx();
match pat.kind {
hir::TyPatKind::Range(start, end) => {
match ty.kind() {
Expand All @@ -2724,6 +2725,13 @@ impl<'tcx> dyn HirTyLowerer<'tcx> + '_ {
.span_delayed_bug(ty_span, "invalid base type for range pattern")),
}
}
hir::TyPatKind::Or(patterns) => {
self.tcx()
.mk_patterns_from_iter(patterns.iter().map(|pat| {
self.lower_pat_ty_pat(ty, ty_span, pat).map(|pat| tcx.mk_pat(pat))
}))
.map(ty::PatternKind::Or)
}
hir::TyPatKind::NotNull => Ok(ty::PatternKind::NotNull),
hir::TyPatKind::Err(e) => Err(e),
}
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_hir_analysis/src/variance/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,11 @@ impl<'a, 'tcx> ConstraintContext<'a, 'tcx> {
self.add_constraints_from_const(current, start, variance);
self.add_constraints_from_const(current, end, variance);
}
ty::PatternKind::Or(patterns) => {
for pat in patterns {
self.add_constraints_from_pat(current, variance, pat)
}
}
ty::PatternKind::NotNull => {}
}
}
Expand Down
13 changes: 13 additions & 0 deletions compiler/rustc_hir_pretty/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1877,6 +1877,19 @@ impl<'a> State<'a> {
self.word("..=");
self.print_const_arg(end);
}
TyPatKind::Or(patterns) => {
self.popen();
let mut first = true;
for pat in patterns {
if first {
first = false;
} else {
self.word(" | ");
}
self.print_ty_pat(pat);
}
self.pclose();
}
TyPatKind::NotNull => {
self.word_space("not");
self.word("null");
Expand Down
10 changes: 10 additions & 0 deletions compiler/rustc_lint/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,9 @@ fn pat_ty_is_known_nonnull<'tcx>(
// to ensure we aren't wrapping over zero.
start > 0 && end >= start
}
ty::PatternKind::Or(patterns) => {
patterns.iter().all(|pat| pat_ty_is_known_nonnull(tcx, typing_env, pat))
}
ty::PatternKind::NotNull => true,
}
},
Expand Down Expand Up @@ -1063,6 +1066,13 @@ fn get_nullable_type_from_pat<'tcx>(
ty::PatternKind::NotNull | ty::PatternKind::Range { .. } => {
get_nullable_type(tcx, typing_env, base)
}
ty::PatternKind::Or(patterns) => {
let first = get_nullable_type_from_pat(tcx, typing_env, base, patterns[0])?;
for &pat in &patterns[1..] {
assert_eq!(first, get_nullable_type_from_pat(tcx, typing_env, base, pat)?);
}
Some(first)
}
}
}

Expand Down
10 changes: 10 additions & 0 deletions compiler/rustc_middle/src/ty/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,15 @@ impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> RefDecodable<'tcx, D>
}
}

impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> RefDecodable<'tcx, D> for ty::List<ty::Pattern<'tcx>> {
fn decode(decoder: &mut D) -> &'tcx Self {
let len = decoder.read_usize();
decoder.interner().mk_patterns_from_iter(
(0..len).map::<ty::Pattern<'tcx>, _>(|_| Decodable::decode(decoder)),
)
}
}

impl<'tcx, D: TyDecoder<I = TyCtxt<'tcx>>> RefDecodable<'tcx, D> for ty::List<ty::Const<'tcx>> {
fn decode(decoder: &mut D) -> &'tcx Self {
let len = decoder.read_usize();
Expand Down Expand Up @@ -482,6 +491,7 @@ impl_decodable_via_ref! {
&'tcx mir::Body<'tcx>,
&'tcx mir::BorrowCheckResult<'tcx>,
&'tcx ty::List<ty::BoundVariableKind>,
&'tcx ty::List<ty::Pattern<'tcx>>,
&'tcx ty::ListWithCachedTypeInfo<ty::Clause<'tcx>>,
&'tcx ty::List<FieldIdx>,
&'tcx ty::List<(VariantIdx, FieldIdx)>,
Expand Down
11 changes: 11 additions & 0 deletions compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,7 @@ pub struct CtxtInterners<'tcx> {
captures: InternedSet<'tcx, List<&'tcx ty::CapturedPlace<'tcx>>>,
offset_of: InternedSet<'tcx, List<(VariantIdx, FieldIdx)>>,
valtree: InternedSet<'tcx, ty::ValTreeKind<'tcx>>,
patterns: InternedSet<'tcx, List<ty::Pattern<'tcx>>>,
}

impl<'tcx> CtxtInterners<'tcx> {
Expand Down Expand Up @@ -848,6 +849,7 @@ impl<'tcx> CtxtInterners<'tcx> {
captures: InternedSet::with_capacity(N),
offset_of: InternedSet::with_capacity(N),
valtree: InternedSet::with_capacity(N),
patterns: InternedSet::with_capacity(N),
}
}

Expand Down Expand Up @@ -2594,6 +2596,7 @@ slice_interners!(
local_def_ids: intern_local_def_ids(LocalDefId),
captures: intern_captures(&'tcx ty::CapturedPlace<'tcx>),
offset_of: pub mk_offset_of((VariantIdx, FieldIdx)),
patterns: pub mk_patterns(Pattern<'tcx>),
);

impl<'tcx> TyCtxt<'tcx> {
Expand Down Expand Up @@ -2867,6 +2870,14 @@ impl<'tcx> TyCtxt<'tcx> {
self.intern_local_def_ids(clauses)
}

pub fn mk_patterns_from_iter<I, T>(self, iter: I) -> T::Output
where
I: Iterator<Item = T>,
T: CollectAndApply<ty::Pattern<'tcx>, &'tcx List<ty::Pattern<'tcx>>>,
{
T::collect_and_apply(iter, |xs| self.mk_patterns(xs))
}

pub fn mk_local_def_ids_from_iter<I, T>(self, iter: I) -> T::Output
where
I: Iterator<Item = T>,
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_middle/src/ty/flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,11 @@ impl FlagComputation {
self.add_const(start);
self.add_const(end);
}
ty::PatternKind::Or(patterns) => {
for pat in patterns {
self.add_pat(pat);
}
}
ty::PatternKind::NotNull => {}
}
}
Expand Down
14 changes: 14 additions & 0 deletions compiler/rustc_middle/src/ty/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ impl<'tcx> fmt::Debug for PatternKind<'tcx> {

write!(f, "..={end}")
}
PatternKind::Or(patterns) => {
write!(f, "(")?;
let mut first = true;
for pat in patterns {
if first {
first = false
} else {
write!(f, " | ")?;
}
write!(f, "{pat:?}")?;
}
write!(f, ")")
}
PatternKind::NotNull => write!(f, "!null"),
}
}
Expand All @@ -60,5 +73,6 @@ impl<'tcx> fmt::Debug for PatternKind<'tcx> {
#[derive(HashStable, TyEncodable, TyDecodable, TypeVisitable, TypeFoldable)]
pub enum PatternKind<'tcx> {
Range { start: ty::Const<'tcx>, end: ty::Const<'tcx> },
Or(&'tcx ty::List<Pattern<'tcx>>),
NotNull,
}
15 changes: 12 additions & 3 deletions compiler/rustc_middle/src/ty/relate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,19 @@ impl<'tcx> Relate<TyCtxt<'tcx>> for ty::Pattern<'tcx> {
let end = relation.relate(end_a, end_b)?;
Ok(tcx.mk_pat(ty::PatternKind::Range { start, end }))
}
(ty::PatternKind::NotNull, ty::PatternKind::NotNull) => Ok(a),
(ty::PatternKind::NotNull | ty::PatternKind::Range { .. }, _) => {
Err(TypeError::Mismatch)
(&ty::PatternKind::Or(a), &ty::PatternKind::Or(b)) => {
if a.len() != b.len() {
return Err(TypeError::Mismatch);
}
let v = iter::zip(a, b).map(|(a, b)| relation.relate(a, b));
let patterns = tcx.mk_patterns_from_iter(v)?;
Ok(tcx.mk_pat(ty::PatternKind::Or(patterns)))
}
(ty::PatternKind::NotNull, ty::PatternKind::NotNull) => Ok(a),
(
ty::PatternKind::NotNull | ty::PatternKind::Range { .. } | ty::PatternKind::Or(_),
_,
) => Err(TypeError::Mismatch),
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_middle/src/ty/structural_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -720,3 +720,12 @@ impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for &'tcx ty::List<PlaceElem<'tcx>> {
ty::util::fold_list(self, folder, |tcx, v| tcx.mk_place_elems(v))
}
}

impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for &'tcx ty::List<ty::Pattern<'tcx>> {
fn try_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
self,
folder: &mut F,
) -> Result<Self, F::Error> {
ty::util::fold_list(self, folder, |tcx, v| tcx.mk_patterns(v))
}
}
5 changes: 5 additions & 0 deletions compiler/rustc_middle/src/ty/walk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ fn push_pat<'tcx>(stack: &mut SmallVec<[GenericArg<'tcx>; 8]>, pat: ty::Pattern<
stack.push(end.into());
stack.push(start.into());
}
ty::PatternKind::Or(patterns) => {
for pat in patterns {
push_pat(stack, pat)
}
}
ty::PatternKind::NotNull => {}
}
}
5 changes: 5 additions & 0 deletions compiler/rustc_resolve/src/late.rs
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,11 @@ impl<'ra: 'ast, 'ast, 'tcx> Visitor<'ast> for LateResolutionVisitor<'_, 'ast, 'r
self.resolve_anon_const(end, AnonConstKind::ConstArg(IsRepeatExpr::No));
}
}
TyPatKind::Or(patterns) => {
for pat in patterns {
self.visit_ty_pat(pat)
}
}
TyPatKind::NotNull | TyPatKind::Err(_) => {}
}
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_smir/src/rustc_smir/convert/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ impl<'tcx> Stable<'tcx> for ty::Pattern<'tcx> {
end: Some(end.stable(tables)),
include_end: true,
},
ty::PatternKind::Or(_) => todo!(),
ty::PatternKind::NotNull => stable_mir::ty::Pattern::NotNull,
}
}
Expand Down
Loading

0 comments on commit d0db6bc

Please sign in to comment.