Skip to content

Commit cde70c4

Browse files
committed
Amending amdgpu transfer-read to use new linearized size
1 parent 2e39f5d commit cde70c4

File tree

2 files changed

+9
-49
lines changed

2 files changed

+9
-49
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -162,60 +162,20 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
162162
stridedMetadata.getConstifiedMixedStrides();
163163
SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
164164
OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
165+
memref::LinearizedMemRefInfo linearizedInfo;
165166
OpFoldResult linearizedIndices;
166-
std::tie(std::ignore, linearizedIndices) =
167+
std::tie(linearizedInfo, linearizedIndices) =
167168
memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
168169
elementBitWidth, offset, sizes,
169170
strides, indices);
170171

171-
// TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
172-
// Note below doesn't give the correct result for the linearized size.
173-
// Value totalSize = getValueOrCreateConstantIndexOp(
174-
// rewriter, loc, linearizedInfo.linearizedSize);
175-
// It computes the multiplied sizes of all dimensions instead of taking
176-
// the maximum of each dimension size * stride.
177-
SmallVector<AffineExpr> productExpressions;
178-
unsigned sourceRank = cast<ShapedType>(src.getType()).getRank();
179-
180-
SmallVector<AffineExpr> symbols(2 * sourceRank);
181-
SmallVector<Value> offsetValues;
182-
bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
183-
184-
size_t symbolIndex = 0;
185-
for (size_t i = 0; i < sourceRank; ++i) {
186-
AffineExpr strideExpr, sizeExpr;
187-
OpFoldResult stride = strides[i];
188-
OpFoldResult size = sizes[i];
189-
if (auto constantStride = getConstantIntValue(stride)) {
190-
strideExpr = rewriter.getAffineConstantExpr(*constantStride);
191-
} else {
192-
strideExpr = symbols[symbolIndex++];
193-
offsetValues.push_back(
194-
getValueOrCreateConstantIndexOp(rewriter, loc, stride));
195-
}
196-
197-
if (auto constantSize = getConstantIntValue(size)) {
198-
sizeExpr = rewriter.getAffineConstantExpr(*constantSize);
199-
} else {
200-
sizeExpr = symbols[symbolIndex++];
201-
offsetValues.push_back(
202-
getValueOrCreateConstantIndexOp(rewriter, loc, size));
203-
}
204-
205-
productExpressions.push_back(strideExpr * sizeExpr);
206-
}
207-
208-
AffineMap maxMap = AffineMap::get(
209-
/*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
210-
rewriter.getContext());
211-
Value totalSize =
212-
rewriter.create<affine::AffineMaxOp>(loc, maxMap, offsetValues);
213-
214172
// delta = bufferSize - linearizedOffset
215173
Value vectorSizeOffset =
216174
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
217175
Value linearIndex =
218176
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
177+
Value totalSize = getValueOrCreateConstantIndexOp(
178+
rewriter, loc, linearizedInfo.linearizedSize);
219179
Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
220180

221181
// 1) check if delta < vectorSize

mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
5252

5353
// -----
5454

55-
// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
56-
// CHECK: #map1 = affine_map<()[s0, s1, s2] -> (s0 * s1, s2)>
57-
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(
55+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
56+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> (s0 * s1, s2)>
57+
// CHECK: func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(
5858
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>
5959
// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
6060
// CHECK-SAME: %[[ARG3:.*]]: vector<4xi1>
@@ -68,8 +68,8 @@ func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref<?x?xi8,
6868
// CHECK: %[[C0:.*]] = arith.constant 0 : index
6969
// CHECK: %[[C4:.*]] = arith.constant 4 : index
7070
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
71-
// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[STRIDES]]#0, %[[ARG2]]]
72-
// CHECK: %[[SIZE:.*]] = affine.max #map1()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[SIZES]]#1]
71+
// CHECK: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[SIZES]]#1]
72+
// CHECK: %[[LINEAR:.*]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[STRIDES]]#0, %[[ARG2]]]
7373
// CHECK: %[[IF:.*]] = scf.if
7474
// CHECK: return
7575

0 commit comments

Comments
 (0)