Skip to content

Commit 64b086a

Browse files
committed
Auto merge of #2004 - RalfJung:simd, r=RalfJung
implement more SIMD intrinsics Requires rust-lang/rust#94681 With this, the cast, i32_ops, and f32_ops test suites of portable-simd pass. :) Cc #1912
2 parents dd42a47 + b87a9c9 commit 64b086a

File tree

4 files changed

+346
-78
lines changed

4 files changed

+346
-78
lines changed

rust-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
8876ca3dd46b99fe7e6ad937f11493d37996231e
1+
297273c45b205820a4c055082c71677197a40b55

src/shims/intrinsics.rs

Lines changed: 171 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
345345
bug!("simd_fabs operand is not a float")
346346
};
347347
let op = op.to_scalar()?;
348-
// FIXME: Using host floats.
349348
match float_ty {
350349
FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()),
351350
FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
@@ -371,7 +370,9 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
371370
| "simd_lt"
372371
| "simd_le"
373372
| "simd_gt"
374-
| "simd_ge" => {
373+
| "simd_ge"
374+
| "simd_fmax"
375+
| "simd_fmin" => {
375376
use mir::BinOp;
376377

377378
let &[ref left, ref right] = check_arg_count(args)?;
@@ -382,58 +383,77 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
382383
assert_eq!(dest_len, left_len);
383384
assert_eq!(dest_len, right_len);
384385

385-
let mir_op = match intrinsic_name {
386-
"simd_add" => BinOp::Add,
387-
"simd_sub" => BinOp::Sub,
388-
"simd_mul" => BinOp::Mul,
389-
"simd_div" => BinOp::Div,
390-
"simd_rem" => BinOp::Rem,
391-
"simd_shl" => BinOp::Shl,
392-
"simd_shr" => BinOp::Shr,
393-
"simd_and" => BinOp::BitAnd,
394-
"simd_or" => BinOp::BitOr,
395-
"simd_xor" => BinOp::BitXor,
396-
"simd_eq" => BinOp::Eq,
397-
"simd_ne" => BinOp::Ne,
398-
"simd_lt" => BinOp::Lt,
399-
"simd_le" => BinOp::Le,
400-
"simd_gt" => BinOp::Gt,
401-
"simd_ge" => BinOp::Ge,
386+
enum Op {
387+
MirOp(BinOp),
388+
FMax,
389+
FMin,
390+
}
391+
let which = match intrinsic_name {
392+
"simd_add" => Op::MirOp(BinOp::Add),
393+
"simd_sub" => Op::MirOp(BinOp::Sub),
394+
"simd_mul" => Op::MirOp(BinOp::Mul),
395+
"simd_div" => Op::MirOp(BinOp::Div),
396+
"simd_rem" => Op::MirOp(BinOp::Rem),
397+
"simd_shl" => Op::MirOp(BinOp::Shl),
398+
"simd_shr" => Op::MirOp(BinOp::Shr),
399+
"simd_and" => Op::MirOp(BinOp::BitAnd),
400+
"simd_or" => Op::MirOp(BinOp::BitOr),
401+
"simd_xor" => Op::MirOp(BinOp::BitXor),
402+
"simd_eq" => Op::MirOp(BinOp::Eq),
403+
"simd_ne" => Op::MirOp(BinOp::Ne),
404+
"simd_lt" => Op::MirOp(BinOp::Lt),
405+
"simd_le" => Op::MirOp(BinOp::Le),
406+
"simd_gt" => Op::MirOp(BinOp::Gt),
407+
"simd_ge" => Op::MirOp(BinOp::Ge),
408+
"simd_fmax" => Op::FMax,
409+
"simd_fmin" => Op::FMin,
402410
_ => unreachable!(),
403411
};
404412

405413
for i in 0..dest_len {
406414
let left = this.read_immediate(&this.mplace_index(&left, i)?.into())?;
407415
let right = this.read_immediate(&this.mplace_index(&right, i)?.into())?;
408416
let dest = this.mplace_index(&dest, i)?;
409-
let (val, overflowed, ty) = this.overflowing_binary_op(mir_op, &left, &right)?;
410-
if matches!(mir_op, BinOp::Shl | BinOp::Shr) {
411-
// Shifts have extra UB as SIMD operations that the MIR binop does not have.
412-
// See <https://github.com/rust-lang/rust/issues/91237>.
413-
if overflowed {
414-
let r_val = right.to_scalar()?.to_bits(right.layout.size)?;
415-
throw_ub_format!("overflowing shift by {} in `{}` in SIMD lane {}", r_val, intrinsic_name, i);
417+
let val = match which {
418+
Op::MirOp(mir_op) => {
419+
let (val, overflowed, ty) = this.overflowing_binary_op(mir_op, &left, &right)?;
420+
if matches!(mir_op, BinOp::Shl | BinOp::Shr) {
421+
// Shifts have extra UB as SIMD operations that the MIR binop does not have.
422+
// See <https://github.com/rust-lang/rust/issues/91237>.
423+
if overflowed {
424+
let r_val = right.to_scalar()?.to_bits(right.layout.size)?;
425+
throw_ub_format!("overflowing shift by {} in `{}` in SIMD lane {}", r_val, intrinsic_name, i);
426+
}
427+
}
428+
if matches!(mir_op, BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge) {
429+
// Special handling for boolean-returning operations
430+
assert_eq!(ty, this.tcx.types.bool);
431+
let val = val.to_bool().unwrap();
432+
bool_to_simd_element(val, dest.layout.size)
433+
} else {
434+
assert_ne!(ty, this.tcx.types.bool);
435+
assert_eq!(ty, dest.layout.ty);
436+
val
437+
}
438+
}
439+
Op::FMax => {
440+
fmax_op(&left, &right)?
416441
}
417-
}
418-
if matches!(mir_op, BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge) {
419-
// Special handling for boolean-returning operations
420-
assert_eq!(ty, this.tcx.types.bool);
421-
let val = val.to_bool().unwrap();
422-
let val = bool_to_simd_element(val, dest.layout.size);
423-
this.write_scalar(val, &dest.into())?;
424-
} else {
425-
assert_ne!(ty, this.tcx.types.bool);
426-
assert_eq!(ty, dest.layout.ty);
427-
this.write_scalar(val, &dest.into())?;
428-
}
442+
Op::FMin => {
443+
fmin_op(&left, &right)?
444+
}
445+
};
446+
this.write_scalar(val, &dest.into())?;
429447
}
430448
}
431449
#[rustfmt::skip]
432450
| "simd_reduce_and"
433451
| "simd_reduce_or"
434452
| "simd_reduce_xor"
435453
| "simd_reduce_any"
436-
| "simd_reduce_all" => {
454+
| "simd_reduce_all"
455+
| "simd_reduce_max"
456+
| "simd_reduce_min" => {
437457
use mir::BinOp;
438458

439459
let &[ref op] = check_arg_count(args)?;
@@ -445,19 +465,27 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
445465
enum Op {
446466
MirOp(BinOp),
447467
MirOpBool(BinOp),
468+
Max,
469+
Min,
448470
}
449-
// The initial value is the neutral element.
450-
let (which, init) = match intrinsic_name {
451-
"simd_reduce_and" => (Op::MirOp(BinOp::BitAnd), ImmTy::from_int(-1, dest.layout)),
452-
"simd_reduce_or" => (Op::MirOp(BinOp::BitOr), ImmTy::from_int(0, dest.layout)),
453-
"simd_reduce_xor" => (Op::MirOp(BinOp::BitXor), ImmTy::from_int(0, dest.layout)),
454-
"simd_reduce_any" => (Op::MirOpBool(BinOp::BitOr), imm_from_bool(false)),
455-
"simd_reduce_all" => (Op::MirOpBool(BinOp::BitAnd), imm_from_bool(true)),
471+
let which = match intrinsic_name {
472+
"simd_reduce_and" => Op::MirOp(BinOp::BitAnd),
473+
"simd_reduce_or" => Op::MirOp(BinOp::BitOr),
474+
"simd_reduce_xor" => Op::MirOp(BinOp::BitXor),
475+
"simd_reduce_any" => Op::MirOpBool(BinOp::BitOr),
476+
"simd_reduce_all" => Op::MirOpBool(BinOp::BitAnd),
477+
"simd_reduce_max" => Op::Max,
478+
"simd_reduce_min" => Op::Min,
456479
_ => unreachable!(),
457480
};
458481

459-
let mut res = init;
460-
for i in 0..op_len {
482+
// Initialize with first lane, then proceed with the rest.
483+
let mut res = this.read_immediate(&this.mplace_index(&op, 0)?.into())?;
484+
if matches!(which, Op::MirOpBool(_)) {
485+
// Convert to `bool` scalar.
486+
res = imm_from_bool(simd_element_to_bool(res)?);
487+
}
488+
for i in 1..op_len {
461489
let op = this.read_immediate(&this.mplace_index(&op, i)?.into())?;
462490
res = match which {
463491
Op::MirOp(mir_op) => {
@@ -467,6 +495,30 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
467495
let op = imm_from_bool(simd_element_to_bool(op)?);
468496
this.binary_op(mir_op, &res, &op)?
469497
}
498+
Op::Max => {
499+
if matches!(res.layout.ty.kind(), ty::Float(_)) {
500+
ImmTy::from_scalar(fmax_op(&res, &op)?, res.layout)
501+
} else {
502+
// Just boring integers, so NaNs to worry about
503+
if this.binary_op(BinOp::Ge, &res, &op)?.to_scalar()?.to_bool()? {
504+
res
505+
} else {
506+
op
507+
}
508+
}
509+
}
510+
Op::Min => {
511+
if matches!(res.layout.ty.kind(), ty::Float(_)) {
512+
ImmTy::from_scalar(fmin_op(&res, &op)?, res.layout)
513+
} else {
514+
// Just boring integers, so NaNs to worry about
515+
if this.binary_op(BinOp::Le, &res, &op)?.to_scalar()?.to_bool()? {
516+
res
517+
} else {
518+
op
519+
}
520+
}
521+
}
470522
};
471523
}
472524
this.write_immediate(*res, dest)?;
@@ -515,6 +567,45 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
515567
this.write_immediate(*val, &dest.into())?;
516568
}
517569
}
570+
#[rustfmt::skip]
571+
"simd_cast" | "simd_as" => {
572+
let &[ref op] = check_arg_count(args)?;
573+
let (op, op_len) = this.operand_to_simd(op)?;
574+
let (dest, dest_len) = this.place_to_simd(dest)?;
575+
576+
assert_eq!(dest_len, op_len);
577+
578+
let safe_cast = intrinsic_name == "simd_as";
579+
580+
for i in 0..dest_len {
581+
let op = this.read_immediate(&this.mplace_index(&op, i)?.into())?;
582+
let dest = this.mplace_index(&dest, i)?;
583+
584+
let val = match (op.layout.ty.kind(), dest.layout.ty.kind()) {
585+
// Int-to-(int|float): always safe
586+
(ty::Int(_) | ty::Uint(_), ty::Int(_) | ty::Uint(_) | ty::Float(_)) =>
587+
this.misc_cast(&op, dest.layout.ty)?,
588+
// Float-to-float: always safe
589+
(ty::Float(_), ty::Float(_)) =>
590+
this.misc_cast(&op, dest.layout.ty)?,
591+
// Float-to-int in safe mode
592+
(ty::Float(_), ty::Int(_) | ty::Uint(_)) if safe_cast =>
593+
this.misc_cast(&op, dest.layout.ty)?,
594+
// Float-to-int in unchecked mode
595+
(ty::Float(FloatTy::F32), ty::Int(_) | ty::Uint(_)) if !safe_cast =>
596+
this.float_to_int_unchecked(op.to_scalar()?.to_f32()?, dest.layout.ty)?.into(),
597+
(ty::Float(FloatTy::F64), ty::Int(_) | ty::Uint(_)) if !safe_cast =>
598+
this.float_to_int_unchecked(op.to_scalar()?.to_f64()?, dest.layout.ty)?.into(),
599+
_ =>
600+
throw_unsup_format!(
601+
"Unsupported SIMD cast from element type {} to {}",
602+
op.layout.ty,
603+
dest.layout.ty
604+
),
605+
};
606+
this.write_immediate(val, &dest.into())?;
607+
}
608+
}
518609

519610
// Atomic operations
520611
"atomic_load" => this.atomic_load(args, dest, AtomicReadOp::SeqCst)?,
@@ -1003,3 +1094,35 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
10031094
})
10041095
}
10051096
}
1097+
1098+
fn fmax_op<'tcx>(
1099+
left: &ImmTy<'tcx, Tag>,
1100+
right: &ImmTy<'tcx, Tag>,
1101+
) -> InterpResult<'tcx, Scalar<Tag>> {
1102+
assert_eq!(left.layout.ty, right.layout.ty);
1103+
let ty::Float(float_ty) = left.layout.ty.kind() else {
1104+
bug!("fmax operand is not a float")
1105+
};
1106+
let left = left.to_scalar()?;
1107+
let right = right.to_scalar()?;
1108+
Ok(match float_ty {
1109+
FloatTy::F32 => Scalar::from_f32(left.to_f32()?.max(right.to_f32()?)),
1110+
FloatTy::F64 => Scalar::from_f64(left.to_f64()?.max(right.to_f64()?)),
1111+
})
1112+
}
1113+
1114+
fn fmin_op<'tcx>(
1115+
left: &ImmTy<'tcx, Tag>,
1116+
right: &ImmTy<'tcx, Tag>,
1117+
) -> InterpResult<'tcx, Scalar<Tag>> {
1118+
assert_eq!(left.layout.ty, right.layout.ty);
1119+
let ty::Float(float_ty) = left.layout.ty.kind() else {
1120+
bug!("fmin operand is not a float")
1121+
};
1122+
let left = left.to_scalar()?;
1123+
let right = right.to_scalar()?;
1124+
Ok(match float_ty {
1125+
FloatTy::F32 => Scalar::from_f32(left.to_f32()?.min(right.to_f32()?)),
1126+
FloatTy::F64 => Scalar::from_f64(left.to_f64()?.min(right.to_f64()?)),
1127+
})
1128+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// error-pattern: cannot be represented in target type `i32`
2+
#![feature(portable_simd)]
3+
use std::simd::*;
4+
5+
fn main() { unsafe {
6+
let _x : i32x2 = f32x2::from_array([f32::MAX, f32::MIN]).to_int_unchecked();
7+
} }

0 commit comments

Comments
 (0)