-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR] Fixing the memref linearization size computation for non-packed memref #138922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Fixing the memref linearization size computation for non-packed memref #138922
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir-memref Author: Zhuoran Yin (jerryyin) ChangesCredit to @krzysz00 who discovered this subtle bug in BackgroundAs context, in a packed memref of However, this is wrong for a non-packed memref of SolutionThis PR come up with a fix such that the linearized size computation take strides into consideration. It computes the maximum of (dim size * dim stride) for each dimension. We'd compute the size via the affine_map of Full diff: https://github.com/llvm/llvm-project/pull/138922.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index ac397b597fd14..ae92f1bffee75 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -66,7 +66,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
SmallVector<AffineExpr> symbols(2 * sourceRank);
bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
AffineExpr addMulMap = builder.getAffineConstantExpr(0);
- AffineExpr mulMap = builder.getAffineConstantExpr(1);
SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
@@ -75,18 +74,70 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
offsetValues[offsetIdx] = indicesVec[i];
offsetValues[offsetIdx + 1] = strides[i];
-
- mulMap = mulMap * symbols[i];
}
// Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
int64_t scaler = dstBits / srcBits;
- mulMap = mulMap.floorDiv(scaler);
+
+ // If all strides and sizes are constant, we can compute the result
+ // directly without creating the AffineMaxOp.
+ int64_t constResult = 0;
+ int64_t constStride = 0;
+ int64_t constSize = 0;
+ bool isAllConstant = true;
+ for (unsigned i = 0; i < sourceRank; ++i) {
+ if (auto constantStride = getConstantIntValue(strides[i])) {
+ constStride = *constantStride;
+ } else {
+ isAllConstant = false;
+ break;
+ }
+ if (auto constantSize = getConstantIntValue(sizes[i])) {
+ constSize = *constantSize;
+ } else {
+ isAllConstant = false;
+ break;
+ }
+ constResult = std::max(constResult, constStride * constSize / scaler);
+ }
+
+ size_t symbolIndex = 0;
+ SmallVector<Value> values;
+ SmallVector<AffineExpr> productExpressions;
+ for (unsigned i = 0; i < sourceRank; ++i) {
+ AffineExpr strideExpr, sizeExpr;
+ OpFoldResult stride = strides[i];
+ OpFoldResult size = sizes[i];
+ if (auto constantStride = getConstantIntValue(stride)) {
+ strideExpr = builder.getAffineConstantExpr(*constantStride);
+ } else {
+ strideExpr = symbols[symbolIndex++];
+ values.push_back(getValueOrCreateConstantIndexOp(builder, loc, stride));
+ }
+
+ if (auto constantSize = getConstantIntValue(size)) {
+ sizeExpr = builder.getAffineConstantExpr(*constantSize);
+ } else {
+ sizeExpr = symbols[symbolIndex++];
+ values.push_back(getValueOrCreateConstantIndexOp(builder, loc, size));
+ }
+
+ productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler));
+ }
+ AffineMap maxMap = AffineMap::get(
+ /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
+ builder.getContext());
+
+ OpFoldResult linearizedSize;
+ if (isAllConstant) {
+ linearizedSize = builder.getIndexAttr(constResult);
+ } else {
+ Value totalSize = builder.create<affine::AffineMaxOp>(loc, maxMap, values);
+ linearizedSize = totalSize;
+ }
OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
builder, loc, addMulMap.floorDiv(scaler), offsetValues);
- OpFoldResult linearizedSize =
- affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
// Adjust baseOffset by the scale factor (dstBits / srcBits).
AffineExpr s0;
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 1d6cbfa343ba5..f6740fae3046e 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -104,7 +104,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
%1 = memref.load %0[%arg2, %arg3] : memref<?x?xi4>
return %1 : i4
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
// CHECK: func @memref_load_i4_dynamic(
@@ -112,7 +112,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
-// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -122,7 +122,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
// CHECK: return %[[TRUNC]]
-// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
// CHECK32: func @memref_load_i4_dynamic(
@@ -130,7 +130,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
-// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -399,7 +399,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
return
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
// CHECK: func @memref_store_i4_dynamic(
@@ -408,7 +408,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
-// CHECK-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK-DAG: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
@@ -423,7 +423,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
// CHECK: return
-// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
// CHECK32: func @memref_store_i4_dynamic(
@@ -432,7 +432,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
-// CHECK32-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32-DAG: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 9e2d131f421b7..faadf48ad3984 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -58,27 +58,27 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
%1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
return %1 : vector<8xi4>
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
// CHECK: func.func @vector_load_i4_dynamic(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
-// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi8>, vector<4xi8>
// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
-// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
// CHECK32: func.func @vector_load_i4_dynamic(
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
-// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
@@ -450,7 +450,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
return
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
// CHECK: func @vector_store_i4_dynamic
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
@@ -458,13 +458,13 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
-// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi8>, vector<4xi8>
-// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
// CHECK32: func @vector_store_i4_dynamic
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
@@ -472,7 +472,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
-// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
// CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
|
@llvm/pr-subscribers-mlir Author: Zhuoran Yin (jerryyin) ChangesCredit to @krzysz00 who discovered this subtle bug in BackgroundAs context, in a packed memref of However, this is wrong for a non-packed memref of SolutionThis PR come up with a fix such that the linearized size computation take strides into consideration. It computes the maximum of (dim size * dim stride) for each dimension. We'd compute the size via the affine_map of Full diff: https://github.com/llvm/llvm-project/pull/138922.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index ac397b597fd14..ae92f1bffee75 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -66,7 +66,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
SmallVector<AffineExpr> symbols(2 * sourceRank);
bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
AffineExpr addMulMap = builder.getAffineConstantExpr(0);
- AffineExpr mulMap = builder.getAffineConstantExpr(1);
SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
@@ -75,18 +74,70 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
offsetValues[offsetIdx] = indicesVec[i];
offsetValues[offsetIdx + 1] = strides[i];
-
- mulMap = mulMap * symbols[i];
}
// Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
int64_t scaler = dstBits / srcBits;
- mulMap = mulMap.floorDiv(scaler);
+
+ // If all strides and sizes are constant, we can compute the result
+ // directly without creating the AffineMaxOp.
+ int64_t constResult = 0;
+ int64_t constStride = 0;
+ int64_t constSize = 0;
+ bool isAllConstant = true;
+ for (unsigned i = 0; i < sourceRank; ++i) {
+ if (auto constantStride = getConstantIntValue(strides[i])) {
+ constStride = *constantStride;
+ } else {
+ isAllConstant = false;
+ break;
+ }
+ if (auto constantSize = getConstantIntValue(sizes[i])) {
+ constSize = *constantSize;
+ } else {
+ isAllConstant = false;
+ break;
+ }
+ constResult = std::max(constResult, constStride * constSize / scaler);
+ }
+
+ size_t symbolIndex = 0;
+ SmallVector<Value> values;
+ SmallVector<AffineExpr> productExpressions;
+ for (unsigned i = 0; i < sourceRank; ++i) {
+ AffineExpr strideExpr, sizeExpr;
+ OpFoldResult stride = strides[i];
+ OpFoldResult size = sizes[i];
+ if (auto constantStride = getConstantIntValue(stride)) {
+ strideExpr = builder.getAffineConstantExpr(*constantStride);
+ } else {
+ strideExpr = symbols[symbolIndex++];
+ values.push_back(getValueOrCreateConstantIndexOp(builder, loc, stride));
+ }
+
+ if (auto constantSize = getConstantIntValue(size)) {
+ sizeExpr = builder.getAffineConstantExpr(*constantSize);
+ } else {
+ sizeExpr = symbols[symbolIndex++];
+ values.push_back(getValueOrCreateConstantIndexOp(builder, loc, size));
+ }
+
+ productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler));
+ }
+ AffineMap maxMap = AffineMap::get(
+ /*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
+ builder.getContext());
+
+ OpFoldResult linearizedSize;
+ if (isAllConstant) {
+ linearizedSize = builder.getIndexAttr(constResult);
+ } else {
+ Value totalSize = builder.create<affine::AffineMaxOp>(loc, maxMap, values);
+ linearizedSize = totalSize;
+ }
OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
builder, loc, addMulMap.floorDiv(scaler), offsetValues);
- OpFoldResult linearizedSize =
- affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
// Adjust baseOffset by the scale factor (dstBits / srcBits).
AffineExpr s0;
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 1d6cbfa343ba5..f6740fae3046e 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -104,7 +104,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
%1 = memref.load %0[%arg2, %arg3] : memref<?x?xi4>
return %1 : i4
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
// CHECK: func @memref_load_i4_dynamic(
@@ -112,7 +112,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
-// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -122,7 +122,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
// CHECK: return %[[TRUNC]]
-// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
// CHECK32: func @memref_load_i4_dynamic(
@@ -130,7 +130,7 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
-// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -399,7 +399,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
return
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
// CHECK: func @memref_store_i4_dynamic(
@@ -408,7 +408,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
-// CHECK-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK-DAG: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
@@ -423,7 +423,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
// CHECK: return
-// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
// CHECK32: func @memref_store_i4_dynamic(
@@ -432,7 +432,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
-// CHECK32-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32-DAG: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 9e2d131f421b7..faadf48ad3984 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -58,27 +58,27 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
%1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
return %1 : vector<8xi4>
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
// CHECK: func.func @vector_load_i4_dynamic(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
-// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi8>, vector<4xi8>
// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
-// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
// CHECK32: func.func @vector_load_i4_dynamic(
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
-// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]], %[[ARG1]]]
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
@@ -450,7 +450,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
return
}
-// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 2, s2 floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
// CHECK: func @vector_store_i4_dynamic
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
@@ -458,13 +458,13 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
-// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi8>, vector<4xi8>
-// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) floordiv 8, s2 floordiv 8)>
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
// CHECK32: func @vector_store_i4_dynamic
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
@@ -472,7 +472,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
-// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]], %[[ARG2]]]
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
// CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential suggestion you might want to look at?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with this as a correctness fix, though could we get a test with a column-major memref and a non-packed one?
That is, for example, memref<4x8xi4, strided<[8, 1]>>
and memref<8x4xi4, strided<[1, 8]>>
?
} | ||
|
||
// Adjust linearizedIndices and size by the scale factor (dstBits / srcBits). | ||
int64_t scaler = dstBits / srcBits; | ||
mulMap = mulMap.floorDiv(scaler); | ||
size_t symbolIndex = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mostly a note to myself that this could carry a std::optional<ArrayRef<int64_t>>
permutation in the future such that memrefs with contiguous layouts could get simpler handling. But that doesn't block this PR
@krzysz00 I was attempting to construct a minimal test case like: func.func @memref_alloca_load_i4_column(%arg0: index, %arg1: index) -> i4 {
%0 = memref.alloc() : memref<4x8xi4, strided<[8, 1]>>
%1 = memref.load %0[%arg0, %arg1] : memref<4x8xi4, strided<[8, 1]>>
return %1 : i4
} But having this test case will fail the legalization and complained that the alloc is illegal.
After some debugging, looks like it is triggered from below llvm-project/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp Lines 786 to 789 in 9da103a
It seems like non-identity memref was never supported before. Not sure I read it right on this. But if this is the case, then we can't seem to construct a column major or non-packed test case just because the op associated with it can't be legalized. |
For those testcases, I'd maybe pass them in as function arguments instead of |
(But if we can't add tests, then we can't add tests, and that's fine) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! If you can add the non-packed test (with the memref as an argument to the function), then that would be very nice.
I realize you are mainly just refactoring, so my nit is optional if you have time. I'll let Krzysztof give the final approval.
4942b1a
to
b82ca5f
Compare
b82ca5f
to
08978ab
Compare
Re: column major and non-packed test case. I tried again and find those marked illegal func ops. func.func @memref_alloca_load_i4_row(%mem: memref<4x8xi4, strided<[8, 1]>>, %arg0: index, %arg1: index) {
return
} This is legal row major memref, but doesn't add any additional test coverage. func.func @memref_alloca_load_i4_column(%mem: memref<8x4xi4, strided<[1, 8]>>, %arg0: index, %arg1: index) {
return
} This is illegal but already tracked per llvm-project/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir Lines 553 to 556 in bbafa52
So it looks to me that non-conventional memref aren't well supported seem to be a known limitation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two simplifications that should also get rid of some of the new arith.constant
s
otherwise, approved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very clean now!
* main: (420 commits) [AArch64] Merge scaled and unscaled narrow zero stores (llvm#136705) [RISCV] One last migration to getInsertSubvector [nfc] [flang][OpenMP] Update `do concurrent` mapping pass to use `fir.do_concurrent` op (llvm#138489) [MLIR][LLVM] Fix llvm.mlir.global mismatching print and parser order (llvm#138986) [lld][NFC] Fix minor typo in docs (llvm#138898) [RISCV] Migrate getConstant indexed insert/extract subvector to new API (llvm#139111) GlobalISel: Translate minimumnum and maximumnum (llvm#139106) [MemProf] Simplify unittest save and restore of options (llvm#139117) [BOLT][AArch64] Patch functions targeted by optional relocs (llvm#138750) [Coverage] Support -fprofile-list for cold function coverage (llvm#136333) Remove unused forward decl (llvm#139108) [AMDGPU][NFC] Get rid of OPW constants. (llvm#139074) [CIR] Upstream extract op for VectorType (llvm#138413) [mlir][xegpu] Handle scalar uniform ops in SIMT distribution. (llvm#138593) [GlobalISel][AMDGPU] Fix handling of v2i128 type for AND, OR, XOR (llvm#138574) AMDGPU][True16][CodeGen] FP_Round f64 to f16 in true16 (llvm#128911) Reland [Clang] Deprecate `__is_trivially_relocatable` (llvm#139061) [HLSL][NFC] Stricter Overload Tests (clamp,max,min,pow) (llvm#138993) [MLIR] Fixing the memref linearization size computation for non-packed memref (llvm#138922) [TableGen][NFC] Use early exit to simplify large block in emitAction. (llvm#138220) ...
Credit to @krzysz00 who discovered this subtle bug in
MemRefUtils
. The problem is ingetLinearizedMemRefOffsetAndSize()
utility. In particular, how this subroutine computes the linearized size of a memref is incorrect when given a non-packed memref.Background
As context, in a packed memref of
memref<8x8xf32>
, we'd compute the size by multiplying the size of dimensions together. This is implemented by composing an affine_map ofaffine_map<()[s0, s1] -> (s0 * s1)>
and then computing the result of size via%size = affine.apply #map()[%c8, %c8]
.However, this is wrong for a non-packed memref of
memref<8x8xf32, strided<[1024, 1]>>
. Since the previous computed multiplication map will only consider the dimension sizes, it'd continue to conclude that the size of the non-packed memref to be 64.Solution
This PR come up with a fix such that the linearized size computation take strides into consideration. It computes the maximum of (dim size * dim stride) for each dimension. We'd compute the size via the affine_map of
affine_map<()[stride0, size0, stride1] -> ((stride0 * size0), 1 * size1)>
and then computing the size via%size = affine.max #map()[%stride0, %size0, %size1]
. In particular for the new non-packed memref, the size will be derived as max(1024*8, 1*8) = 8192 (rather than the wrong size 64 computed by packed memref equation).