Skip to content

[MLIR] Fixing the memref linearization size computation for non-packed memref #138922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 8, 2025

Conversation

jerryyin
Copy link
Member

@jerryyin jerryyin commented May 7, 2025

Credit to @krzysz00 who discovered this subtle bug in MemRefUtils. The problem is in getLinearizedMemRefOffsetAndSize() utility. In particular, how this subroutine computes the linearized size of a memref is incorrect when given a non-packed memref.

Background

As context, in a packed memref of memref<8x8xf32>, we'd compute the size by multiplying the size of dimensions together. This is implemented by composing an affine_map of affine_map<()[s0, s1] -> (s0 * s1)> and then computing the result of size via %size = affine.apply #map()[%c8, %c8].

However, this is wrong for a non-packed memref of memref<8x8xf32, strided<[1024, 1]>>. Since the previous computed multiplication map will only consider the dimension sizes, it'd continue to conclude that the size of the non-packed memref to be 64.

Solution

This PR come up with a fix such that the linearized size computation take strides into consideration. It computes the maximum of (dim size * dim stride) for each dimension. We'd compute the size via the affine_map of affine_map<()[stride0, size0, stride1] -> ((stride0 * size0), 1 * size1)> and then computing the size via %size = affine.max #map()[%stride0, %size0, %size1]. In particular for the new non-packed memref, the size will be derived as max(1024*8, 1*8) = 8192 (rather than the wrong size 64 computed by packed memref equation).

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-mlir-amdgpu
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-memref

Author: Zhuoran Yin (jerryyin)

Changes

Credit to @krzysz00 who discovered this subtle bug in MemRefUtils. The problem is in getLinearizedMemRefOffsetAndSize() utility. In particular, how this subroutine computes the linearized size of a memref is incorrect when given a non-packed memref.

Background

As context, in a packed memref of memref&lt;8x8xf32&gt;, we'd compute the size by multiplying the size of dimensions together. This is implemented by composing an affine_map of affine_map&lt;()[s0, s1] -&gt; (s0 * s1)&gt; and then computing the result of size via %size = affine.apply #map()[%c8, %c8].

However, this is wrong for a non-packed memref of memref&lt;8x8xf32, strided&lt;[1024, 1]&gt;&gt;. Since the previous computed multiplication map will only consider the dimension sizes, it'd continue to conclude that the size of the non-packed memref to be 64.

Solution

This PR come up with a fix such that the linearized size computation take strides into consideration. It computes the maximum of (dim size * dim stride) for each dimension. We'd compute the size via the affine_map of affine_map&lt;()[stride0, size0, stride1] -&gt; ((stride0 * size0), 1 * size1)&gt; and then computing the size via %size = affine.max #map()[%stride0, %size0, %size1]. In particular for the new non-packed memref, the size will be derived as max(1024*8, 1*8) = 8192 (rather than the wrong size 64 computed by packed memref equation).


Full diff: https://github.com/llvm/llvm-project/pull/138922.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp (+57-6)
  • (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+8-8)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+8-8)
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index ac397b597fd14..ae92f1bffee75 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -66,7 +66,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
   SmallVector<AffineExpr> symbols(2 * sourceRank);
   bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
   AffineExpr addMulMap = builder.getAffineConstantExpr(0);
-  AffineExpr mulMap = builder.getAffineConstantExpr(1);
 
   SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
 
@@ -75,18 +74,70 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
     addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
     offsetValues[offsetIdx] = indicesVec[i];
     offsetValues[offsetIdx + 1] = strides[i];
-
-    mulMap = mulMap * symbols[i];
   }
 
   // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
   int64_t scaler = dstBits / srcBits;
-  mulMap = mulMap.floorDiv(scaler);
+
+  // If all strides and sizes are constant, we can compute the result
+  // directly without creating the AffineMaxOp.
+  int64_t constResult = 0;
+  int64_t constStride = 0;
+  int64_t constSize = 0;
+  bool isAllConstant = true;
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    if (auto constantStride = getConstantIntValue(strides[i])) {
+      constStride = *constantStride;
+    } else {
+      isAllConstant = false;
+      break;
+    }
+    if (auto constantSize = getConstantIntValue(sizes[i])) {
+      constSize = *constantSize;
+    } else {
+      isAllConstant = false;
+      break;
+    }
+    constResult = std::max(constResult, constStride * constSize / scaler);
+  }
+
+  size_t symbolIndex = 0;
+  SmallVector<Value> values;
+  SmallVector<AffineExpr> productExpressions;
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    AffineExpr strideExpr, sizeExpr;
+    OpFoldResult stride = strides[i];
+    OpFoldResult size = sizes[i];
+    if (auto constantStride = getConstantIntValue(stride)) {
+      strideExpr = builder.getAffineConstantExpr(*constantStride);
+    } else {
+      strideExpr = symbols[symbolIndex++];
+      values.push_back(getValueOrCreateConstantIndexOp(builder, loc, stride));
+    }
+
+    if (auto constantSize = getConstantIntValue(size)) {
+      sizeExpr = builder.getAffineConstantExpr(*constantSize);
+    } else {
+      sizeExpr = symbols[symbolIndex++];
+      values.push_back(getValueOrCreateConstantIndexOp(builder, loc, size));
+    }
+
+    productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler));
+  }
+  AffineMap maxMap = AffineMap::get(
+      /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
+      builder.getContext());
+
+  OpFoldResult linearizedSize;
+  if (isAllConstant) {
+    linearizedSize = builder.getIndexAttr(constResult);
+  } else {
+    Value totalSize = builder.create<affine::AffineMaxOp>(loc, maxMap, values);
+    linearizedSize = totalSize;
+  }
 
   OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
       builder, loc, addMulMap.floorDiv(scaler), offsetValues);
-  OpFoldResult linearizedSize =
-      affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
 
   // Adjust baseOffset by the scale factor (dstBits / srcBits).
   AffineExpr s0;
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 1d6cbfa343ba5..f6740fae3046e 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -104,7 +104,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
   %1 = memref.load %0[%arg2, %arg3] : memref<?x?xi4>
   return %1 : i4
 }
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
 //      CHECK: func @memref_load_i4_dynamic(
@@ -112,7 +112,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
-//      CHECK:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //      CHECK:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
 //      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -122,7 +122,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
 //      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
 //      CHECK:   return %[[TRUNC]]
 
-//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //  CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
 //      CHECK32: func @memref_load_i4_dynamic(
@@ -130,7 +130,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
 // CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
-//      CHECK32:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
 //      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK32:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -399,7 +399,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
   memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
   return
 }
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
 //      CHECK: func @memref_store_i4_dynamic(
@@ -408,7 +408,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
-//  CHECK-DAG:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK-DAG:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
 //  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
 //  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
@@ -423,7 +423,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
 //      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
 //      CHECK:   return
 
-//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //  CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
 //      CHECK32: func @memref_store_i4_dynamic(
@@ -432,7 +432,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
-//  CHECK32-DAG:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK32-DAG:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
 //  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
 //  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 9e2d131f421b7..faadf48ad3984 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -58,27 +58,27 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
   %1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
   return %1 : vector<8xi4>
 }
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //      CHECK: func.func @vector_load_i4_dynamic(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
-//      CHECK:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //      CHECK:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
 //      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi8>, vector<4xi8>
 //      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
 
-//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //      CHECK32: func.func @vector_load_i4_dynamic(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
 // CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
 // CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
-//      CHECK32:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
 //      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK32:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
@@ -450,7 +450,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
     return
 }
 
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //      CHECK: func @vector_store_i4_dynamic
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
@@ -458,13 +458,13 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:   %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9]+]]: index
-//      CHECK: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
 //      CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
 //      CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
 //      CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
 //      CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi8>, vector<4xi8>
 
-//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //      CHECK32: func @vector_store_i4_dynamic
 // CHECK32-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
@@ -472,7 +472,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
 // CHECK32-SAME:   %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:   %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:   %[[ARG4:[a-zA-Z0-9]+]]: index
-//      CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK32: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
 //      CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
 //      CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
 //      CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-mlir

Author: Zhuoran Yin (jerryyin)

Changes

Credit to @krzysz00 who discovered this subtle bug in MemRefUtils. The problem is in getLinearizedMemRefOffsetAndSize() utility. In particular, how this subroutine computes the linearized size of a memref is incorrect when given a non-packed memref.

Background

As context, in a packed memref of memref&lt;8x8xf32&gt;, we'd compute the size by multiplying the size of dimensions together. This is implemented by composing an affine_map of affine_map&lt;()[s0, s1] -&gt; (s0 * s1)&gt; and then computing the result of size via %size = affine.apply #map()[%c8, %c8].

However, this is wrong for a non-packed memref of memref&lt;8x8xf32, strided&lt;[1024, 1]&gt;&gt;. Since the previous computed multiplication map will only consider the dimension sizes, it'd continue to conclude that the size of the non-packed memref to be 64.

Solution

This PR come up with a fix such that the linearized size computation take strides into consideration. It computes the maximum of (dim size * dim stride) for each dimension. We'd compute the size via the affine_map of affine_map&lt;()[stride0, size0, stride1] -&gt; ((stride0 * size0), 1 * size1)&gt; and then computing the size via %size = affine.max #map()[%stride0, %size0, %size1]. In particular for the new non-packed memref, the size will be derived as max(1024*8, 1*8) = 8192 (rather than the wrong size 64 computed by packed memref equation).


Full diff: https://github.com/llvm/llvm-project/pull/138922.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp (+57-6)
  • (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+8-8)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+8-8)
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index ac397b597fd14..ae92f1bffee75 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -66,7 +66,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
   SmallVector<AffineExpr> symbols(2 * sourceRank);
   bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
   AffineExpr addMulMap = builder.getAffineConstantExpr(0);
-  AffineExpr mulMap = builder.getAffineConstantExpr(1);
 
   SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
 
@@ -75,18 +74,70 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
     addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
     offsetValues[offsetIdx] = indicesVec[i];
     offsetValues[offsetIdx + 1] = strides[i];
-
-    mulMap = mulMap * symbols[i];
   }
 
   // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
   int64_t scaler = dstBits / srcBits;
-  mulMap = mulMap.floorDiv(scaler);
+
+  // If all strides and sizes are constant, we can compute the result
+  // directly without creating the AffineMaxOp.
+  int64_t constResult = 0;
+  int64_t constStride = 0;
+  int64_t constSize = 0;
+  bool isAllConstant = true;
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    if (auto constantStride = getConstantIntValue(strides[i])) {
+      constStride = *constantStride;
+    } else {
+      isAllConstant = false;
+      break;
+    }
+    if (auto constantSize = getConstantIntValue(sizes[i])) {
+      constSize = *constantSize;
+    } else {
+      isAllConstant = false;
+      break;
+    }
+    constResult = std::max(constResult, constStride * constSize / scaler);
+  }
+
+  size_t symbolIndex = 0;
+  SmallVector<Value> values;
+  SmallVector<AffineExpr> productExpressions;
+  for (unsigned i = 0; i < sourceRank; ++i) {
+    AffineExpr strideExpr, sizeExpr;
+    OpFoldResult stride = strides[i];
+    OpFoldResult size = sizes[i];
+    if (auto constantStride = getConstantIntValue(stride)) {
+      strideExpr = builder.getAffineConstantExpr(*constantStride);
+    } else {
+      strideExpr = symbols[symbolIndex++];
+      values.push_back(getValueOrCreateConstantIndexOp(builder, loc, stride));
+    }
+
+    if (auto constantSize = getConstantIntValue(size)) {
+      sizeExpr = builder.getAffineConstantExpr(*constantSize);
+    } else {
+      sizeExpr = symbols[symbolIndex++];
+      values.push_back(getValueOrCreateConstantIndexOp(builder, loc, size));
+    }
+
+    productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler));
+  }
+  AffineMap maxMap = AffineMap::get(
+      /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
+      builder.getContext());
+
+  OpFoldResult linearizedSize;
+  if (isAllConstant) {
+    linearizedSize = builder.getIndexAttr(constResult);
+  } else {
+    Value totalSize = builder.create<affine::AffineMaxOp>(loc, maxMap, values);
+    linearizedSize = totalSize;
+  }
 
   OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
       builder, loc, addMulMap.floorDiv(scaler), offsetValues);
-  OpFoldResult linearizedSize =
-      affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
 
   // Adjust baseOffset by the scale factor (dstBits / srcBits).
   AffineExpr s0;
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 1d6cbfa343ba5..f6740fae3046e 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -104,7 +104,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
   %1 = memref.load %0[%arg2, %arg3] : memref<?x?xi4>
   return %1 : i4
 }
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
 //      CHECK: func @memref_load_i4_dynamic(
@@ -112,7 +112,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
-//      CHECK:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //      CHECK:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
 //      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -122,7 +122,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
 //      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
 //      CHECK:   return %[[TRUNC]]
 
-//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //  CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
 //      CHECK32: func @memref_load_i4_dynamic(
@@ -130,7 +130,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
 // CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
-//      CHECK32:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
 //      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK32:   %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -399,7 +399,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
   memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
   return
 }
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //  CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
 //      CHECK: func @memref_store_i4_dynamic(
@@ -408,7 +408,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
-//  CHECK-DAG:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK-DAG:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
 //  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
 //  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
@@ -423,7 +423,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
 //      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
 //      CHECK:   return
 
-//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //  CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
 //      CHECK32: func @memref_store_i4_dynamic(
@@ -432,7 +432,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: i4
-//  CHECK32-DAG:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//  CHECK32-DAG:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
 //  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
 //  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 9e2d131f421b7..faadf48ad3984 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -58,27 +58,27 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
   %1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
   return %1 : vector<8xi4>
 }
-//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //      CHECK: func.func @vector_load_i4_dynamic(
 // CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
 // CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
-//      CHECK:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //      CHECK:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
 //      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi8>, vector<4xi8>
 //      CHECK:   %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
 
-//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //      CHECK32: func.func @vector_load_i4_dynamic(
 // CHECK32-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: index
 // CHECK32-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: index
 // CHECK32-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: index
 // CHECK32-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: index
-//      CHECK32:   %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
 //      CHECK32:   %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
 //      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
 //      CHECK32:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
@@ -450,7 +450,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
     return
 }
 
-//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
 //  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
 //      CHECK: func @vector_store_i4_dynamic
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
@@ -458,13 +458,13 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:   %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK-SAME:   %[[ARG4:[a-zA-Z0-9]+]]: index
-//      CHECK: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
 //      CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
 //      CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
 //      CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
 //      CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi8>, vector<4xi8>
 
-//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
 //  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
 //      CHECK32: func @vector_store_i4_dynamic
 // CHECK32-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
@@ -472,7 +472,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
 // CHECK32-SAME:   %[[ARG2:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:   %[[ARG3:[a-zA-Z0-9]+]]: index
 // CHECK32-SAME:   %[[ARG4:[a-zA-Z0-9]+]]: index
-//      CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK32: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
 //      CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
 //      CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
 //      CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>

@jerryyin jerryyin requested review from krzysz00, hanhanW and Max191 May 7, 2025 17:46
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential suggestion you might want to look at?

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with this as a correctness fix, though could we get a test with a column-major memref and a non-packed one?

That is, for example, memref<4x8xi4, strided<[8, 1]>> and memref<8x4xi4, strided<[1, 8]>>?

}

// Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
int64_t scaler = dstBits / srcBits;
mulMap = mulMap.floorDiv(scaler);
size_t symbolIndex = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly a note to myself that this could carry a std::optional<ArrayRef<int64_t>> permutation in the future such that memrefs with contiguous layouts could get simpler handling. But that doesn't block this PR

@jerryyin
Copy link
Member Author

jerryyin commented May 7, 2025

@krzysz00 I was attempting to construct a minimal test case like:

func.func @memref_alloca_load_i4_column(%arg0: index, %arg1: index) -> i4 {
  %0 = memref.alloc() : memref<4x8xi4, strided<[8, 1]>>
  %1 = memref.load %0[%arg0, %arg1] : memref<4x8xi4, strided<[8, 1]>>
  return %1 : i4
}

But having this test case will fail the legalization and complained that the alloc is illegal.

** Failure : alloc-like operations should have been normalized
"(anonymous namespace)::ExtractStridedMetadataOpAllocFoldermlir::memref::AllocOp" result 0

After some debugging, looks like it is triggered from below

auto memRefType = cast<MemRefType>(allocLikeOp.getResult().getType());
if (!memRefType.getLayout().isIdentity())
return rewriter.notifyMatchFailure(
allocLikeOp, "alloc-like operations should have been normalized");

It seems like non-identity memref was never supported before. Not sure I read it right on this. But if this is the case, then we can't seem to construct a column major or non-packed test case just because the op associated with it can't be legalized.

@krzysz00
Copy link
Contributor

krzysz00 commented May 7, 2025

For those testcases, I'd maybe pass them in as function arguments instead of alloc()ing them?

@krzysz00
Copy link
Contributor

krzysz00 commented May 7, 2025

(But if we can't add tests, then we can't add tests, and that's fine)

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! If you can add the non-packed test (with the memref as an argument to the function), then that would be very nice.

I realize you are mainly just refactoring, so my nit is optional if you have time. I'll let Krzysztof give the final approval.

@jerryyin jerryyin force-pushed the users/zyin/fix-memref-linearization-size-maxMap branch from 4942b1a to b82ca5f Compare May 8, 2025 14:39
@jerryyin jerryyin force-pushed the users/zyin/fix-memref-linearization-size-maxMap branch from b82ca5f to 08978ab Compare May 8, 2025 14:43
@jerryyin
Copy link
Member Author

jerryyin commented May 8, 2025

Re: column major and non-packed test case. I tried again and find those marked illegal func ops.

func.func @memref_alloca_load_i4_row(%mem: memref<4x8xi4, strided<[8, 1]>>, %arg0: index, %arg1: index) {
  return
}

This is legal row major memref, but doesn't add any additional test coverage.

func.func @memref_alloca_load_i4_column(%mem: memref<8x4xi4, strided<[1, 8]>>, %arg0: index, %arg1: index) {
  return
}

This is illegal but already tracked per

// expected-error @+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
func.func @argument_non_contiguous(%arg0 : memref<8x8xi4, strided<[1, 8]>>) {
return
}

So it looks to me that non-conventional memref aren't well supported seem to be a known limitation.

@jerryyin jerryyin changed the title [MLIR] Fixing the memref linearization size computation [MLIR] Fixing the memref linearization size computation for non-packed memref May 8, 2025
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two simplifications that should also get rid of some of the new arith.constants

otherwise, approved

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very clean now!

@jerryyin jerryyin merged commit 53e8ff1 into main May 8, 2025
9 of 10 checks passed
@jerryyin jerryyin deleted the users/zyin/fix-memref-linearization-size-maxMap branch May 8, 2025 17:14
lenary added a commit to lenary/llvm-project that referenced this pull request May 9, 2025
* main: (420 commits)
  [AArch64] Merge scaled and unscaled narrow zero stores (llvm#136705)
  [RISCV] One last migration to getInsertSubvector [nfc]
  [flang][OpenMP] Update `do concurrent` mapping pass to use `fir.do_concurrent` op (llvm#138489)
  [MLIR][LLVM] Fix llvm.mlir.global mismatching print and parser order (llvm#138986)
  [lld][NFC] Fix minor typo in docs (llvm#138898)
  [RISCV] Migrate getConstant indexed insert/extract subvector to new API (llvm#139111)
  GlobalISel: Translate minimumnum and maximumnum (llvm#139106)
  [MemProf] Simplify unittest save and restore of options (llvm#139117)
  [BOLT][AArch64] Patch functions targeted by optional relocs (llvm#138750)
  [Coverage] Support -fprofile-list for cold function coverage (llvm#136333)
  Remove unused forward decl (llvm#139108)
  [AMDGPU][NFC] Get rid of OPW constants. (llvm#139074)
  [CIR] Upstream extract op for VectorType (llvm#138413)
  [mlir][xegpu] Handle scalar uniform ops in SIMT distribution.  (llvm#138593)
  [GlobalISel][AMDGPU] Fix handling of v2i128 type for AND, OR, XOR (llvm#138574)
  AMDGPU][True16][CodeGen] FP_Round f64 to f16 in true16 (llvm#128911)
  Reland [Clang] Deprecate `__is_trivially_relocatable` (llvm#139061)
  [HLSL][NFC] Stricter Overload Tests (clamp,max,min,pow) (llvm#138993)
  [MLIR] Fixing the memref linearization size computation for non-packed memref (llvm#138922)
  [TableGen][NFC] Use early exit to simplify large block in emitAction. (llvm#138220)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants