@@ -162,60 +162,20 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
162
162
stridedMetadata.getConstifiedMixedStrides ();
163
163
SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes ();
164
164
OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset ();
165
+ memref::LinearizedMemRefInfo linearizedInfo;
165
166
OpFoldResult linearizedIndices;
166
- std::tie (std::ignore , linearizedIndices) =
167
+ std::tie (linearizedInfo , linearizedIndices) =
167
168
memref::getLinearizedMemRefOffsetAndSize (rewriter, loc, elementBitWidth,
168
169
elementBitWidth, offset, sizes,
169
170
strides, indices);
170
171
171
- // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
172
- // Note below doesn't give the correct result for the linearized size.
173
- // Value totalSize = getValueOrCreateConstantIndexOp(
174
- // rewriter, loc, linearizedInfo.linearizedSize);
175
- // It computes the multiplied sizes of all dimensions instead of taking
176
- // the maximum of each dimension size * stride.
177
- SmallVector<AffineExpr> productExpressions;
178
- unsigned sourceRank = cast<ShapedType>(src.getType ()).getRank ();
179
-
180
- SmallVector<AffineExpr> symbols (2 * sourceRank);
181
- SmallVector<Value> offsetValues;
182
- bindSymbolsList (rewriter.getContext (), MutableArrayRef{symbols});
183
-
184
- size_t symbolIndex = 0 ;
185
- for (size_t i = 0 ; i < sourceRank; ++i) {
186
- AffineExpr strideExpr, sizeExpr;
187
- OpFoldResult stride = strides[i];
188
- OpFoldResult size = sizes[i];
189
- if (auto constantStride = getConstantIntValue (stride)) {
190
- strideExpr = rewriter.getAffineConstantExpr (*constantStride);
191
- } else {
192
- strideExpr = symbols[symbolIndex++];
193
- offsetValues.push_back (
194
- getValueOrCreateConstantIndexOp (rewriter, loc, stride));
195
- }
196
-
197
- if (auto constantSize = getConstantIntValue (size)) {
198
- sizeExpr = rewriter.getAffineConstantExpr (*constantSize);
199
- } else {
200
- sizeExpr = symbols[symbolIndex++];
201
- offsetValues.push_back (
202
- getValueOrCreateConstantIndexOp (rewriter, loc, size));
203
- }
204
-
205
- productExpressions.push_back (strideExpr * sizeExpr);
206
- }
207
-
208
- AffineMap maxMap = AffineMap::get (
209
- /* dimCount=*/ 0 , /* symbolCount=*/ symbolIndex, productExpressions,
210
- rewriter.getContext ());
211
- Value totalSize =
212
- rewriter.create <affine::AffineMaxOp>(loc, maxMap, offsetValues);
213
-
214
172
// delta = bufferSize - linearizedOffset
215
173
Value vectorSizeOffset =
216
174
rewriter.create <arith::ConstantIndexOp>(loc, vectorSize);
217
175
Value linearIndex =
218
176
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
177
+ Value totalSize = getValueOrCreateConstantIndexOp (
178
+ rewriter, loc, linearizedInfo.linearizedSize );
219
179
Value delta = rewriter.create <arith::SubIOp>(loc, totalSize, linearIndex);
220
180
221
181
// 1) check if delta < vectorSize
0 commit comments