Skip to content

Commit aab7589

Browse files
committed
Fix string pattern matching in mir interpreter
1 parent dfaca93 commit aab7589

File tree

6 files changed

+84
-12
lines changed

6 files changed

+84
-12
lines changed

crates/hir-ty/src/chalk_ext.rs

+5
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use crate::{
2222
pub trait TyExt {
2323
fn is_unit(&self) -> bool;
2424
fn is_integral(&self) -> bool;
25+
fn is_scalar(&self) -> bool;
2526
fn is_floating_point(&self) -> bool;
2627
fn is_never(&self) -> bool;
2728
fn is_unknown(&self) -> bool;
@@ -68,6 +69,10 @@ impl TyExt for Ty {
6869
)
6970
}
7071

72+
fn is_scalar(&self) -> bool {
73+
matches!(self.kind(Interner), TyKind::Scalar(_))
74+
}
75+
7176
fn is_floating_point(&self) -> bool {
7277
matches!(
7378
self.kind(Interner),

crates/hir-ty/src/consteval/tests.rs

+9-6
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ fn casts() {
179179
"#,
180180
4,
181181
);
182+
check_number(r#"const GOAL: i32 = -12i8 as i32"#, -12);
182183
}
183184

184185
#[test]
@@ -1034,16 +1035,18 @@ fn pattern_matching_literal() {
10341035
);
10351036
check_number(
10361037
r#"
1037-
const fn f(x: &str) -> u8 {
1038+
const fn f(x: &str) -> i32 {
10381039
match x {
1039-
"foo" => 1,
1040-
"bar" => 10,
1041-
_ => 100,
1040+
"f" => 1,
1041+
"foo" => 10,
1042+
"" => 100,
1043+
"bar" => 1000,
1044+
_ => 10000,
10421045
}
10431046
}
1044-
const GOAL: u8 = f("foo") + f("bar");
1047+
const GOAL: i32 = f("f") + f("foo") * 2 + f("") * 3 + f("bar") * 4;
10451048
"#,
1046-
11,
1049+
4321,
10471050
);
10481051
}
10491052

crates/hir-ty/src/mir/eval.rs

+20-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use hir_expand::InFile;
1717
use intern::Interned;
1818
use la_arena::ArenaMap;
1919
use rustc_hash::{FxHashMap, FxHashSet};
20+
use stdx::never;
2021
use syntax::{SyntaxNodePtr, TextRange};
2122
use triomphe::Arc;
2223

@@ -896,7 +897,7 @@ impl Evaluator<'_> {
896897
Owned(c)
897898
}
898899
}
899-
Rvalue::CheckedBinaryOp(op, lhs, rhs) => {
900+
Rvalue::CheckedBinaryOp(op, lhs, rhs) => 'binary_op: {
900901
let lc = self.eval_operand(lhs, locals)?;
901902
let rc = self.eval_operand(rhs, locals)?;
902903
let mut lc = lc.get(&self)?;
@@ -905,10 +906,17 @@ impl Evaluator<'_> {
905906
while let TyKind::Ref(_, _, z) = ty.kind(Interner) {
906907
ty = z.clone();
907908
let size = if ty.kind(Interner) == &TyKind::Str {
908-
let ns = from_bytes!(usize, &lc[self.ptr_size()..self.ptr_size() * 2]);
909+
if *op != BinOp::Eq {
910+
never!("Only eq is builtin for `str`");
911+
}
912+
let ls = from_bytes!(usize, &lc[self.ptr_size()..self.ptr_size() * 2]);
913+
let rs = from_bytes!(usize, &rc[self.ptr_size()..self.ptr_size() * 2]);
914+
if ls != rs {
915+
break 'binary_op Owned(vec![0]);
916+
}
909917
lc = &lc[..self.ptr_size()];
910918
rc = &rc[..self.ptr_size()];
911-
ns
919+
ls
912920
} else {
913921
self.size_of_sized(&ty, locals, "operand of binary op")?
914922
};
@@ -1200,8 +1208,15 @@ impl Evaluator<'_> {
12001208
CastKind::IntToInt
12011209
| CastKind::PointerExposeAddress
12021210
| CastKind::PointerFromExposedAddress => {
1203-
// FIXME: handle signed cast
1204-
let current = pad16(self.eval_operand(operand, locals)?.get(&self)?, false);
1211+
let current_ty = self.operand_ty(operand, locals)?;
1212+
let is_signed = match current_ty.kind(Interner) {
1213+
TyKind::Scalar(s) => match s {
1214+
chalk_ir::Scalar::Int(_) => true,
1215+
_ => false,
1216+
},
1217+
_ => false,
1218+
};
1219+
let current = pad16(self.eval_operand(operand, locals)?.get(&self)?, is_signed);
12051220
let dest_size =
12061221
self.size_of_sized(target_ty, locals, "destination of int to int cast")?;
12071222
Owned(current[0..dest_size].to_vec())

crates/hir-ty/src/mir/eval/shim.rs

+16
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,22 @@ impl Evaluator<'_> {
238238
_span: MirSpan,
239239
) -> Result<()> {
240240
match as_str {
241+
"memcmp" => {
242+
let [ptr1, ptr2, size] = args else {
243+
return Err(MirEvalError::TypeError("memcmp args are not provided"));
244+
};
245+
let addr1 = Address::from_bytes(ptr1.get(self)?)?;
246+
let addr2 = Address::from_bytes(ptr2.get(self)?)?;
247+
let size = from_bytes!(usize, size.get(self)?);
248+
let slice1 = self.read_memory(addr1, size)?;
249+
let slice2 = self.read_memory(addr2, size)?;
250+
let r: i128 = match slice1.cmp(slice2) {
251+
cmp::Ordering::Less => -1,
252+
cmp::Ordering::Equal => 0,
253+
cmp::Ordering::Greater => 1,
254+
};
255+
destination.write_from_bytes(self, &r.to_le_bytes()[..destination.size])
256+
}
241257
"write" => {
242258
let [fd, ptr, len] = args else {
243259
return Err(MirEvalError::TypeError("libc::write args are not provided"));

crates/hir-ty/src/mir/eval/tests.rs

+33
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,39 @@ fn main() {
228228
);
229229
}
230230

231+
#[test]
232+
fn memcmp() {
233+
check_pass(
234+
r#"
235+
//- minicore: slice, coerce_unsized, index
236+
237+
fn should_not_reach() -> bool {
238+
_ // FIXME: replace this function with panic when that works
239+
}
240+
241+
extern "C" {
242+
fn memcmp(s1: *const u8, s2: *const u8, n: usize) -> i32;
243+
}
244+
245+
fn my_cmp(x: &[u8], y: &[u8]) -> i32 {
246+
memcmp(x as *const u8, y as *const u8, x.len())
247+
}
248+
249+
fn main() {
250+
if my_cmp(&[1, 2, 3], &[1, 2, 3]) != 0 {
251+
should_not_reach();
252+
}
253+
if my_cmp(&[1, 20, 3], &[1, 2, 3]) <= 0 {
254+
should_not_reach();
255+
}
256+
if my_cmp(&[1, 2, 3], &[1, 20, 3]) >= 0 {
257+
should_not_reach();
258+
}
259+
}
260+
"#,
261+
);
262+
}
263+
231264
#[test]
232265
fn unix_write_stdout() {
233266
check_pass_and_stdio(

crates/hir-ty/src/mir/lower.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ impl<'ctx> MirLowerCtx<'ctx> {
829829
op,
830830
BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) | BinaryOp::Assignment { op: Some(ArithOp::Shl | ArithOp::Shr) }
831831
);
832-
lhs_ty.as_builtin().is_some() && rhs_ty.as_builtin().is_some() && (lhs_ty == rhs_ty || builtin_inequal_impls)
832+
lhs_ty.is_scalar() && rhs_ty.is_scalar() && (lhs_ty == rhs_ty || builtin_inequal_impls)
833833
};
834834
if !is_builtin {
835835
if let Some((func_id, generic_args)) = self.infer.method_resolution(expr_id) {

0 commit comments

Comments
 (0)