Skip to content

Commit fedd79b

Browse files
authored
[mlir][vector] Tighten the semantics of vector.{load|store} (#135151)
This change refines the verifier for `vector.load` and `vector.store` to disallow the use of vectors with higher rank than the source or destination memref. For example, the following is now rejected: ```mlir %0 = vector.load %src[%c0] : memref<?xi8>, vector<16x16xi8> vector.store %vec, %dest[%c0] : memref<?xi8>, vector<16x16xi8> ``` This pattern was previously used in SME end-to-end tests and "happened" to work by implicitly assuming row-major memory layout. However, there is no guarantee that such an assumption will always hold, and we should avoid relying on it unless it can be enforced deterministically. Notably, production ArmSME lowering pipelines do not rely on this behavior. Instead, the expected usage (illustrated here with scalable vector syntax) would be: ```mlir %0 = vector.load %src[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8> ``` This PR updates the verifier accordingly and adjusts all affected tests. These tests are either removed (if no longer relevant) or updated to use memrefs with appropriately matching rank.
1 parent 179d30f commit fedd79b

File tree

7 files changed

+130
-89
lines changed

7 files changed

+130
-89
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5100,6 +5100,10 @@ LogicalResult vector::LoadOp::verify() {
51005100
if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
51015101
return failure();
51025102

5103+
if (memRefTy.getRank() < resVecTy.getRank())
5104+
return emitOpError(
5105+
"destination memref has lower rank than the result vector");
5106+
51035107
// Checks for vector memrefs.
51045108
Type memElemTy = memRefTy.getElementType();
51055109
if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {
@@ -5132,6 +5136,9 @@ LogicalResult vector::StoreOp::verify() {
51325136
if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
51335137
return failure();
51345138

5139+
if (memRefTy.getRank() < valueVecTy.getRank())
5140+
return emitOpError("source memref has lower rank than the vector to store");
5141+
51355142
// Checks for vector memrefs.
51365143
Type memElemTy = memRefTy.getElementType();
51375144
if (auto memVecTy = llvm::dyn_cast<VectorType>(memElemTy)) {

mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -718,18 +718,6 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16
718718

719719
// -----
720720

721-
// CHECK-LABEL: @vector_load_i8_from_rank_1_memref(
722-
// CHECK-SAME: %[[MEMREF:.*]]: memref<?xi8>)
723-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
724-
// CHECK: arm_sme.tile_load %[[MEMREF]][%[[C0]]] : memref<?xi8>, vector<[16]x[16]xi8>
725-
func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
726-
%c0 = arith.constant 0 : index
727-
%tile = vector.load %arg0[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
728-
return %tile : vector<[16]x[16]xi8>
729-
}
730-
731-
// -----
732-
733721
// CHECK-LABEL: @vector_load_i16(
734722
// CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
735723
func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {

mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -819,18 +819,29 @@ func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: ind
819819

820820
// -----
821821

822-
func.func @fold_vector_load_subview(
823-
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index) -> vector<12x32xf32> {
824-
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
825-
%1 = vector.load %0[] : memref<f32, strided<[], offset: ?>>, vector<12x32xf32>
826-
return %1 : vector<12x32xf32>
822+
func.func @fold_vector_load_subview(%src : memref<24x64xf32>,
823+
%off1 : index,
824+
%off2 : index,
825+
%dim1 : index,
826+
%dim2 : index,
827+
%idx : index) -> vector<12x32xf32> {
828+
829+
%0 = memref.subview %src[%off1, %off2][%dim1, %dim2][1, 1] : memref<24x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
830+
%1 = vector.load %0[%idx, %idx] : memref<?x?xf32, strided<[64, 1], offset: ?>>, vector<12x32xf32>
831+
return %1 : vector<12x32xf32>
827832
}
828833

829-
// CHECK: func @fold_vector_load_subview
830-
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
831-
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
832-
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
833-
// CHECK: vector.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<12x32xf32>
834+
// CHECK: #[[$ATTR_46:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
835+
// CHECK-LABEL: func.func @fold_vector_load_subview(
836+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]*]]: memref<24x64xf32>,
837+
// CHECK-SAME: %[[OFF_1:[a-zA-Z0-9$._-]*]]: index,
838+
// CHECK-SAME: %[[OFF_2:[a-zA-Z0-9$._-]*]]: index,
839+
// CHECK-SAME: %[[DIM_1:[a-zA-Z0-9$._-]*]]: index,
840+
// CHECK-SAME: %[[DIM_2:[a-zA-Z0-9$._-]*]]: index,
841+
// CHECK-SAME: %[[IDX:[a-zA-Z0-9$._-]*]]: index) -> vector<12x32xf32> {
842+
// CHECK: %[[VAL_6:.*]] = affine.apply #[[$ATTR_46]](){{\[}}%[[OFF_1]], %[[IDX]]]
843+
// CHECK: %[[VAL_7:.*]] = affine.apply #[[$ATTR_46]](){{\[}}%[[OFF_2]], %[[IDX]]]
844+
// CHECK: %[[VAL_8:.*]] = vector.load %[[SRC]]{{\[}}%[[VAL_6]], %[[VAL_7]]] : memref<24x64xf32>, vector<12x32xf32>
834845

835846
// -----
836847

@@ -851,20 +862,32 @@ func.func @fold_vector_maskedload_subview(
851862

852863
// -----
853864

854-
func.func @fold_vector_store_subview(
855-
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<2x32xf32>) -> () {
856-
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
857-
vector.store %arg3, %0[] : memref<f32, strided<[], offset: ?>>, vector<2x32xf32>
858-
return
865+
func.func @fold_vector_store_subview(%src : memref<24x64xf32>,
866+
%off1 : index,
867+
%off2 : index,
868+
%vec: vector<2x32xf32>,
869+
%idx : index,
870+
%dim1 : index,
871+
%dim2 : index) -> () {
872+
873+
%0 = memref.subview %src[%off1, %off2][%dim1, %dim2][1, 1] : memref<24x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
874+
vector.store %vec, %0[%idx, %idx] : memref<?x?xf32, strided<[64, 1], offset: ?>> , vector<2x32xf32>
875+
return
859876
}
860877

861-
// CHECK: func @fold_vector_store_subview
862-
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<12x32xf32>
863-
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
864-
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
865-
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<2x32xf32>
866-
// CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<12x32xf32>, vector<2x32xf32>
867-
// CHECK: return
878+
// CHECK: #[[$ATTR_47:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
879+
880+
// CHECK-LABEL: func.func @fold_vector_store_subview(
881+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9$._-]*]]: memref<24x64xf32>,
882+
// CHECK-SAME: %[[OFF1:[a-zA-Z0-9$._-]*]]: index,
883+
// CHECK-SAME: %[[OFF_2:[a-zA-Z0-9$._-]*]]: index,
884+
// CHECK-SAME: %[[VEC:[a-zA-Z0-9$._-]*]]: vector<2x32xf32>,
885+
// CHECK-SAME: %[[IDX:[a-zA-Z0-9$._-]*]]: index,
886+
// CHECK-SAME: %[[VAL_5:[a-zA-Z0-9$._-]*]]: index,
887+
// CHECK-SAME: %[[VAL_6:[a-zA-Z0-9$._-]*]]: index) {
888+
// CHECK: %[[VAL_7:.*]] = affine.apply #[[$ATTR_47]](){{\[}}%[[OFF1]], %[[IDX]]]
889+
// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_47]](){{\[}}%[[OFF_2]], %[[IDX]]]
890+
// CHECK: vector.store %[[VEC]], %[[SRC]]{{\[}}%[[VAL_7]], %[[VAL_8]]] : memref<24x64xf32>, vector<2x32xf32>
868891

869892
// -----
870893

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,13 +1743,11 @@ func.func @invalid_outerproduct(%src : memref<?xf32>) {
17431743

17441744
// -----
17451745

1746-
func.func @invalid_outerproduct1(%src : memref<?xf32>) {
1746+
func.func @invalid_outerproduct1(%src : memref<?xf32>, %lhs : vector<[4]x[4]xf32>, %rhs : vector<[4]xf32>) {
17471747
%idx = arith.constant 0 : index
1748-
%0 = vector.load %src[%idx] : memref<?xf32>, vector<[4]x[4]xf32>
1749-
%1 = vector.load %src[%idx] : memref<?xf32>, vector<[4]xf32>
17501748

17511749
// expected-error @+1 {{'vector.outerproduct' op expected 1-d vector for operand #1}}
1752-
%op = vector.outerproduct %0, %1 : vector<[4]x[4]xf32>, vector<[4]xf32>
1750+
%op = vector.outerproduct %lhs, %rhs : vector<[4]x[4]xf32>, vector<[4]xf32>
17531751
}
17541752

17551753
// -----
@@ -1870,3 +1868,29 @@ func.func @flat_transpose_scalable(%arg0: vector<[16]xf32>) -> vector<[16]xf32>
18701868
: vector<[16]xf32> -> vector<[16]xf32>
18711869
return %0 : vector<[16]xf32>
18721870
}
1871+
1872+
// -----
1873+
1874+
//===----------------------------------------------------------------------===//
1875+
// vector.load
1876+
//===----------------------------------------------------------------------===//
1877+
1878+
func.func @vector_load(%src : memref<?xi8>) {
1879+
%c0 = arith.constant 0 : index
1880+
// expected-error @+1 {{'vector.load' op destination memref has lower rank than the result vector}}
1881+
%0 = vector.load %src[%c0] : memref<?xi8>, vector<16x16xi8>
1882+
return
1883+
}
1884+
1885+
// -----
1886+
1887+
//===----------------------------------------------------------------------===//
1888+
// vector.store
1889+
//===----------------------------------------------------------------------===//
1890+
1891+
func.func @vector_store(%dest : memref<?xi8>, %vec : vector<16x16xi8>) {
1892+
%c0 = arith.constant 0 : index
1893+
// expected-error @+1 {{'vector.store' op source memref has lower rank than the vector to store}}
1894+
vector.store %vec, %dest[%c0] : memref<?xi8>, vector<16x16xi8>
1895+
return
1896+
}

mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
// CHECK-LABEL: func @vector_transfer_ops_0d_memref(
44
// CHECK-SAME: %[[MEM:.*]]: memref<f32>
5-
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1xf32>
6-
func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf32>) {
5+
// CHECK-SAME: %[[VEC:.*]]: vector<f32>
6+
func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<f32>) {
77
%f0 = arith.constant 0.0 : f32
88

99
// CHECK-NEXT: %[[S:.*]] = vector.load %[[MEM]][] : memref<f32>, vector<f32>
@@ -12,8 +12,8 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
1212
// CHECK-NEXT: vector.store %[[S]], %[[MEM]][] : memref<f32>, vector<f32>
1313
vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>
1414

15-
// CHECK-NEXT: vector.store %[[VEC]], %[[MEM]][] : memref<f32>, vector<1x1x1xf32>
16-
vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>
15+
// CHECK-NEXT: vector.store %[[VEC]], %[[MEM]][] : memref<f32>, vector<f32>
16+
vector.store %vec, %mem[] : memref<f32>, vector<f32>
1717

1818
return
1919
}

mlir/test/Integration/Dialect/Vector/CPU/ArmSME/transpose.mlir

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@ func.func @entry() {
1414

1515
// Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
1616
%svl_s = arm_sme.streaming_vl <word>
17-
%za_s_size = arith.muli %svl_s, %svl_s : index
1817

1918
// Allocate memory.
20-
%mem1 = memref.alloca(%za_s_size) : memref<?xi32>
19+
%mem1 = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
2120

2221
// Fill each "row" of "mem1" with row number.
2322
//
@@ -29,15 +28,15 @@ func.func @entry() {
2928
// 3, 3, 3, 3
3029
//
3130
%init_0 = arith.constant 0 : i32
32-
scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
31+
scf.for %i = %c0 to %svl_s step %c1 iter_args(%val = %init_0) -> (i32) {
3332
%splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
34-
vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
33+
vector.store %splat_val, %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
3534
%val_next = arith.addi %val, %c1_i32 : i32
3635
scf.yield %val_next : i32
3736
}
3837

3938
// Load tile from "mem1".
40-
%tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
39+
%tile = vector.load %mem1[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
4140

4241
// Transpose tile.
4342
%transposed_tile = vector.transpose %tile, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>

0 commit comments

Comments
 (0)