-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Add ElementwiseToOuterproduct #93664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
dab57eb
0f454b0
9889dc2
de63fd6
aa165d2
ec3cd83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1795,6 +1795,87 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> { | |
unsigned maxNumElementsToExtract = 0; | ||
}; | ||
|
||
/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A, | ||
/// B)`. | ||
/// Example: | ||
/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32> | ||
/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to | ||
/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to | ||
/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32> | ||
/// | ||
/// Becomes : | ||
/// | ||
/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32> | ||
/// | ||
/// Supports only 1D-to-2D broadcasts. The following cases are not supported. | ||
/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32> | ||
/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32> | ||
/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32> | ||
template <typename MulOpType> | ||
struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> { | ||
using OpRewritePattern<MulOpType>::OpRewritePattern; | ||
// Returns whether a vector.broadcast matches requirements for an outerproduct | ||
// pattern. aka a 1D-to-2D broadcastOp without broadcasted unit dimension. | ||
bool isValidBroadcastSource(vector::BroadcastOp broadcastOp) const { | ||
// Fail if it is not a 1-to-2 dimension to broadcast to avoid generating | ||
// shape_casts/broadcasts which does not belong in this pattern. | ||
if (!broadcastOp.computeBroadcastedUnitDims().empty()) | ||
return false; | ||
// Avoid broadcast like f32 or vector<f32> -> ResType | ||
auto srcType = dyn_cast<VectorType>(broadcastOp.getSourceType()); | ||
if (!srcType || srcType.getRank() == 2) | ||
return false; | ||
return true; | ||
} | ||
|
||
LogicalResult matchAndRewrite(MulOpType mulOp, | ||
PatternRewriter &rewriter) const override { | ||
auto resType = llvm::cast<VectorType>(mulOp.getResult().getType()); | ||
if (!resType) | ||
return failure(); | ||
if (resType.getRank() != 2) | ||
return failure(); | ||
/// If operandA can be written as tr(broadcast(A)) and operandB as | ||
/// broadcast(B) where broadcasts are 1D-to-2D, create and return | ||
/// vector.outerproduct(A, B). Returns failure() otherwise. | ||
auto matchOuterProduct = | ||
[&](Value operandA, | ||
Value operandB) -> FailureOr<vector::OuterProductOp> { | ||
vector::TransposeOp transposedLhs = | ||
dyn_cast_or_null<vector::TransposeOp>(operandA.getDefiningOp()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: use auto because the casting already spells the type; we can use https://www.llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable
|
||
if (!transposedLhs) | ||
return failure(); | ||
// Fail unless this is a true 2-D matrix transpose. | ||
ArrayRef<int64_t> permutation = transposedLhs.getPermutation(); | ||
if (permutation.size() != 2 || permutation[0] != 1 || permutation[1] != 0) | ||
return failure(); | ||
|
||
auto broadcastedLhs = | ||
transposedLhs.getVector().getDefiningOp<vector::BroadcastOp>(); | ||
if (!broadcastedLhs || !isValidBroadcastSource(broadcastedLhs)) | ||
return failure(); | ||
|
||
auto broadcastedRhs = operandB.getDefiningOp<vector::BroadcastOp>(); | ||
if (!broadcastedRhs || !isValidBroadcastSource(broadcastedRhs)) | ||
return failure(); | ||
|
||
return rewriter.create<vector::OuterProductOp>( | ||
mulOp->getLoc(), resType, broadcastedLhs.getSource(), | ||
broadcastedRhs.getSource(), Value(), vector::CombiningKind::ADD); | ||
}; | ||
|
||
Value lhs = mulOp->getOperand(0), rhs = mulOp->getOperand(1); | ||
auto maybeOuterP = matchOuterProduct(lhs, rhs); | ||
// Handle commutativity, the transposed op is the outerproduct LHS. | ||
if (failed(maybeOuterP)) | ||
maybeOuterP = matchOuterProduct(rhs, lhs); | ||
if (failed(maybeOuterP)) | ||
return failure(); | ||
rewriter.replaceOp(mulOp, maybeOuterP->getResult()); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::vector::populateFoldArithExtensionPatterns( | ||
|
@@ -1882,6 +1963,13 @@ void mlir::vector::populateBreakDownVectorReductionPatterns( | |
maxNumElementsToExtract, benefit); | ||
} | ||
|
||
void mlir::vector::populateElementwiseToVectorOpsPatterns( | ||
RewritePatternSet &patterns) { | ||
patterns.add<FoldArithToVectorOuterProduct<arith::MulFOp>, | ||
FoldArithToVectorOuterProduct<arith::MulIOp>>( | ||
patterns.getContext()); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// TableGen'd enum attribute definitions | ||
//===----------------------------------------------------------------------===// | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about collapsing them to
return srcType && srcType.getRank() == 2
?