-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Add ElementwiseToOuterproduct #93664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
dab57eb
0f454b0
9889dc2
de63fd6
aa165d2
ec3cd83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1795,6 +1795,75 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> { | |
unsigned maxNumElementsToExtract = 0; | ||
}; | ||
|
||
/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B)) | ||
/// into vector.outerproduct(A, B) such as : | ||
/// ```mlir | ||
/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32> | ||
/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to | ||
/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to | ||
/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32> | ||
///``` | ||
/// Becomes : | ||
///```mlir | ||
/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32> | ||
///``` | ||
/// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled. | ||
/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32> | ||
/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32> | ||
/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32> | ||
|
||
nujaa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
template <typename MulOpType> | ||
struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> { | ||
using OpRewritePattern<MulOpType>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(MulOpType mulOp, | ||
PatternRewriter &rewriter) const override { | ||
auto VT = llvm::cast<VectorType>(mulOp.getResult().getType()); | ||
if (!VT) | ||
return failure(); | ||
if (VT.getRank() != 2) | ||
return failure(); | ||
nujaa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
auto canonicalize = [&](Value OperandA, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TBH, it was quite inspired by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
FYI, there is |
||
Value OperandB) -> vector::OuterProductOp { | ||
nujaa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
vector::TransposeOp transposedLhs = | ||
dyn_cast_or_null<vector::TransposeOp>(OperandA.getDefiningOp()); | ||
if (!transposedLhs) | ||
return vector::OuterProductOp(); | ||
nujaa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Fail unless this is a true 2-D matrix transpose. | ||
ArrayRef<int64_t> permutation = transposedLhs.getPermutation(); | ||
if (permutation[0] != 1 || permutation[1] != 0) | ||
return vector::OuterProductOp(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand that it is weird to have transpose with 0D or 1D vector types, but it is still a valid input IR. So I'd suggest to check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We check rank == 2 a bit earlier.
|
||
|
||
// Fail in case it is not a 1-to-2 dimension to broadcast to avoid | ||
// generating shape_casts/broadcasts which do not belong in this pattern. | ||
vector::BroadcastOp broadcastedLhs = dyn_cast<vector::BroadcastOp>( | ||
transposedLhs.getVector().getDefiningOp()); | ||
nujaa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (!broadcastedLhs || | ||
!broadcastedLhs.computeBroadcastedUnitDims().empty()) | ||
return vector::OuterProductOp(); | ||
// Avoid broadcast f32 or vector<f32> -> ResType | ||
auto srcVT = dyn_cast<VectorType>(broadcastedLhs.getSourceType()); | ||
if (!srcVT || srcVT.getRank() != 1) | ||
return vector::OuterProductOp(); | ||
|
||
vector::BroadcastOp broadcastedRhs = | ||
dyn_cast<vector::BroadcastOp>(OperandB.getDefiningOp()); | ||
nujaa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (!broadcastedRhs || broadcastedRhs.getSourceType() != srcVT) | ||
return vector::OuterProductOp(); | ||
|
||
return rewriter.replaceOpWithNewOp<vector::OuterProductOp>( | ||
mulOp, VT, broadcastedLhs.getSource(), broadcastedRhs.getSource(), | ||
Value(), vector::CombiningKind::ADD); | ||
nujaa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}; | ||
Value a = mulOp->getOperand(0), b = mulOp->getOperand(1); | ||
vector::OuterProductOp outerP = canonicalize(a, b); | ||
// Handle commutativity, the transposed op is the outerproduct LHS. | ||
outerP = outerP ? outerP : canonicalize(b, a); | ||
return outerP ? success() : failure(); | ||
nujaa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::vector::populateFoldArithExtensionPatterns( | ||
|
@@ -1882,6 +1951,12 @@ void mlir::vector::populateBreakDownVectorReductionPatterns( | |
maxNumElementsToExtract, benefit); | ||
} | ||
|
||
void mlir::vector::populateElementwiseToVectorOpsPatterns( | ||
RewritePatternSet &patterns) { | ||
patterns.add<ElementwiseToOuterproduct<arith::MulFOp>, | ||
ElementwiseToOuterproduct<arith::MulIOp>>(patterns.getContext()); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// TableGen'd enum attribute definitions | ||
//===----------------------------------------------------------------------===// | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -92,3 +92,41 @@ module attributes {transform.with_named_sequence} { | |||||
transform.yield | ||||||
} | ||||||
} | ||||||
|
||||||
// ----- | ||||||
|
||||||
// CHECK-LABEL: func.func @ewise_outerproduct | ||||||
// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>, | ||||||
// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> { | ||||||
// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32> | ||||||
// CHECK: return %[[RES]] : vector<[4]x[4]xi32> | ||||||
func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try to capture all unique features of a test in the corresponding function name.
Suggested change
|
||||||
%lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32> | ||||||
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> | ||||||
%rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32> | ||||||
%mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32> | ||||||
return %mul: vector<[4]x[4]xi32> | ||||||
} | ||||||
|
||||||
// CHECK-LABEL: func.func @ewise_outerproduct_transposed_rhs | ||||||
// CHECK-SAME: %[[LHS:.*]]: vector<16xf32>, | ||||||
// CHECK-SAME: %[[RHS:.*]]: vector<16xf32>) -> vector<16x16xf32> { | ||||||
// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<16xf32>, vector<16xf32> | ||||||
// CHECK: return %[[RES]] : vector<16x16xf32> | ||||||
func.func @ewise_outerproduct_transposed_rhs(%lhs: vector<16xf32>, %rhs: vector<16xf32>) -> vector<16x16xf32> { | ||||||
%rhsBcast = vector.broadcast %rhs : vector<16xf32> to vector<16x16xf32> | ||||||
%rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x16xf32> to vector<16x16xf32> | ||||||
%lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<16x16xf32> | ||||||
%mul = arith.mulf %lhsBcast, %rhsT : vector<16x16xf32> | ||||||
return %mul: vector<16x16xf32> | ||||||
} | ||||||
|
||||||
nujaa marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
module attributes {transform.with_named_sequence} { | ||||||
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { | ||||||
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op | ||||||
transform.apply_patterns to %func { | ||||||
transform.apply_patterns.vector.elementwise_to_vector | ||||||
} : !transform.any_op | ||||||
transform.yield | ||||||
} | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Keep it shorter. In particular, "series" might suggest that more than one occurance of
mulf(tr(broadcast(A)), broadcast(B))
is required: