Skip to content

Commit d0da4cf

Browse files
committed
Implement comparison operators for int and uint SIMD vectors
1 parent 6dd7a56 commit d0da4cf

File tree

5 files changed

+135
-20
lines changed

5 files changed

+135
-20
lines changed

src/librustc/middle/trans/base.rs

+38
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,44 @@ pub fn compare_scalar_values<'a>(
619619
}
620620
}
621621

622+
pub fn compare_simd_types(
623+
cx: &Block,
624+
lhs: ValueRef,
625+
rhs: ValueRef,
626+
t: ty::t,
627+
size: uint,
628+
op: ast::BinOp)
629+
-> ValueRef {
630+
match ty::get(t).sty {
631+
ty::ty_float(_) => {
632+
// The comparison operators for floating point vectors are challenging.
633+
// LLVM outputs a `< size x i1 >`, but if we perform a sign extension
634+
// then bitcast to a floating point vector, the result will be `-NaN`
635+
// for each truth value. Because of this they are unsupported.
636+
cx.sess().bug("compare_simd_types: comparison operators \
637+
not supported for floating point SIMD types")
638+
},
639+
ty::ty_uint(_) | ty::ty_int(_) => {
640+
let cmp = match op {
641+
ast::BiEq => lib::llvm::IntEQ,
642+
ast::BiNe => lib::llvm::IntNE,
643+
ast::BiLt => lib::llvm::IntSLT,
644+
ast::BiLe => lib::llvm::IntSLE,
645+
ast::BiGt => lib::llvm::IntSGT,
646+
ast::BiGe => lib::llvm::IntSGE,
647+
_ => cx.sess().bug("compare_simd_types: must be a comparison operator"),
648+
};
649+
let return_ty = Type::vector(&type_of(cx.ccx(), t), size as u64);
650+
// LLVM outputs an `< size x i1 >`, so we need to perform a sign extension
651+
// to get the correctly sized type. This will compile to a single instruction
652+
// once the IR is converted to assembly if the SIMD instruction is supported
653+
// by the target architecture.
654+
SExt(cx, ICmp(cx, cmp, lhs, rhs), return_ty)
655+
},
656+
_ => cx.sess().bug("compare_simd_types: invalid SIMD type"),
657+
}
658+
}
659+
622660
pub type val_and_ty_fn<'r,'b> =
623661
|&'b Block<'b>, ValueRef, ty::t|: 'r -> &'b Block<'b>;
624662

src/librustc/middle/trans/expr.rs

+13-14
Original file line numberDiff line numberDiff line change
@@ -1259,16 +1259,15 @@ fn trans_eager_binop<'a>(
12591259
-> DatumBlock<'a, Expr> {
12601260
let _icx = push_ctxt("trans_eager_binop");
12611261

1262-
let mut intype = {
1262+
let tcx = bcx.tcx();
1263+
let is_simd = ty::type_is_simd(tcx, lhs_t);
1264+
let intype = {
12631265
if ty::type_is_bot(lhs_t) { rhs_t }
1266+
else if is_simd { ty::simd_type(tcx, lhs_t) }
12641267
else { lhs_t }
12651268
};
1266-
let tcx = bcx.tcx();
1267-
if ty::type_is_simd(tcx, intype) {
1268-
intype = ty::simd_type(tcx, intype);
1269-
}
12701269
let is_float = ty::type_is_fp(intype);
1271-
let signed = ty::type_is_signed(intype);
1270+
let is_signed = ty::type_is_signed(intype);
12721271

12731272
let rhs = base::cast_shift_expr_rhs(bcx, op, lhs, rhs);
12741273

@@ -1293,7 +1292,7 @@ fn trans_eager_binop<'a>(
12931292
// Only zero-check integers; fp /0 is NaN
12941293
bcx = base::fail_if_zero(bcx, binop_expr.span,
12951294
op, rhs, rhs_t);
1296-
if signed {
1295+
if is_signed {
12971296
SDiv(bcx, lhs, rhs)
12981297
} else {
12991298
UDiv(bcx, lhs, rhs)
@@ -1307,7 +1306,7 @@ fn trans_eager_binop<'a>(
13071306
// Only zero-check integers; fp %0 is NaN
13081307
bcx = base::fail_if_zero(bcx, binop_expr.span,
13091308
op, rhs, rhs_t);
1310-
if signed {
1309+
if is_signed {
13111310
SRem(bcx, lhs, rhs)
13121311
} else {
13131312
URem(bcx, lhs, rhs)
@@ -1319,21 +1318,21 @@ fn trans_eager_binop<'a>(
13191318
ast::BiBitXor => Xor(bcx, lhs, rhs),
13201319
ast::BiShl => Shl(bcx, lhs, rhs),
13211320
ast::BiShr => {
1322-
if signed {
1321+
if is_signed {
13231322
AShr(bcx, lhs, rhs)
13241323
} else { LShr(bcx, lhs, rhs) }
13251324
}
13261325
ast::BiEq | ast::BiNe | ast::BiLt | ast::BiGe | ast::BiLe | ast::BiGt => {
13271326
if ty::type_is_bot(rhs_t) {
13281327
C_bool(bcx.ccx(), false)
1329-
} else {
1330-
if !ty::type_is_scalar(rhs_t) {
1331-
bcx.tcx().sess.span_bug(binop_expr.span,
1332-
"non-scalar comparison");
1333-
}
1328+
} else if ty::type_is_scalar(rhs_t) {
13341329
let cmpr = base::compare_scalar_types(bcx, lhs, rhs, rhs_t, op);
13351330
bcx = cmpr.bcx;
13361331
ZExt(bcx, cmpr.val, Type::i8(bcx.ccx()))
1332+
} else if is_simd {
1333+
base::compare_simd_types(bcx, lhs, rhs, intype, ty::simd_size(tcx, lhs_t), op)
1334+
} else {
1335+
bcx.tcx().sess.span_bug(binop_expr.span, "comparison operator unsupported for type")
13371336
}
13381337
}
13391338
_ => {

src/librustc/middle/typeck/check/mod.rs

+21-2
Original file line numberDiff line numberDiff line change
@@ -2102,8 +2102,27 @@ fn check_expr_with_unifier(fcx: &FnCtxt,
21022102

21032103
let result_t = match op {
21042104
ast::BiEq | ast::BiNe | ast::BiLt | ast::BiLe | ast::BiGe |
2105-
ast::BiGt => ty::mk_bool(),
2106-
_ => lhs_t
2105+
ast::BiGt => {
2106+
if ty::type_is_simd(tcx, lhs_t) {
2107+
if ty::type_is_fp(ty::simd_type(tcx, lhs_t)) {
2108+
fcx.type_error_message(expr.span,
2109+
|actual| {
2110+
format!("binary comparison operation `{}` not supported \
2111+
for floating point SIMD vector `{}`",
2112+
ast_util::binop_to_str(op), actual)
2113+
},
2114+
lhs_t,
2115+
None
2116+
);
2117+
ty::mk_err()
2118+
} else {
2119+
lhs_t
2120+
}
2121+
} else {
2122+
ty::mk_bool()
2123+
}
2124+
},
2125+
_ => lhs_t,
21072126
};
21082127

21092128
fcx.write_ty(expr.id, result_t);

src/test/compile-fail/simd-binop.rs

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright 2014 The Rust Project Developers. See the COPYRIGHT
2+
// file at the top-level directory of this distribution and at
3+
// http://rust-lang.org/COPYRIGHT.
4+
//
5+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8+
// option. This file may not be copied, modified, or distributed
9+
// except according to those terms.
10+
11+
// ignore-tidy-linelength
12+
13+
#![allow(experimental)]
14+
15+
use std::unstable::simd::f32x4;
16+
17+
fn main() {
18+
19+
let _ = f32x4(0.0, 0.0, 0.0, 0.0) == f32x4(0.0, 0.0, 0.0, 0.0);
20+
//~^ ERROR binary comparison operation `==` not supported for floating point SIMD vector `std::unstable::simd::f32x4`
21+
22+
let _ = f32x4(0.0, 0.0, 0.0, 0.0) != f32x4(0.0, 0.0, 0.0, 0.0);
23+
//~^ ERROR binary comparison operation `!=` not supported for floating point SIMD vector `std::unstable::simd::f32x4`
24+
25+
let _ = f32x4(0.0, 0.0, 0.0, 0.0) < f32x4(0.0, 0.0, 0.0, 0.0);
26+
//~^ ERROR binary comparison operation `<` not supported for floating point SIMD vector `std::unstable::simd::f32x4`
27+
28+
let _ = f32x4(0.0, 0.0, 0.0, 0.0) <= f32x4(0.0, 0.0, 0.0, 0.0);
29+
//~^ ERROR binary comparison operation `<=` not supported for floating point SIMD vector `std::unstable::simd::f32x4`
30+
31+
let _ = f32x4(0.0, 0.0, 0.0, 0.0) >= f32x4(0.0, 0.0, 0.0, 0.0);
32+
//~^ ERROR binary comparison operation `>=` not supported for floating point SIMD vector `std::unstable::simd::f32x4`
33+
34+
let _ = f32x4(0.0, 0.0, 0.0, 0.0) > f32x4(0.0, 0.0, 0.0, 0.0);
35+
//~^ ERROR binary comparison operation `>` not supported for floating point SIMD vector `std::unstable::simd::f32x4`
36+
37+
}

src/test/run-pass/simd-binop.rs

+26-4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ fn eq_i32x4(i32x4(x0, x1, x2, x3): i32x4, i32x4(y0, y1, y2, y3): i32x4) -> bool
2525
}
2626

2727
pub fn main() {
28+
// arithmetic operators
29+
2830
assert!(eq_u32x4(u32x4(1, 2, 3, 4) + u32x4(4, 3, 2, 1), u32x4(5, 5, 5, 5)));
2931
assert!(eq_u32x4(u32x4(4, 5, 6, 7) - u32x4(4, 3, 2, 1), u32x4(0, 2, 4, 6)));
3032
assert!(eq_u32x4(u32x4(1, 2, 3, 4) * u32x4(4, 3, 2, 1), u32x4(4, 6, 6, 4)));
@@ -43,8 +45,28 @@ pub fn main() {
4345
assert!(eq_i32x4(i32x4(1, 2, 3, 4) << i32x4(4, 3, 2, 1), i32x4(16, 16, 12, 8)));
4446
assert!(eq_i32x4(i32x4(1, 2, 3, 4) >> i32x4(4, 3, 2, 1), i32x4(0, 0, 0, 2)));
4547

46-
assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) + f32x4(4.0, 3.0, 2.0, 1.0), f32x4(5.0, 5.0, 5.0, 5.0)));
47-
assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) - f32x4(4.0, 3.0, 2.0, 1.0), f32x4(-3.0, -1.0, 1.0, 3.0)));
48-
assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) * f32x4(4.0, 3.0, 2.0, 1.0), f32x4(4.0, 6.0, 6.0, 4.0)));
49-
assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) / f32x4(4.0, 4.0, 2.0, 1.0), f32x4(0.25, 0.5, 1.5, 4.0)));
48+
assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) + f32x4(4.0, 3.0, 2.0, 1.0),
49+
f32x4(5.0, 5.0, 5.0, 5.0)));
50+
assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) - f32x4(4.0, 3.0, 2.0, 1.0),
51+
f32x4(-3.0, -1.0, 1.0, 3.0)));
52+
assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) * f32x4(4.0, 3.0, 2.0, 1.0),
53+
f32x4(4.0, 6.0, 6.0, 4.0)));
54+
assert!(eq_f32x4(f32x4(1.0, 2.0, 3.0, 4.0) / f32x4(4.0, 4.0, 2.0, 1.0),
55+
f32x4(0.25, 0.5, 1.5, 4.0)));
56+
57+
// comparison operators
58+
59+
assert!(eq_u32x4(u32x4(1, 2, 3, 4) == u32x4(3, 2, 1, 0), u32x4(0, !0, 0, 0)));
60+
assert!(eq_u32x4(u32x4(1, 2, 3, 4) != u32x4(3, 2, 1, 0), u32x4(!0, 0, !0, !0)));
61+
assert!(eq_u32x4(u32x4(1, 2, 3, 4) < u32x4(3, 2, 1, 0), u32x4(!0, 0, 0, 0)));
62+
assert!(eq_u32x4(u32x4(1, 2, 3, 4) <= u32x4(3, 2, 1, 0), u32x4(!0, !0, 0, 0)));
63+
assert!(eq_u32x4(u32x4(1, 2, 3, 4) >= u32x4(3, 2, 1, 0), u32x4(0, !0, !0, !0)));
64+
assert!(eq_u32x4(u32x4(1, 2, 3, 4) > u32x4(3, 2, 1, 0), u32x4(0, 0, !0, !0)));
65+
66+
assert!(eq_i32x4(i32x4(1, 2, 3, 4) == i32x4(3, 2, 1, 0), i32x4(0, !0, 0, 0)));
67+
assert!(eq_i32x4(i32x4(1, 2, 3, 4) != i32x4(3, 2, 1, 0), i32x4(!0, 0, !0, !0)));
68+
assert!(eq_i32x4(i32x4(1, 2, 3, 4) < i32x4(3, 2, 1, 0), i32x4(!0, 0, 0, 0)));
69+
assert!(eq_i32x4(i32x4(1, 2, 3, 4) <= i32x4(3, 2, 1, 0), i32x4(!0, !0, 0, 0)));
70+
assert!(eq_i32x4(i32x4(1, 2, 3, 4) >= i32x4(3, 2, 1, 0), i32x4(0, !0, !0, !0)));
71+
assert!(eq_i32x4(i32x4(1, 2, 3, 4) > i32x4(3, 2, 1, 0), i32x4(0, 0, !0, !0)));
5072
}

0 commit comments

Comments
 (0)