-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Add support for linearizing Insert VectorOp in VectorLinearize #92370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Add support for linearizing Insert VectorOp in VectorLinearize #92370
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir Author: Artem Kroviakov (akroviakov) ChangesBuilding on top of #88204, this PR adds support for converting Full diff: https://github.com/llvm/llvm-project/pull/92370.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 802a64b0805ee..55d2903d8427d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -44,6 +44,22 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
return true;
}
+static bool isLessThanOrEqualTargetBitWidth(mlir::Type t,
+ unsigned targetBitWidth) {
+ VectorType vecType = dyn_cast<VectorType>(t);
+ // Reject index since getElementTypeBitWidth will abort for Index types.
+ if (!vecType || vecType.getElementType().isIndex())
+ return false;
+ // There are no dimension to fold if it is a 0-D vector.
+ if (vecType.getRank() == 0)
+ return false;
+ unsigned trailingVecDimBitWidth =
+ vecType.getShape().back() * vecType.getElementTypeBitWidth();
+ if (trailingVecDimBitWidth > targetBitWidth)
+ return false;
+ return true;
+}
+
namespace {
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
@@ -355,6 +371,88 @@ struct LinearizeVectorExtract final
return success();
}
+private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the InsertOp to a ShuffleOp that works on a
+/// linearized vector.
+/// Following,
+/// vector.insert %source %destination [ position ]
+/// is converted to :
+/// %source_1d = vector.shape_cast %source
+/// %destination_1d = vector.shape_cast %destination
+/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
+/// ] %out_nd = vector.shape_cast %out_1d
+/// `shuffle_indices_1d` is computed using the position of the original insert.
+struct LinearizeVectorInsert final
+ : public mlir::OpConversionPattern<mlir::vector::InsertOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorInsert(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+ mlir::LogicalResult
+ matchAndRewrite(mlir::vector::InsertOp insertOp, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
+ assert(!(insertOp.getDestVectorType().isScalable() ||
+ cast<VectorType>(dstTy).isScalable()) &&
+ "scalable vectors are not supported.");
+
+ if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
+ targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ insertOp, "Can't flatten since targetBitWidth < OpSize");
+
+ // dynamic position is not supported
+ if (insertOp.hasDynamicPosition())
+ return rewriter.notifyMatchFailure(insertOp,
+ "dynamic position is not supported.");
+ auto srcTy = insertOp.getSourceType();
+ auto srcAsVec = mlir::dyn_cast<mlir::VectorType>(srcTy);
+ uint64_t srcSize = 0;
+ if (srcAsVec) {
+ srcSize = srcAsVec.getNumElements();
+ } else {
+ return rewriter.notifyMatchFailure(insertOp,
+ "scalars are not supported.");
+ }
+
+ auto dstShape = insertOp.getDestVectorType().getShape();
+ const auto dstSize = insertOp.getDestVectorType().getNumElements();
+ auto dstSizeForOffsets = dstSize;
+
+ // compute linearized offset
+ int64_t linearizedOffset = 0;
+ auto offsetsNd = insertOp.getStaticPosition();
+ for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
+ dstSizeForOffsets /= dstShape[dim];
+ linearizedOffset += offset * dstSizeForOffsets;
+ }
+
+ llvm::SmallVector<int64_t, 2> indices(dstSize);
+ auto origValsUntil = indices.begin();
+ std::advance(origValsUntil, linearizedOffset);
+ std::iota(indices.begin(), origValsUntil,
+ 0); // original values that remain [0, offset)
+ auto newValsUntil = origValsUntil;
+ std::advance(newValsUntil, srcSize);
+ std::iota(origValsUntil, newValsUntil,
+ dstSize); // new values [offset, offset+srcNumElements)
+ std::iota(newValsUntil, indices.end(),
+ linearizedOffset + srcSize); // the rest of original values
+ // [offset+srcNumElements, end)
+
+ rewriter.replaceOpWithNewOp<mlir::vector::ShuffleOp>(
+ insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
+ rewriter.getI64ArrayAttr(indices));
+
+ return mlir::success();
+ }
+
private:
unsigned targetVectorBitWidth;
};
@@ -410,6 +508,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
: true;
});
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
- LinearizeVectorExtractStridedSlice>(
+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
typeConverter, patterns.getContext(), targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index b29ceab5783d7..31a59b809a74b 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -245,3 +245,32 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
%0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
return %0 : vector<8x2xf32>
}
+
+// -----
+// ALL-LABEL: test_vector_insert
+// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
+func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
+ // DEFAULT: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
+ // DEFAULT: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
+ // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
+ // DEFAULT-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
+ // DEFAULT-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
+ // DEFAULT-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
+ // DEFAULT: return %[[RES]] : vector<2x8x4xf32>
+
+ // BW-128: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
+ // BW-128: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
+ // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
+ // BW-128-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
+ // BW-128-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
+ // BW-128-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
+ // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
+ // BW-128: return %[[RES]] : vector<2x8x4xf32>
+
+ // BW-0: %[[RES:.*]] = vector.insert %[[SRC]], %[[DEST]] [0] : vector<8x4xf32> into vector<2x8x4xf32>
+ // BW-0: return %[[RES]] : vector<2x8x4xf32>
+
+ %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
+ return %0 : vector<2x8x4xf32>
+}
|
@llvm/pr-subscribers-mlir-vector Author: Artem Kroviakov (akroviakov) ChangesBuilding on top of #88204, this PR adds support for converting Full diff: https://github.com/llvm/llvm-project/pull/92370.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 802a64b0805ee..55d2903d8427d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -44,6 +44,22 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
return true;
}
+static bool isLessThanOrEqualTargetBitWidth(mlir::Type t,
+ unsigned targetBitWidth) {
+ VectorType vecType = dyn_cast<VectorType>(t);
+ // Reject index since getElementTypeBitWidth will abort for Index types.
+ if (!vecType || vecType.getElementType().isIndex())
+ return false;
+ // There are no dimension to fold if it is a 0-D vector.
+ if (vecType.getRank() == 0)
+ return false;
+ unsigned trailingVecDimBitWidth =
+ vecType.getShape().back() * vecType.getElementTypeBitWidth();
+ if (trailingVecDimBitWidth > targetBitWidth)
+ return false;
+ return true;
+}
+
namespace {
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
@@ -355,6 +371,88 @@ struct LinearizeVectorExtract final
return success();
}
+private:
+ unsigned targetVectorBitWidth;
+};
+
+/// This pattern converts the InsertOp to a ShuffleOp that works on a
+/// linearized vector.
+/// Following,
+/// vector.insert %source %destination [ position ]
+/// is converted to :
+/// %source_1d = vector.shape_cast %source
+/// %destination_1d = vector.shape_cast %destination
+/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
+/// ] %out_nd = vector.shape_cast %out_1d
+/// `shuffle_indices_1d` is computed using the position of the original insert.
+struct LinearizeVectorInsert final
+ : public mlir::OpConversionPattern<mlir::vector::InsertOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorInsert(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+ mlir::LogicalResult
+ matchAndRewrite(mlir::vector::InsertOp insertOp, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
+ assert(!(insertOp.getDestVectorType().isScalable() ||
+ cast<VectorType>(dstTy).isScalable()) &&
+ "scalable vectors are not supported.");
+
+ if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
+ targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ insertOp, "Can't flatten since targetBitWidth < OpSize");
+
+ // dynamic position is not supported
+ if (insertOp.hasDynamicPosition())
+ return rewriter.notifyMatchFailure(insertOp,
+ "dynamic position is not supported.");
+ auto srcTy = insertOp.getSourceType();
+ auto srcAsVec = mlir::dyn_cast<mlir::VectorType>(srcTy);
+ uint64_t srcSize = 0;
+ if (srcAsVec) {
+ srcSize = srcAsVec.getNumElements();
+ } else {
+ return rewriter.notifyMatchFailure(insertOp,
+ "scalars are not supported.");
+ }
+
+ auto dstShape = insertOp.getDestVectorType().getShape();
+ const auto dstSize = insertOp.getDestVectorType().getNumElements();
+ auto dstSizeForOffsets = dstSize;
+
+ // compute linearized offset
+ int64_t linearizedOffset = 0;
+ auto offsetsNd = insertOp.getStaticPosition();
+ for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
+ dstSizeForOffsets /= dstShape[dim];
+ linearizedOffset += offset * dstSizeForOffsets;
+ }
+
+ llvm::SmallVector<int64_t, 2> indices(dstSize);
+ auto origValsUntil = indices.begin();
+ std::advance(origValsUntil, linearizedOffset);
+ std::iota(indices.begin(), origValsUntil,
+ 0); // original values that remain [0, offset)
+ auto newValsUntil = origValsUntil;
+ std::advance(newValsUntil, srcSize);
+ std::iota(origValsUntil, newValsUntil,
+ dstSize); // new values [offset, offset+srcNumElements)
+ std::iota(newValsUntil, indices.end(),
+ linearizedOffset + srcSize); // the rest of original values
+ // [offset+srcNumElements, end)
+
+ rewriter.replaceOpWithNewOp<mlir::vector::ShuffleOp>(
+ insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
+ rewriter.getI64ArrayAttr(indices));
+
+ return mlir::success();
+ }
+
private:
unsigned targetVectorBitWidth;
};
@@ -410,6 +508,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
: true;
});
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
- LinearizeVectorExtractStridedSlice>(
+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
typeConverter, patterns.getContext(), targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index b29ceab5783d7..31a59b809a74b 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -245,3 +245,32 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
%0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
return %0 : vector<8x2xf32>
}
+
+// -----
+// ALL-LABEL: test_vector_insert
+// ALL-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
+func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
+ // DEFAULT: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
+ // DEFAULT: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
+ // DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
+ // DEFAULT-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
+ // DEFAULT-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
+ // DEFAULT-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
+ // DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
+ // DEFAULT: return %[[RES]] : vector<2x8x4xf32>
+
+ // BW-128: %[[ARG_SRC:.*]] = vector.shape_cast %[[SRC]] : vector<8x4xf32> to vector<32xf32>
+ // BW-128: %[[ARG_DEST:.*]] = vector.shape_cast %[[DEST]] : vector<2x8x4xf32> to vector<64xf32>
+ // BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG_DEST]], %[[ARG_SRC]]
+ // BW-128-SAME: [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87,
+ // BW-128-SAME: 88, 89, 90, 91, 92, 93, 94, 95, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
+ // BW-128-SAME: 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<64xf32>, vector<32xf32>
+ // BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<64xf32> to vector<2x8x4xf32>
+ // BW-128: return %[[RES]] : vector<2x8x4xf32>
+
+ // BW-0: %[[RES:.*]] = vector.insert %[[SRC]], %[[DEST]] [0] : vector<8x4xf32> into vector<2x8x4xf32>
+ // BW-0: return %[[RES]] : vector<2x8x4xf32>
+
+ %0 = vector.insert %arg1, %arg0[0]: vector<8x4xf32> into vector<2x8x4xf32>
+ return %0 : vector<2x8x4xf32>
+}
|
1fa0571
to
51115c2
Compare
51115c2
to
18c92bb
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
18c92bb
to
d991dc9
Compare
…inearize Building on top of llvm#88204, this commit adds support for InsertOp.
d991dc9
to
41a5598
Compare
LGTM, but please wait for other reviewers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's merge it.
@akroviakov Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested Please check whether problems have been caused by your change specifically, as How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
assert(!(insertOp.getDestVectorType().isScalable() || | ||
cast<VectorType>(dstTy).isScalable()) && | ||
"scalable vectors are not supported."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this an assert rather than rewriter.notifyMatchFailure
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for noting this, #93590 addresses it and also refactors other patterns in VectorLinearize.cpp
to use notifyMatchFailure
…inearize (llvm#92370) Building on top of [llvm#88204](llvm#88204), this PR adds support for converting `vector.insert` into an equivalent `vector.shuffle` operation that operates on linearized (1-D) vectors.
…arize (#93590) As it was [suggested](#92370 (comment)), the `assert` is replaced by `notifyMatchFailure` for improved consistency.
…arize (llvm#93590) As it was [suggested](llvm#92370 (comment)), the `assert` is replaced by `notifyMatchFailure` for improved consistency.
Building on top of #88204, this PR adds support for converting
vector.insert
into an equivalentvector.shuffle
operation that operates on linearized (1-D) vectors.