Skip to content

Commit ebceb73

Browse files
authored
[mlir][vector] Update the folder for vector.{insert|extract} (#136579)
This is a minor follow-up to #135498. It ensures that operations like the following are not treated as out-of-bounds accesses and can be folded correctly (*): ```mlir %c_neg_1 = arith.constant -1 : index %0 = vector.insert %value_to_store, %dest[%c_neg_1] : vector<5xf32> into vector<4x5xf32> %1 = vector.extract %src[%c_neg_1, 0] : f32 from vector<4x5xf32> ``` In addition to adding tests for the case above, this PR also relocates the tests from #135498 to be alongside existing tests for the `vector.{insert|extract}` folder, and reformats them to follow: * https://mlir.llvm.org/getting_started/TestingGuide/ For example: * The "no_fold" prefix is now used to label negative tests. * Redundant check lines have been removed (e.g., CHECK: vector.insert is sufficient to verify that folding did not occur). (*) As per https://mlir.llvm.org/docs/Dialects/Vector/#vectorinsert-vectorinsertop, these are poison values.
1 parent 88083a0 commit ebceb73

File tree

2 files changed

+58
-40
lines changed

2 files changed

+58
-40
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2045,8 +2045,9 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
20452045
Value position = dynamicPosition[index++];
20462046
if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
20472047
int64_t value = attr.getInt();
2048-
// Do not fold if the value is out of bounds.
2049-
if (value >= 0 && value < vectorShape[i]) {
2048+
// Do not fold if the value is out of bounds (-1 signifies a poison
2049+
// value rather than OOB index).
2050+
if (value >= -1 && value < vectorShape[i]) {
20502051
staticPosition[i] = attr.getInt();
20512052
opChange = true;
20522053
continue;

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,33 @@ func.func @extract_scalar_poison_idx(%a: vector<4x5xf32>) -> f32 {
165165
return %0 : f32
166166
}
167167

168+
// -----
169+
170+
// Similar to the test above, but the index is not a static constant.
171+
172+
// CHECK-LABEL: @extract_scalar_poison_idx_non_cst
173+
func.func @extract_scalar_poison_idx_non_cst(%a: vector<4x5xf32>) -> f32 {
174+
// CHECK-NEXT: %[[UB:.*]] = ub.poison : f32
175+
// CHECK-NOT: vector.extract
176+
// CHECK-NEXT: return %[[UB]] : f32
177+
%c_neg_1 = arith.constant -1 : index
178+
%0 = vector.extract %a[%c_neg_1, 0] : f32 from vector<4x5xf32>
179+
return %0 : f32
180+
}
181+
182+
// -----
183+
184+
// Similar to test above, but now the index is out-of-bounds.
185+
186+
// CHECK-LABEL: @no_fold_extract_scalar_oob_idx
187+
func.func @no_fold_extract_scalar_oob_idx(%a: vector<4x5xf32>) -> f32 {
188+
// CHECK: vector.extract
189+
%c_neg_2 = arith.constant -2 : index
190+
%0 = vector.extract %a[%c_neg_2, 0] : f32 from vector<4x5xf32>
191+
return %0 : f32
192+
}
193+
194+
168195
// -----
169196

170197
// CHECK-LABEL: @extract_vector_poison_idx
@@ -3062,6 +3089,34 @@ func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>)
30623089

30633090
// -----
30643091

3092+
// Similar to the test above, but the index is not a static constant.
3093+
3094+
// CHECK-LABEL: @insert_vector_poison_idx_non_cst
3095+
func.func @insert_vector_poison_idx_non_cst(%a: vector<4x5xf32>, %b: vector<5xf32>)
3096+
-> vector<4x5xf32> {
3097+
// CHECK-NEXT: %[[UB:.*]] = ub.poison : vector<4x5xf32>
3098+
// CHECK-NOT: vector.insert
3099+
// CHECK-NEXT: return %[[UB]] : vector<4x5xf32>
3100+
%c_neg_1 = arith.constant -1 : index
3101+
%0 = vector.insert %b, %a[%c_neg_1] : vector<5xf32> into vector<4x5xf32>
3102+
return %0 : vector<4x5xf32>
3103+
}
3104+
3105+
// -----
3106+
3107+
// Similar to test above, but now the index is out-of-bounds.
3108+
3109+
// CHECK-LABEL: @no_fold_insert_scalar_idx_oob
3110+
func.func @no_fold_insert_scalar_idx_oob(%a: vector<4x5xf32>, %b: vector<5xf32>)
3111+
-> vector<4x5xf32> {
3112+
// CHECK: vector.insert
3113+
%c_neg_2 = arith.constant -2 : index
3114+
%0 = vector.insert %b, %a[%c_neg_2] : vector<5xf32> into vector<4x5xf32>
3115+
return %0 : vector<4x5xf32>
3116+
}
3117+
3118+
// -----
3119+
30653120
// CHECK-LABEL: @insert_multiple_poison_idx
30663121
func.func @insert_multiple_poison_idx(%a: vector<4x5x8xf32>, %b: vector<8xf32>)
30673122
-> vector<4x5x8xf32> {
@@ -3311,41 +3366,3 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
33113366
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
33123367
return %res : vector<4x1xi32>
33133368
}
3314-
3315-
// -----
3316-
3317-
// Check that out of bounds indices are not folded for vector.insert.
3318-
3319-
// CHECK-LABEL: @fold_insert_oob
3320-
// CHECK-SAME: %[[ARG:.*]]: vector<4x1x2xi32>) -> vector<4x1x2xi32> {
3321-
// CHECK: %[[OOB1:.*]] = arith.constant -2 : index
3322-
// CHECK: %[[OOB2:.*]] = arith.constant 2 : index
3323-
// CHECK: %[[VAL:.*]] = arith.constant 1 : i32
3324-
// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, %[[OOB1]], %[[OOB2]]] : i32 into vector<4x1x2xi32>
3325-
// CHECK: return %[[RES]] : vector<4x1x2xi32>
3326-
func.func @fold_insert_oob(%arg : vector<4x1x2xi32>) -> vector<4x1x2xi32> {
3327-
%c0 = arith.constant 0 : index
3328-
%c-2 = arith.constant -2 : index
3329-
%c2 = arith.constant 2 : index
3330-
%c1 = arith.constant 1 : i32
3331-
%res = vector.insert %c1, %arg[%c0, %c-2, %c2] : i32 into vector<4x1x2xi32>
3332-
return %res : vector<4x1x2xi32>
3333-
}
3334-
3335-
// -----
3336-
3337-
// Check that out of bounds indices are not folded for vector.extract.
3338-
3339-
// CHECK-LABEL: @fold_extract_oob
3340-
// CHECK-SAME: %[[ARG:.*]]: vector<4x1x2xi32>) -> i32 {
3341-
// CHECK: %[[OOB1:.*]] = arith.constant -2 : index
3342-
// CHECK: %[[OOB2:.*]] = arith.constant 2 : index
3343-
// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, %[[OOB1]], %[[OOB2]]] : i32 from vector<4x1x2xi32>
3344-
// CHECK: return %[[RES]] : i32
3345-
func.func @fold_extract_oob(%arg : vector<4x1x2xi32>) -> i32 {
3346-
%c0 = arith.constant 0 : index
3347-
%c-2 = arith.constant -2 : index
3348-
%c2 = arith.constant 2 : index
3349-
%res = vector.extract %arg[%c0, %c-2, %c2] : i32 from vector<4x1x2xi32>
3350-
return %res : i32
3351-
}

0 commit comments

Comments
 (0)