Skip to content

Commit dbe159b

Browse files
authored
[mlir] [IR] Allow zero strides in StridedLayoutAttr (#116463)
Disabling memrefs with a stride of 0 was intended to prevent internal aliasing, but this does not address all cases : internal aliasing can still occur when the stride is less than the shape. On the other hand, a stride of 0 can be very useful in certain scenarios. For example, in architectures that support multi-dimensional DMA, we can use memref::copy with a stride of 0 to achieve a broadcast effect. This commit removes the restriction that strides in memrefs cannot be 0.
1 parent 42775a4 commit dbe159b

File tree

5 files changed

+2
-35
lines changed

5 files changed

+2
-35
lines changed

mlir/lib/IR/BuiltinAttributes.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,6 @@ AffineMap StridedLayoutAttr::getAffineMap() const {
245245
LogicalResult
246246
StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
247247
int64_t offset, ArrayRef<int64_t> strides) {
248-
if (llvm::is_contained(strides, 0))
249-
return emitError() << "strides must not be zero";
250-
251248
return success();
252249
}
253250

@@ -1815,7 +1812,6 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
18151812
for (const auto &en : llvm::enumerate(strides)) {
18161813
auto dim = en.index();
18171814
auto stride = en.value();
1818-
assert(stride != 0 && "Invalid stride specification");
18191815
auto d = getAffineDimExpr(dim, context);
18201816
AffineExpr mult;
18211817
// Static case.

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -798,20 +798,6 @@ static LogicalResult getStridesAndOffset(MemRefType t,
798798
for (auto &stride : strides)
799799
stride = simplifyAffineExpr(stride, numDims, numSymbols);
800800

801-
// In practice, a strided memref must be internally non-aliasing. Test
802-
// against 0 as a proxy.
803-
// TODO: static cases can have more advanced checks.
804-
// TODO: dynamic cases would require a way to compare symbolic
805-
// expressions and would probably need an affine set context propagated
806-
// everywhere.
807-
if (llvm::any_of(strides, [](AffineExpr e) {
808-
return e == getAffineConstantExpr(0, e.getContext());
809-
})) {
810-
offset = AffineExpr();
811-
strides.clear();
812-
return failure();
813-
}
814-
815801
return success();
816802
}
817803

mlir/test/Dialect/Affine/memref-stride-calculation.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ func.func @f(%0: index) {
5151
%26 = memref.alloc(%0)[] : memref<?xf32, affine_map<(i)[M]->(i)>>
5252
// CHECK: MemRefType offset: 0 strides: 1
5353
%27 = memref.alloc()[%0] : memref<5xf32, affine_map<(i)[M]->(M)>>
54-
// CHECK: MemRefType memref<5xf32, affine_map<(d0)[s0] -> (s0)>> cannot be converted to strided form
54+
// CHECK: MemRefType offset: ? strides: 0
5555
%28 = memref.alloc()[%0] : memref<5xf32, affine_map<(i)[M]->(123)>>
56-
// CHECK: MemRefType memref<5xf32, affine_map<(d0)[s0] -> (123)>> cannot be converted to strided form
56+
// CHECK: MemRefType offset: 123 strides: 0
5757
%29 = memref.alloc()[%0] : memref<f32, affine_map<()[M]->(M)>>
5858
// CHECK: MemRefType offset: ? strides:
5959
%30 = memref.alloc()[%0] : memref<f32, affine_map<()[M]->(123)>>

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -245,16 +245,6 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {
245245

246246
// -----
247247

248-
func.func @memref_reinterpret_cast_non_strided_layout(%in: memref<?x?xf32>) {
249-
// expected-error @+1 {{expected result type to have strided layout but found 'memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>}}
250-
%out = memref.reinterpret_cast %in to
251-
offset: [0], sizes: [9, 10], strides: [42, 1]
252-
: memref<?x?xf32> to memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>
253-
return
254-
}
255-
256-
// -----
257-
258248
func.func @memref_reshape_element_type_mismatch(
259249
%buf: memref<*xf32>, %shape: memref<1xi32>) {
260250
// expected-error @+1 {{element types of source and destination memref types should be the same}}

mlir/test/IR/invalid-builtin-types.mlir

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,6 @@ func.func private @memref_incorrect_strided_ending() -> memref<?x?xf32, strided<
9999

100100
// -----
101101

102-
// expected-error @below {{strides must not be zero}}
103-
func.func private @memref_zero_stride() -> memref<?x?xf32, strided<[0, 0]>>
104-
105-
// -----
106-
107102
// expected-error @below {{expected the number of strides to match the rank}}
108103
func.func private @memref_strided_rank_mismatch() -> memref<?x?xf32, strided<[1]>>
109104

0 commit comments

Comments
 (0)