Skip to content

pattern analysis: fix union handling #123301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion compiler/rustc_middle/src/thir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,8 @@ impl<'tcx> fmt::Display for Pat<'tcx> {
printed += 1;
}

if printed < variant.fields.len() {
let is_union = self.ty.ty_adt_def().is_some_and(|adt| adt.is_union());
if printed < variant.fields.len() && (!is_union || printed == 0) {
write!(f, "{}..", start_or_comma())?;
}

Expand Down
29 changes: 28 additions & 1 deletion compiler/rustc_pattern_analysis/src/constructor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,34 @@
//! [`ConstructorSet::split`]. The invariants of [`SplitConstructorSet`] are also of interest.
//!
//!
//! ## Unions
//!
//! Unions allow us to match a value via several overlapping representations at the same time. For
//! example, the following is exhaustive because when seeing the value as a boolean we handled all
//! possible cases (other cases such as `n == 3` would trigger UB).
//!
//! ```rust
//! # fn main() {
//! union U8AsBool {
//! n: u8,
//! b: bool,
//! }
//! let x = U8AsBool { n: 1 };
//! unsafe {
//! match x {
//! U8AsBool { n: 2 } => {}
//! U8AsBool { b: true } => {}
//! U8AsBool { b: false } => {}
//! }
//! }
//! # }
//! ```
//!
//! Pattern-matching has no knowledge that e.g. `false as u8 == 0`, so the values we consider in the
//! algorithm look like `U8AsBool { b: true, n: 2 }`. In other words, for the most part a union is
//! treated like a struct with the same fields. The difference lies in how we construct witnesses of
//! non-exhaustiveness.
//!
//!
//! ## Opaque patterns
//!
Expand Down Expand Up @@ -974,7 +1002,6 @@ impl<Cx: PatCx> ConstructorSet<Cx> {
/// any) are missing; 2/ split constructors to handle non-trivial intersections e.g. on ranges
/// or slices. This can get subtle; see [`SplitConstructorSet`] for details of this operation
/// and its invariants.
#[instrument(level = "debug", skip(self, ctors), ret)]
pub fn split<'a>(
&self,
ctors: impl Iterator<Item = &'a Constructor<Cx>> + Clone,
Expand Down
2 changes: 0 additions & 2 deletions compiler/rustc_pattern_analysis/src/rustc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {

/// Returns the types of the fields for a given constructor. The result must have a length of
/// `ctor.arity()`.
#[instrument(level = "trace", skip(self))]
pub(crate) fn ctor_sub_tys<'a>(
&'a self,
ctor: &'a Constructor<'p, 'tcx>,
Expand Down Expand Up @@ -283,7 +282,6 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
/// Creates a set that represents all the constructors of `ty`.
///
/// See [`crate::constructor`] for considerations of emptiness.
#[instrument(level = "debug", skip(self), ret)]
pub fn ctors_for_ty(
&self,
ty: RevealedTy<'tcx>,
Expand Down
49 changes: 39 additions & 10 deletions compiler/rustc_pattern_analysis/src/usefulness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -871,12 +871,14 @@ impl<Cx: PatCx> PlaceInfo<Cx> {
where
Cx: 'a,
{
debug!(?self.ty);
if self.private_uninhabited {
// Skip the whole column
return Ok((smallvec![Constructor::PrivateUninhabited], vec![]));
}

let ctors_for_ty = cx.ctors_for_ty(&self.ty)?;
debug!(?ctors_for_ty);

// We treat match scrutinees of type `!` or `EmptyEnum` differently.
let is_toplevel_exception =
Expand All @@ -895,6 +897,7 @@ impl<Cx: PatCx> PlaceInfo<Cx> {

// Analyze the constructors present in this column.
let mut split_set = ctors_for_ty.split(ctors);
debug!(?split_set);
let all_missing = split_set.present.is_empty();

// Build the set of constructors we will specialize with. It must cover the whole type, so
Expand Down Expand Up @@ -1254,7 +1257,7 @@ impl<'p, Cx: PatCx> Matrix<'p, Cx> {
/// + true + [Second(true)] +
/// + false + [_] +
/// + _ + [_, _, tail @ ..] +
/// | ✓ | ? | // column validity
/// | ✓ | ? | // validity
/// ```
impl<'p, Cx: PatCx> fmt::Debug for Matrix<'p, Cx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand Down Expand Up @@ -1285,7 +1288,7 @@ impl<'p, Cx: PatCx> fmt::Debug for Matrix<'p, Cx> {
write!(f, " {sep}")?;
}
if is_validity_row {
write!(f, " // column validity")?;
write!(f, " // validity")?;
}
write!(f, "\n")?;
}
Expand Down Expand Up @@ -1381,12 +1384,35 @@ impl<Cx: PatCx> WitnessStack<Cx> {
/// pats: [(false, "foo"), _, true]
/// result: [Enum::Variant { a: (false, "foo"), b: _ }, true]
/// ```
fn apply_constructor(&mut self, pcx: &PlaceCtxt<'_, Cx>, ctor: &Constructor<Cx>) {
fn apply_constructor(
mut self,
pcx: &PlaceCtxt<'_, Cx>,
ctor: &Constructor<Cx>,
) -> SmallVec<[Self; 1]> {
let len = self.0.len();
let arity = pcx.ctor_arity(ctor);
let fields = self.0.drain((len - arity)..).rev().collect();
let pat = WitnessPat::new(ctor.clone(), fields, pcx.ty.clone());
self.0.push(pat);
let fields: Vec<_> = self.0.drain((len - arity)..).rev().collect();
if matches!(ctor, Constructor::UnionField)
&& fields.iter().filter(|p| !matches!(p.ctor(), Constructor::Wildcard)).count() >= 2
{
// Convert a `Union { a: p, b: q }` witness into `Union { a: p }` and `Union { b: q }`.
// First add `Union { .. }` to `self`.
self.0.push(WitnessPat::wild_from_ctor(pcx.cx, ctor.clone(), pcx.ty.clone()));
fields
.into_iter()
.enumerate()
.filter(|(_, p)| !matches!(p.ctor(), Constructor::Wildcard))
.map(|(i, p)| {
let mut ret = self.clone();
// Fill the `i`th field of the union with `p`.
ret.0.last_mut().unwrap().fields[i] = p;
ret
})
.collect()
} else {
self.0.push(WitnessPat::new(ctor.clone(), fields, pcx.ty.clone()));
smallvec![self]
}
}
}

Expand Down Expand Up @@ -1459,8 +1485,8 @@ impl<Cx: PatCx> WitnessMatrix<Cx> {
*self = ret;
} else {
// Any other constructor we unspecialize as expected.
for witness in self.0.iter_mut() {
witness.apply_constructor(pcx, ctor)
for witness in std::mem::take(&mut self.0) {
self.0.extend(witness.apply_constructor(pcx, ctor));
}
}
}
Expand Down Expand Up @@ -1617,7 +1643,6 @@ fn compute_exhaustiveness_and_usefulness<'a, 'p, Cx: PatCx>(
};

// Analyze the constructors present in this column.
debug!("ty: {:?}", place.ty);
let ctors = matrix.heads().map(|p| p.ctor());
let (split_ctors, missing_ctors) = place.split_column_ctors(mcx.tycx, ctors)?;

Expand Down Expand Up @@ -1669,7 +1694,10 @@ fn compute_exhaustiveness_and_usefulness<'a, 'p, Cx: PatCx>(
for row in matrix.rows() {
if row.useful {
if let PatOrWild::Pat(pat) = row.head() {
mcx.useful_subpatterns.insert(pat.uid);
let newly_useful = mcx.useful_subpatterns.insert(pat.uid);
if newly_useful {
debug!("newly useful: {pat:?}");
}
}
}
}
Expand Down Expand Up @@ -1768,6 +1796,7 @@ pub fn compute_match_usefulness<'p, Cx: PatCx>(
.map(|arm| {
debug!(?arm);
let usefulness = collect_pattern_usefulness(&cx.useful_subpatterns, arm.pat);
debug!(?usefulness);
(arm, usefulness)
})
.collect();
Expand Down
35 changes: 35 additions & 0 deletions tests/ui/pattern/usefulness/unions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
fn main() {
#[derive(Copy, Clone)]
union U8AsBool {
n: u8,
b: bool,
}

let x = U8AsBool { n: 1 };
unsafe {
match x {
// exhaustive
U8AsBool { n: 2 } => {}
U8AsBool { b: true } => {}
U8AsBool { b: false } => {}
}
match x {
// exhaustive
U8AsBool { b: true } => {}
U8AsBool { n: 0 } => {}
U8AsBool { n: 1.. } => {}
}
match x {
//~^ ERROR non-exhaustive patterns: `U8AsBool { n: 0_u8 }` and `U8AsBool { b: false }` not covered
U8AsBool { b: true } => {}
U8AsBool { n: 1.. } => {}
}
// Our approach can report duplicate witnesses sometimes.
match (x, true) {
//~^ ERROR non-exhaustive patterns: `(U8AsBool { n: 0_u8 }, false)`, `(U8AsBool { b: false }, false)`, `(U8AsBool { n: 0_u8 }, false)` and 1 more not covered
(U8AsBool { b: true }, true) => {}
(U8AsBool { b: false }, true) => {}
(U8AsBool { n: 1.. }, true) => {}
}
}
}
34 changes: 34 additions & 0 deletions tests/ui/pattern/usefulness/unions.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
error[E0004]: non-exhaustive patterns: `U8AsBool { n: 0_u8 }` and `U8AsBool { b: false }` not covered
--> $DIR/unions.rs:22:15
|
LL | match x {
| ^ patterns `U8AsBool { n: 0_u8 }` and `U8AsBool { b: false }` not covered
|
note: `U8AsBool` defined here
--> $DIR/unions.rs:3:11
|
LL | union U8AsBool {
| ^^^^^^^^
= note: the matched value is of type `U8AsBool`
help: ensure that all possible cases are being handled by adding a match arm with a wildcard pattern, a match arm with multiple or-patterns as shown, or multiple match arms
|
LL ~ U8AsBool { n: 1.. } => {},
LL + U8AsBool { n: 0_u8 } | U8AsBool { b: false } => todo!()
|

error[E0004]: non-exhaustive patterns: `(U8AsBool { n: 0_u8 }, false)`, `(U8AsBool { b: false }, false)`, `(U8AsBool { n: 0_u8 }, false)` and 1 more not covered
--> $DIR/unions.rs:28:15
|
LL | match (x, true) {
| ^^^^^^^^^ patterns `(U8AsBool { n: 0_u8 }, false)`, `(U8AsBool { b: false }, false)`, `(U8AsBool { n: 0_u8 }, false)` and 1 more not covered
|
= note: the matched value is of type `(U8AsBool, bool)`
help: ensure that all possible cases are being handled by adding a match arm with a wildcard pattern as shown, or multiple match arms
|
LL ~ (U8AsBool { n: 1.. }, true) => {},
LL + _ => todo!()
|

error: aborting due to 2 previous errors

For more information about this error, try `rustc --explain E0004`.