Skip to content

Commit 2393e44

Browse files
committed
miri: algebraic intrinsics: bring back float non-determinism
1 parent 1b8ab72 commit 2393e44

File tree

6 files changed

+107
-86
lines changed

6 files changed

+107
-86
lines changed

compiler/rustc_const_eval/src/interpret/intrinsics.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
178178

179179
let res = self.binary_op(op, &a, &b)?;
180180
// `binary_op` already called `generate_nan` if needed.
181-
182-
// FIXME: Miri should add some non-determinism to the result here to catch any dependences on exact computations. This has previously been done, but the behaviour was removed as part of constification.
181+
let res = M::apply_float_nondet(self, res)?;
183182
self.write_immediate(*res, dest)?;
184183
}
185184

compiler/rustc_const_eval/src/interpret/machine.rs

+8
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,14 @@ pub trait Machine<'tcx>: Sized {
276276
F2::NAN
277277
}
278278

279+
/// Apply non-determinism to float operations that do not return a precise result.
280+
fn apply_float_nondet(
281+
_ecx: &mut InterpCx<'tcx, Self>,
282+
val: ImmTy<'tcx, Self::Provenance>,
283+
) -> InterpResult<'tcx, ImmTy<'tcx, Self::Provenance>> {
284+
interp_ok(val)
285+
}
286+
279287
/// Determines the result of `min`/`max` on floats when the arguments are equal.
280288
fn equal_float_min_max<F: Float>(_ecx: &InterpCx<'tcx, Self>, a: F, _b: F) -> F {
281289
// By default, we pick the left argument.

src/tools/miri/src/intrinsics/mod.rs

+2-25
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ use rand::Rng;
77
use rustc_abi::Size;
88
use rustc_apfloat::{Float, Round};
99
use rustc_middle::mir;
10-
use rustc_middle::ty::{self, FloatTy, ScalarInt};
10+
use rustc_middle::ty::{self, FloatTy};
1111
use rustc_span::{Symbol, sym};
1212

1313
use self::atomic::EvalContextExt as _;
1414
use self::helpers::{ToHost, ToSoft, check_intrinsic_arg_count};
1515
use self::simd::EvalContextExt as _;
16-
use crate::math::apply_random_float_error_ulp;
16+
use crate::math::apply_random_float_error_to_imm;
1717
use crate::*;
1818

1919
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
@@ -473,26 +473,3 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
473473
interp_ok(EmulateItemResult::NeedsReturn)
474474
}
475475
}
476-
477-
/// Applies a random 16ULP floating point error to `val` and returns the new value.
478-
/// Will fail if `val` is not a floating point number.
479-
fn apply_random_float_error_to_imm<'tcx>(
480-
ecx: &mut MiriInterpCx<'tcx>,
481-
val: ImmTy<'tcx>,
482-
ulp_exponent: u32,
483-
) -> InterpResult<'tcx, ImmTy<'tcx>> {
484-
let scalar = val.to_scalar_int()?;
485-
let res: ScalarInt = match val.layout.ty.kind() {
486-
ty::Float(FloatTy::F16) =>
487-
apply_random_float_error_ulp(ecx, scalar.to_f16(), ulp_exponent).into(),
488-
ty::Float(FloatTy::F32) =>
489-
apply_random_float_error_ulp(ecx, scalar.to_f32(), ulp_exponent).into(),
490-
ty::Float(FloatTy::F64) =>
491-
apply_random_float_error_ulp(ecx, scalar.to_f64(), ulp_exponent).into(),
492-
ty::Float(FloatTy::F128) =>
493-
apply_random_float_error_ulp(ecx, scalar.to_f128(), ulp_exponent).into(),
494-
_ => bug!("intrinsic called with non-float input type"),
495-
};
496-
497-
interp_ok(ImmTy::from_scalar_int(res, val.layout))
498-
}

src/tools/miri/src/machine.rs

+10
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,16 @@ impl<'tcx> Machine<'tcx> for MiriMachine<'tcx> {
11991199
ecx.generate_nan(inputs)
12001200
}
12011201

1202+
#[inline(always)]
1203+
fn apply_float_nondet(
1204+
ecx: &mut InterpCx<'tcx, Self>,
1205+
val: ImmTy<'tcx>,
1206+
) -> InterpResult<'tcx, ImmTy<'tcx>> {
1207+
crate::math::apply_random_float_error_to_imm(
1208+
ecx, val, 2 /* log2(4) */
1209+
)
1210+
}
1211+
12021212
#[inline(always)]
12031213
fn equal_float_min_max<F: Float>(ecx: &MiriInterpCx<'tcx>, a: F, b: F) -> F {
12041214
ecx.equal_float_min_max(a, b)

src/tools/miri/src/math.rs

+26
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use rand::Rng as _;
22
use rustc_apfloat::Float as _;
33
use rustc_apfloat::ieee::IeeeFloat;
4+
use rustc_middle::ty::{self, FloatTy, ScalarInt};
5+
6+
use crate::*;
47

58
/// Disturbes a floating-point result by a relative error in the range (-2^scale, 2^scale).
69
///
@@ -43,6 +46,29 @@ pub(crate) fn apply_random_float_error_ulp<F: rustc_apfloat::Float>(
4346
apply_random_float_error(ecx, val, err_scale)
4447
}
4548

49+
/// Applies a random 16ULP floating point error to `val` and returns the new value.
50+
/// Will fail if `val` is not a floating point number.
51+
pub(crate) fn apply_random_float_error_to_imm<'tcx>(
52+
ecx: &mut MiriInterpCx<'tcx>,
53+
val: ImmTy<'tcx>,
54+
ulp_exponent: u32,
55+
) -> InterpResult<'tcx, ImmTy<'tcx>> {
56+
let scalar = val.to_scalar_int()?;
57+
let res: ScalarInt = match val.layout.ty.kind() {
58+
ty::Float(FloatTy::F16) =>
59+
apply_random_float_error_ulp(ecx, scalar.to_f16(), ulp_exponent).into(),
60+
ty::Float(FloatTy::F32) =>
61+
apply_random_float_error_ulp(ecx, scalar.to_f32(), ulp_exponent).into(),
62+
ty::Float(FloatTy::F64) =>
63+
apply_random_float_error_ulp(ecx, scalar.to_f64(), ulp_exponent).into(),
64+
ty::Float(FloatTy::F128) =>
65+
apply_random_float_error_ulp(ecx, scalar.to_f128(), ulp_exponent).into(),
66+
_ => bug!("intrinsic called with non-float input type"),
67+
};
68+
69+
interp_ok(ImmTy::from_scalar_int(res, val.layout))
70+
}
71+
4672
pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFloat<S> {
4773
match x.category() {
4874
// preserve zero sign

src/tools/miri/tests/pass/float.rs

+60-59
Original file line numberDiff line numberDiff line change
@@ -1292,8 +1292,7 @@ fn test_non_determinism() {
12921292
}
12931293
}
12941294
// We saw the same thing N times.
1295-
// FIXME: temporarily disabled as it breaks std tests.
1296-
//panic!("expected non-determinism, got {rounds} times the same result: {first:?}");
1295+
panic!("expected non-determinism, got {rounds} times the same result: {first:?}");
12971296
}
12981297

12991298
macro_rules! test_operations_f {
@@ -1319,66 +1318,68 @@ fn test_non_determinism() {
13191318
}
13201319
pub fn test_operations_f32(a: f32, b: f32) {
13211320
test_operations_f!(a, b);
1322-
ensure_nondet(|| a.log(b));
1323-
ensure_nondet(|| a.exp());
1324-
ensure_nondet(|| 10f32.exp2());
1325-
ensure_nondet(|| f32::consts::E.ln());
1326-
ensure_nondet(|| 1f32.ln_1p());
1327-
ensure_nondet(|| 10f32.log10());
1328-
ensure_nondet(|| 8f32.log2());
1329-
ensure_nondet(|| 27.0f32.cbrt());
1330-
ensure_nondet(|| 3.0f32.hypot(4.0f32));
1331-
ensure_nondet(|| 1f32.sin());
1332-
ensure_nondet(|| 0f32.cos());
1333-
// On i686-pc-windows-msvc , these functions are implemented by calling the `f64` version,
1334-
// which means the little rounding errors Miri introduces are discard by the cast down to `f32`.
1335-
// Just skip the test for them.
1336-
if !cfg!(all(target_os = "windows", target_env = "msvc", target_arch = "x86")) {
1337-
ensure_nondet(|| 1.0f32.tan());
1338-
ensure_nondet(|| 1.0f32.asin());
1339-
ensure_nondet(|| 5.0f32.acos());
1340-
ensure_nondet(|| 1.0f32.atan());
1341-
ensure_nondet(|| 1.0f32.atan2(2.0f32));
1342-
ensure_nondet(|| 1.0f32.sinh());
1343-
ensure_nondet(|| 1.0f32.cosh());
1344-
ensure_nondet(|| 1.0f32.tanh());
1345-
}
1346-
ensure_nondet(|| 1.0f32.asinh());
1347-
ensure_nondet(|| 2.0f32.acosh());
1348-
ensure_nondet(|| 0.5f32.atanh());
1349-
ensure_nondet(|| 5.0f32.gamma());
1350-
ensure_nondet(|| 5.0f32.ln_gamma());
1351-
ensure_nondet(|| 5.0f32.erf());
1352-
ensure_nondet(|| 5.0f32.erfc());
1321+
// FIXME: temporarily disabled as it breaks std tests.
1322+
// ensure_nondet(|| a.log(b));
1323+
// ensure_nondet(|| a.exp());
1324+
// ensure_nondet(|| 10f32.exp2());
1325+
// ensure_nondet(|| f32::consts::E.ln());
1326+
// ensure_nondet(|| 1f32.ln_1p());
1327+
// ensure_nondet(|| 10f32.log10());
1328+
// ensure_nondet(|| 8f32.log2());
1329+
// ensure_nondet(|| 27.0f32.cbrt());
1330+
// ensure_nondet(|| 3.0f32.hypot(4.0f32));
1331+
// ensure_nondet(|| 1f32.sin());
1332+
// ensure_nondet(|| 0f32.cos());
1333+
// // On i686-pc-windows-msvc , these functions are implemented by calling the `f64` version,
1334+
// // which means the little rounding errors Miri introduces are discard by the cast down to `f32`.
1335+
// // Just skip the test for them.
1336+
// if !cfg!(all(target_os = "windows", target_env = "msvc", target_arch = "x86")) {
1337+
// ensure_nondet(|| 1.0f32.tan());
1338+
// ensure_nondet(|| 1.0f32.asin());
1339+
// ensure_nondet(|| 5.0f32.acos());
1340+
// ensure_nondet(|| 1.0f32.atan());
1341+
// ensure_nondet(|| 1.0f32.atan2(2.0f32));
1342+
// ensure_nondet(|| 1.0f32.sinh());
1343+
// ensure_nondet(|| 1.0f32.cosh());
1344+
// ensure_nondet(|| 1.0f32.tanh());
1345+
// }
1346+
// ensure_nondet(|| 1.0f32.asinh());
1347+
// ensure_nondet(|| 2.0f32.acosh());
1348+
// ensure_nondet(|| 0.5f32.atanh());
1349+
// ensure_nondet(|| 5.0f32.gamma());
1350+
// ensure_nondet(|| 5.0f32.ln_gamma());
1351+
// ensure_nondet(|| 5.0f32.erf());
1352+
// ensure_nondet(|| 5.0f32.erfc());
13531353
}
13541354
pub fn test_operations_f64(a: f64, b: f64) {
13551355
test_operations_f!(a, b);
1356-
ensure_nondet(|| a.log(b));
1357-
ensure_nondet(|| a.exp());
1358-
ensure_nondet(|| 50f64.exp2());
1359-
ensure_nondet(|| 3f64.ln());
1360-
ensure_nondet(|| 1f64.ln_1p());
1361-
ensure_nondet(|| f64::consts::E.log10());
1362-
ensure_nondet(|| f64::consts::E.log2());
1363-
ensure_nondet(|| 27.0f64.cbrt());
1364-
ensure_nondet(|| 3.0f64.hypot(4.0f64));
1365-
ensure_nondet(|| 1f64.sin());
1366-
ensure_nondet(|| 0f64.cos());
1367-
ensure_nondet(|| 1.0f64.tan());
1368-
ensure_nondet(|| 1.0f64.asin());
1369-
ensure_nondet(|| 5.0f64.acos());
1370-
ensure_nondet(|| 1.0f64.atan());
1371-
ensure_nondet(|| 1.0f64.atan2(2.0f64));
1372-
ensure_nondet(|| 1.0f64.sinh());
1373-
ensure_nondet(|| 1.0f64.cosh());
1374-
ensure_nondet(|| 1.0f64.tanh());
1375-
ensure_nondet(|| 1.0f64.asinh());
1376-
ensure_nondet(|| 3.0f64.acosh());
1377-
ensure_nondet(|| 0.5f64.atanh());
1378-
ensure_nondet(|| 5.0f64.gamma());
1379-
ensure_nondet(|| 5.0f64.ln_gamma());
1380-
ensure_nondet(|| 5.0f64.erf());
1381-
ensure_nondet(|| 5.0f64.erfc());
1356+
// FIXME: temporarily disabled as it breaks std tests.
1357+
// ensure_nondet(|| a.log(b));
1358+
// ensure_nondet(|| a.exp());
1359+
// ensure_nondet(|| 50f64.exp2());
1360+
// ensure_nondet(|| 3f64.ln());
1361+
// ensure_nondet(|| 1f64.ln_1p());
1362+
// ensure_nondet(|| f64::consts::E.log10());
1363+
// ensure_nondet(|| f64::consts::E.log2());
1364+
// ensure_nondet(|| 27.0f64.cbrt());
1365+
// ensure_nondet(|| 3.0f64.hypot(4.0f64));
1366+
// ensure_nondet(|| 1f64.sin());
1367+
// ensure_nondet(|| 0f64.cos());
1368+
// ensure_nondet(|| 1.0f64.tan());
1369+
// ensure_nondet(|| 1.0f64.asin());
1370+
// ensure_nondet(|| 5.0f64.acos());
1371+
// ensure_nondet(|| 1.0f64.atan());
1372+
// ensure_nondet(|| 1.0f64.atan2(2.0f64));
1373+
// ensure_nondet(|| 1.0f64.sinh());
1374+
// ensure_nondet(|| 1.0f64.cosh());
1375+
// ensure_nondet(|| 1.0f64.tanh());
1376+
// ensure_nondet(|| 1.0f64.asinh());
1377+
// ensure_nondet(|| 3.0f64.acosh());
1378+
// ensure_nondet(|| 0.5f64.atanh());
1379+
// ensure_nondet(|| 5.0f64.gamma());
1380+
// ensure_nondet(|| 5.0f64.ln_gamma());
1381+
// ensure_nondet(|| 5.0f64.erf());
1382+
// ensure_nondet(|| 5.0f64.erfc());
13821383
}
13831384
pub fn test_operations_f128(a: f128, b: f128) {
13841385
test_operations_f!(a, b);

0 commit comments

Comments
 (0)