|
13 | 13 | #include "mlir/Dialect/Arith/IR/Arith.h"
|
14 | 14 | #include "mlir/Dialect/Vector/IR/VectorOps.h"
|
15 | 15 | #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
| 16 | +#include "mlir/IR/Attributes.h" |
| 17 | +#include "mlir/IR/BuiltinAttributes.h" |
| 18 | +#include "mlir/IR/Operation.h" |
16 | 19 | #include "mlir/IR/PatternMatch.h"
|
17 | 20 | #include "mlir/IR/TypeUtilities.h"
|
| 21 | +#include "mlir/Support/LogicalResult.h" |
18 | 22 | #include "mlir/Transforms/DialectConversion.h"
|
| 23 | +#include "llvm/ADT/ArrayRef.h" |
| 24 | +#include <cstdint> |
| 25 | +#include <numeric> |
19 | 26 |
|
20 | 27 | using namespace mlir;
|
21 | 28 |
|
@@ -103,6 +110,251 @@ struct LinearizeVectorizable final
|
103 | 110 | return success();
|
104 | 111 | }
|
105 | 112 |
|
| 113 | +private: |
| 114 | + unsigned targetVectorBitWidth; |
| 115 | +}; |
| 116 | + |
| 117 | +/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works |
| 118 | +/// on a linearized vector. |
| 119 | +/// Following, |
| 120 | +/// vector.extract_strided_slice %source |
| 121 | +/// { offsets = [..], strides = [..], sizes = [..] } |
| 122 | +/// is converted to : |
| 123 | +/// %source_1d = vector.shape_cast %source |
| 124 | +/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] |
| 125 | +/// %out_nd = vector.shape_cast %out_1d |
| 126 | +/// `shuffle_indices_1d` is computed using the offsets and sizes of the |
| 127 | +/// extraction. |
| 128 | +struct LinearizeVectorExtractStridedSlice final |
| 129 | + : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> { |
| 130 | + using OpConversionPattern::OpConversionPattern; |
| 131 | + LinearizeVectorExtractStridedSlice( |
| 132 | + const TypeConverter &typeConverter, MLIRContext *context, |
| 133 | + unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), |
| 134 | + PatternBenefit benefit = 1) |
| 135 | + : OpConversionPattern(typeConverter, context, benefit), |
| 136 | + targetVectorBitWidth(targetVectBitWidth) {} |
| 137 | + |
| 138 | + LogicalResult |
| 139 | + matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, |
| 140 | + ConversionPatternRewriter &rewriter) const override { |
| 141 | + Type dstType = getTypeConverter()->convertType(extractOp.getType()); |
| 142 | + assert(!(extractOp.getVector().getType().isScalable() || |
| 143 | + dstType.cast<VectorType>().isScalable()) && |
| 144 | + "scalable vectors are not supported."); |
| 145 | + if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) |
| 146 | + return rewriter.notifyMatchFailure( |
| 147 | + extractOp, "Can't flatten since targetBitWidth <= OpSize"); |
| 148 | + |
| 149 | + ArrayAttr offsets = extractOp.getOffsets(); |
| 150 | + ArrayAttr sizes = extractOp.getSizes(); |
| 151 | + ArrayAttr strides = extractOp.getStrides(); |
| 152 | + if (!isConstantIntValue(strides[0], 1)) |
| 153 | + return rewriter.notifyMatchFailure( |
| 154 | + extractOp, "Strided slice with stride != 1 is not supported."); |
| 155 | + Value srcVector = adaptor.getVector(); |
| 156 | + // If kD offsets are specified for nD source vector (n > k), the granularity |
| 157 | + // of the extraction is greater than 1. In this case last (n-k) dimensions |
| 158 | + // form the extraction granularity. |
| 159 | + // Example : |
| 160 | + // vector.extract_strided_slice %src { |
| 161 | + // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : |
| 162 | + // vector<4x8x8xf32> to vector<2x2x8xf32> |
| 163 | + // Here, extraction granularity is 8. |
| 164 | + int64_t extractGranularitySize = 1; |
| 165 | + int64_t nD = extractOp.getSourceVectorType().getRank(); |
| 166 | + int64_t kD = (int64_t)offsets.size(); |
| 167 | + int64_t k = kD; |
| 168 | + while (k < nD) { |
| 169 | + extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k]; |
| 170 | + ++k; |
| 171 | + } |
| 172 | + // Get total number of extracted slices. |
| 173 | + int64_t nExtractedSlices = 1; |
| 174 | + for (Attribute size : sizes) { |
| 175 | + nExtractedSlices *= size.cast<IntegerAttr>().getInt(); |
| 176 | + } |
| 177 | + // Compute the strides of the source vector considering first k dimensions. |
| 178 | + llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize); |
| 179 | + for (int i = kD - 2; i >= 0; --i) { |
| 180 | + sourceStrides[i] = sourceStrides[i + 1] * |
| 181 | + extractOp.getSourceVectorType().getShape()[i + 1]; |
| 182 | + } |
| 183 | + // Final shuffle indices has nExtractedSlices * extractGranularitySize |
| 184 | + // elements. |
| 185 | + llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * |
| 186 | + extractGranularitySize); |
| 187 | + // Compute the strides of the extracted kD vector. |
| 188 | + llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1); |
| 189 | + // Compute extractedStrides. |
| 190 | + for (int i = kD - 2; i >= 0; --i) { |
| 191 | + extractedStrides[i] = |
| 192 | + extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt(); |
| 193 | + } |
| 194 | + // Iterate over all extracted slices from 0 to nExtractedSlices - 1 |
| 195 | + // and compute the multi-dimensional index and the corresponding linearized |
| 196 | + // index within the source vector. |
| 197 | + for (int64_t i = 0; i < nExtractedSlices; ++i) { |
| 198 | + int64_t index = i; |
| 199 | + // Compute the corresponding multi-dimensional index. |
| 200 | + llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0); |
| 201 | + for (int64_t j = 0; j < kD; ++j) { |
| 202 | + multiDimIndex[j] = (index / extractedStrides[j]); |
| 203 | + index -= multiDimIndex[j] * extractedStrides[j]; |
| 204 | + } |
| 205 | + // Compute the corresponding linearized index in the source vector |
| 206 | + // i.e. shift the multiDimIndex by the offsets. |
| 207 | + int64_t linearizedIndex = 0; |
| 208 | + for (int64_t j = 0; j < kD; ++j) { |
| 209 | + linearizedIndex += |
| 210 | + (offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) * |
| 211 | + sourceStrides[j]; |
| 212 | + } |
| 213 | + // Fill the indices array form linearizedIndex to linearizedIndex + |
| 214 | + // extractGranularitySize. |
| 215 | + for (int64_t j = 0; j < extractGranularitySize; ++j) { |
| 216 | + indices[i * extractGranularitySize + j] = linearizedIndex + j; |
| 217 | + } |
| 218 | + } |
| 219 | + // Perform a shuffle to extract the kD vector. |
| 220 | + rewriter.replaceOpWithNewOp<vector::ShuffleOp>( |
| 221 | + extractOp, dstType, srcVector, srcVector, |
| 222 | + rewriter.getI64ArrayAttr(indices)); |
| 223 | + return success(); |
| 224 | + } |
| 225 | + |
| 226 | +private: |
| 227 | + unsigned targetVectorBitWidth; |
| 228 | +}; |
| 229 | + |
| 230 | +/// This pattern converts the ShuffleOp that works on nD (n > 1) |
| 231 | +/// vectors to a ShuffleOp that works on linearized vectors. |
| 232 | +/// Following, |
| 233 | +/// vector.shuffle %v1, %v2 [ shuffle_indices ] |
| 234 | +/// is converted to : |
| 235 | +/// %v1_1d = vector.shape_cast %v1 |
| 236 | +/// %v2_1d = vector.shape_cast %v2 |
| 237 | +/// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ] |
| 238 | +/// %out_nd = vector.shape_cast %out_1d |
| 239 | +// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices` |
| 240 | +/// of the original shuffle operation. |
| 241 | +struct LinearizeVectorShuffle final |
| 242 | + : public OpConversionPattern<vector::ShuffleOp> { |
| 243 | + using OpConversionPattern::OpConversionPattern; |
| 244 | + LinearizeVectorShuffle( |
| 245 | + const TypeConverter &typeConverter, MLIRContext *context, |
| 246 | + unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), |
| 247 | + PatternBenefit benefit = 1) |
| 248 | + : OpConversionPattern(typeConverter, context, benefit), |
| 249 | + targetVectorBitWidth(targetVectBitWidth) {} |
| 250 | + |
| 251 | + LogicalResult |
| 252 | + matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, |
| 253 | + ConversionPatternRewriter &rewriter) const override { |
| 254 | + Type dstType = getTypeConverter()->convertType(shuffleOp.getType()); |
| 255 | + assert(!(shuffleOp.getV1VectorType().isScalable() || |
| 256 | + shuffleOp.getV2VectorType().isScalable() || |
| 257 | + dstType.cast<VectorType>().isScalable()) && |
| 258 | + "scalable vectors are not supported."); |
| 259 | + if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) |
| 260 | + return rewriter.notifyMatchFailure( |
| 261 | + shuffleOp, "Can't flatten since targetBitWidth <= OpSize"); |
| 262 | + |
| 263 | + Value vec1 = adaptor.getV1(); |
| 264 | + Value vec2 = adaptor.getV2(); |
| 265 | + int shuffleSliceLen = 1; |
| 266 | + int rank = shuffleOp.getV1().getType().getRank(); |
| 267 | + |
| 268 | + // If rank > 1, we need to do the shuffle in the granularity of slices |
| 269 | + // instead of scalars. Size of the slice is equal to the rank-1 innermost |
| 270 | + // dims. Mask of the shuffle op specifies which slice to take from the |
| 271 | + // outermost dim. |
| 272 | + if (rank > 1) { |
| 273 | + llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape(); |
| 274 | + for (unsigned i = 1; i < shape.size(); ++i) { |
| 275 | + shuffleSliceLen *= shape[i]; |
| 276 | + } |
| 277 | + } |
| 278 | + |
| 279 | + // For each value in the mask, we generate the indices of the source vectors |
| 280 | + // that needs to be shuffled to the destination vector. If shuffleSliceLen > |
| 281 | + // 1 we need to shuffle the slices (consecutive shuffleSliceLen number of |
| 282 | + // elements) instead of scalars. |
| 283 | + ArrayAttr mask = shuffleOp.getMask(); |
| 284 | + int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen; |
| 285 | + llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts); |
| 286 | + for (auto [i, value] : |
| 287 | + llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) { |
| 288 | + |
| 289 | + int64_t v = value.getZExtValue(); |
| 290 | + std::iota(indices.begin() + shuffleSliceLen * i, |
| 291 | + indices.begin() + shuffleSliceLen * (i + 1), |
| 292 | + shuffleSliceLen * v); |
| 293 | + } |
| 294 | + |
| 295 | + rewriter.replaceOpWithNewOp<vector::ShuffleOp>( |
| 296 | + shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices)); |
| 297 | + return success(); |
| 298 | + } |
| 299 | + |
| 300 | +private: |
| 301 | + unsigned targetVectorBitWidth; |
| 302 | +}; |
| 303 | + |
| 304 | +/// This pattern converts the ExtractOp to a ShuffleOp that works on a |
| 305 | +/// linearized vector. |
| 306 | +/// Following, |
| 307 | +/// vector.extract %source [ position ] |
| 308 | +/// is converted to : |
| 309 | +/// %source_1d = vector.shape_cast %source |
| 310 | +/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] |
| 311 | +/// %out_nd = vector.shape_cast %out_1d |
| 312 | +/// `shuffle_indices_1d` is computed using the position of the original extract. |
| 313 | +struct LinearizeVectorExtract final |
| 314 | + : public OpConversionPattern<vector::ExtractOp> { |
| 315 | + using OpConversionPattern::OpConversionPattern; |
| 316 | + LinearizeVectorExtract( |
| 317 | + const TypeConverter &typeConverter, MLIRContext *context, |
| 318 | + unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(), |
| 319 | + PatternBenefit benefit = 1) |
| 320 | + : OpConversionPattern(typeConverter, context, benefit), |
| 321 | + targetVectorBitWidth(targetVectBitWidth) {} |
| 322 | + LogicalResult |
| 323 | + matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, |
| 324 | + ConversionPatternRewriter &rewriter) const override { |
| 325 | + Type dstTy = getTypeConverter()->convertType(extractOp.getType()); |
| 326 | + assert(!(extractOp.getVector().getType().isScalable() || |
| 327 | + dstTy.cast<VectorType>().isScalable()) && |
| 328 | + "scalable vectors are not supported."); |
| 329 | + if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) |
| 330 | + return rewriter.notifyMatchFailure( |
| 331 | + extractOp, "Can't flatten since targetBitWidth <= OpSize"); |
| 332 | + |
| 333 | + // Dynamic position is not supported. |
| 334 | + if (extractOp.hasDynamicPosition()) |
| 335 | + return rewriter.notifyMatchFailure(extractOp, |
| 336 | + "dynamic position is not supported."); |
| 337 | + |
| 338 | + llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape(); |
| 339 | + int64_t size = extractOp.getVector().getType().getNumElements(); |
| 340 | + |
| 341 | + // Compute linearized offset. |
| 342 | + int64_t linearizedOffset = 0; |
| 343 | + llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition(); |
| 344 | + for (auto [i, off] : llvm::enumerate(offsets)) { |
| 345 | + size /= shape[i]; |
| 346 | + linearizedOffset += offsets[i] * size; |
| 347 | + } |
| 348 | + |
| 349 | + llvm::SmallVector<int64_t, 2> indices(size); |
| 350 | + std::iota(indices.begin(), indices.end(), linearizedOffset); |
| 351 | + rewriter.replaceOpWithNewOp<vector::ShuffleOp>( |
| 352 | + extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), |
| 353 | + rewriter.getI64ArrayAttr(indices)); |
| 354 | + |
| 355 | + return success(); |
| 356 | + } |
| 357 | + |
106 | 358 | private:
|
107 | 359 | unsigned targetVectorBitWidth;
|
108 | 360 | };
|
@@ -145,3 +397,21 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
|
145 | 397 | patterns.add<LinearizeConstant, LinearizeVectorizable>(
|
146 | 398 | typeConverter, patterns.getContext(), targetBitWidth);
|
147 | 399 | }
|
| 400 | + |
| 401 | +void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( |
| 402 | + TypeConverter &typeConverter, RewritePatternSet &patterns, |
| 403 | + ConversionTarget &target, unsigned int targetBitWidth) { |
| 404 | + target.addDynamicallyLegalOp<vector::ShuffleOp>( |
| 405 | + [=](vector::ShuffleOp shuffleOp) -> bool { |
| 406 | + return isLessThanTargetBitWidth(shuffleOp, targetBitWidth) |
| 407 | + ? (typeConverter.isLegal(shuffleOp) && |
| 408 | + shuffleOp.getResult() |
| 409 | + .getType() |
| 410 | + .cast<mlir::VectorType>() |
| 411 | + .getRank() == 1) |
| 412 | + : true; |
| 413 | + }); |
| 414 | + patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract, |
| 415 | + LinearizeVectorExtractStridedSlice>( |
| 416 | + typeConverter, patterns.getContext(), targetBitWidth); |
| 417 | +} |
0 commit comments