Skip to content

Compile unicode-normalization faster #97936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
21 changes: 12 additions & 9 deletions compiler/rustc_mir_build/src/build/matches/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,18 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
_ => (None, 0),
};
if let Some((min, max, sz)) = range {
if let (Some(lo), Some(hi)) = (lo.try_to_bits(sz), hi.try_to_bits(sz)) {
// We want to compare ranges numerically, but the order of the bitwise
// representation of signed integers does not match their numeric order.
// Thus, to correct the ordering, we need to shift the range of signed
// integers to correct the comparison. This is achieved by XORing with a
// bias (see pattern/_match.rs for another pertinent example of this
// pattern).
let (lo, hi) = (lo ^ bias, hi ^ bias);
if lo <= min && (hi > max || hi == max && end == RangeEnd::Included) {
// We want to compare ranges numerically, but the order of the bitwise
// representation of signed integers does not match their numeric order. Thus,
// to correct the ordering, we need to shift the range of signed integers to
// correct the comparison. This is achieved by XORing with a bias (see
// pattern/_match.rs for another pertinent example of this pattern).
//
// Also, for performance, it's important to only do the second `try_to_bits` if
// necessary.
let lo = lo.try_to_bits(sz).unwrap() ^ bias;
if lo <= min {
let hi = hi.try_to_bits(sz).unwrap() ^ bias;
if hi > max || hi == max && end == RangeEnd::Included {
// Irrefutable pattern match.
return Ok(());
}
Expand Down
59 changes: 25 additions & 34 deletions compiler/rustc_mir_build/src/build/matches/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,39 +632,30 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}

(&TestKind::Range(test), &PatKind::Range(pat)) => {
use std::cmp::Ordering::*;

if test == pat {
self.candidate_without_match_pair(match_pair_index, candidate);
return Some(0);
}

let no_overlap = (|| {
use rustc_hir::RangeEnd::*;
use std::cmp::Ordering::*;

let tcx = self.tcx;

let test_ty = test.lo.ty();
let lo = compare_const_vals(tcx, test.lo, pat.hi, self.param_env, test_ty)?;
let hi = compare_const_vals(tcx, test.hi, pat.lo, self.param_env, test_ty)?;

match (test.end, pat.end, lo, hi) {
// pat < test
(_, _, Greater, _) |
(_, Excluded, Equal, _) |
// pat > test
(_, _, _, Less) |
(Excluded, _, _, Equal) => Some(true),
_ => Some(false),
}
})();

if let Some(true) = no_overlap {
// Testing range does not overlap with pattern range,
// so the pattern can be matched only if this test fails.
// For performance, it's important to only do the second
// `compare_const_vals` if necessary.
let no_overlap = if matches!(
(compare_const_vals(self.tcx, test.hi, pat.lo, self.param_env)?, test.end),
(Less, _) | (Equal, RangeEnd::Excluded) // test < pat
) || matches!(
(compare_const_vals(self.tcx, test.lo, pat.hi, self.param_env)?, pat.end),
(Greater, _) | (Equal, RangeEnd::Excluded) // test > pat
) {
Some(1)
} else {
None
}
};

// If the testing range does not overlap with pattern range,
// the pattern can be matched only if this test fails.
no_overlap
}

(&TestKind::Range(range), &PatKind::Constant { value }) => {
Expand Down Expand Up @@ -768,15 +759,15 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
) -> Option<bool> {
use std::cmp::Ordering::*;

let tcx = self.tcx;

let a = compare_const_vals(tcx, range.lo, value, self.param_env, range.lo.ty())?;
let b = compare_const_vals(tcx, value, range.hi, self.param_env, range.lo.ty())?;

match (b, range.end) {
(Less, _) | (Equal, RangeEnd::Included) if a != Greater => Some(true),
_ => Some(false),
}
// For performance, it's important to only do the second
// `compare_const_vals` if necessary.
Some(
matches!(compare_const_vals(self.tcx, range.lo, value, self.param_env)?, Less | Equal)
&& matches!(
(compare_const_vals(self.tcx, value, range.hi, self.param_env)?, range.end),
(Less, _) | (Equal, RangeEnd::Included)
),
)
}

fn values_not_contained_in_range(
Expand Down
21 changes: 3 additions & 18 deletions compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -828,14 +828,8 @@ impl<'tcx> Constructor<'tcx> {
FloatRange(other_from, other_to, other_end),
) => {
match (
compare_const_vals(pcx.cx.tcx, *self_to, *other_to, pcx.cx.param_env, pcx.ty),
compare_const_vals(
pcx.cx.tcx,
*self_from,
*other_from,
pcx.cx.param_env,
pcx.ty,
),
compare_const_vals(pcx.cx.tcx, *self_to, *other_to, pcx.cx.param_env),
compare_const_vals(pcx.cx.tcx, *self_from, *other_from, pcx.cx.param_env),
) {
(Some(to), Some(from)) => {
(from == Ordering::Greater || from == Ordering::Equal)
Expand All @@ -848,16 +842,7 @@ impl<'tcx> Constructor<'tcx> {
(Str(self_val), Str(other_val)) => {
// FIXME Once valtrees are available we can directly use the bytes
// in the `Str` variant of the valtree for the comparison here.
match compare_const_vals(
pcx.cx.tcx,
*self_val,
*other_val,
pcx.cx.param_env,
pcx.ty,
) {
Some(comparison) => comparison == Ordering::Equal,
None => false,
}
self_val == other_val
}
(Slice(self_slice), Slice(other_slice)) => self_slice.is_covered_by(*other_slice),

Expand Down
101 changes: 47 additions & 54 deletions compiler/rustc_mir_build/src/thir/pattern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ use rustc_hir::def::{CtorOf, DefKind, Res};
use rustc_hir::pat_util::EnumerateAndAdjustIterator;
use rustc_hir::RangeEnd;
use rustc_index::vec::Idx;
use rustc_middle::mir::interpret::{get_slice_bytes, ConstValue};
use rustc_middle::mir::interpret::{ErrorHandled, LitToConstError, LitToConstInput};
use rustc_middle::mir::interpret::{
ConstValue, ErrorHandled, LitToConstError, LitToConstInput, Scalar,
};
use rustc_middle::mir::{self, UserTypeProjection};
use rustc_middle::mir::{BorrowKind, Field, Mutability};
use rustc_middle::thir::{Ascription, BindingMode, FieldPat, LocalVarId, Pat, PatKind, PatRange};
Expand Down Expand Up @@ -129,7 +130,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
) -> PatKind<'tcx> {
assert_eq!(lo.ty(), ty);
assert_eq!(hi.ty(), ty);
let cmp = compare_const_vals(self.tcx, lo, hi, self.param_env, ty);
let cmp = compare_const_vals(self.tcx, lo, hi, self.param_env);
match (end, cmp) {
// `x..y` where `x < y`.
// Non-empty because the range includes at least `x`.
Expand Down Expand Up @@ -753,57 +754,49 @@ pub(crate) fn compare_const_vals<'tcx>(
a: mir::ConstantKind<'tcx>,
b: mir::ConstantKind<'tcx>,
param_env: ty::ParamEnv<'tcx>,
ty: Ty<'tcx>,
) -> Option<Ordering> {
let from_bool = |v: bool| v.then_some(Ordering::Equal);

let fallback = || from_bool(a == b);

// Use the fallback if any type differs
if a.ty() != b.ty() || a.ty() != ty {
return fallback();
}

if a == b {
return from_bool(true);
}

let a_bits = a.try_eval_bits(tcx, param_env, ty);
let b_bits = b.try_eval_bits(tcx, param_env, ty);

if let (Some(a), Some(b)) = (a_bits, b_bits) {
use rustc_apfloat::Float;
return match *ty.kind() {
ty::Float(ty::FloatTy::F32) => {
let l = rustc_apfloat::ieee::Single::from_bits(a);
let r = rustc_apfloat::ieee::Single::from_bits(b);
l.partial_cmp(&r)
}
ty::Float(ty::FloatTy::F64) => {
let l = rustc_apfloat::ieee::Double::from_bits(a);
let r = rustc_apfloat::ieee::Double::from_bits(b);
l.partial_cmp(&r)
}
ty::Int(ity) => {
use rustc_middle::ty::layout::IntegerExt;
let size = rustc_target::abi::Integer::from_int_ty(&tcx, ity).size();
let a = size.sign_extend(a);
let b = size.sign_extend(b);
Some((a as i128).cmp(&(b as i128)))
}
_ => Some(a.cmp(&b)),
};
}

if let ty::Str = ty.kind() && let (
Some(a_val @ ConstValue::Slice { .. }),
Some(b_val @ ConstValue::Slice { .. }),
) = (a.try_to_value(tcx), b.try_to_value(tcx))
{
let a_bytes = get_slice_bytes(&tcx, a_val);
let b_bytes = get_slice_bytes(&tcx, b_val);
return from_bool(a_bytes == b_bytes);
assert_eq!(a.ty(), b.ty());

let ty = a.ty();

// This code is hot when compiling matches with many ranges. So we
// special-case extraction of evaluated scalars for speed, for types where
// raw data comparisons are appropriate. E.g. `unicode-normalization` has
// many ranges such as '\u{037A}'..='\u{037F}', and chars can be compared
// in this way.
match ty.kind() {
ty::Float(_) | ty::Int(_) => {} // require special handling, see below
_ => match (a, b) {
(
mir::ConstantKind::Val(ConstValue::Scalar(Scalar::Int(a)), _a_ty),
mir::ConstantKind::Val(ConstValue::Scalar(Scalar::Int(b)), _b_ty),
) => return Some(a.cmp(&b)),
_ => {}
},
}

let a = a.eval_bits(tcx, param_env, ty);
let b = b.eval_bits(tcx, param_env, ty);

use rustc_apfloat::Float;
match *ty.kind() {
ty::Float(ty::FloatTy::F32) => {
let a = rustc_apfloat::ieee::Single::from_bits(a);
let b = rustc_apfloat::ieee::Single::from_bits(b);
a.partial_cmp(&b)
}
ty::Float(ty::FloatTy::F64) => {
let a = rustc_apfloat::ieee::Double::from_bits(a);
let b = rustc_apfloat::ieee::Double::from_bits(b);
a.partial_cmp(&b)
}
ty::Int(ity) => {
use rustc_middle::ty::layout::IntegerExt;
let size = rustc_target::abi::Integer::from_int_ty(&tcx, ity).size();
let a = size.sign_extend(a);
let b = size.sign_extend(b);
Some((a as i128).cmp(&(b as i128)))
}
_ => Some(a.cmp(&b)),
}

fallback()
}