Skip to content

[mlir][vector] Add ElementwiseToOuterproduct #93664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
/// into vector contract for the backends with native support.
void populateFoldArithExtensionPatterns(RewritePatternSet &patterns);

/// Collect a set of patterns that fold elementwise op on vectors to the vector
/// dialect.
void populateElementwiseToVectorOpsPatterns(RewritePatternSet &patterns);

/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,17 @@ def ApplyFoldArithExtensionPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplyFoldElementwiseToVectorPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.elementwise_to_vector",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect a set of patterns that fold elementwise op on vectors to the vector
dialect.
}];

let assemblyFormat = "attr-dict";
}

def ApplyVectorReductionToContractPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.reduction_to_contract",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ void transform::ApplyFoldArithExtensionPatternsOp::populatePatterns(
vector::populateFoldArithExtensionPatterns(patterns);
}

void transform::ApplyFoldElementwiseToVectorPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateElementwiseToVectorOpsPatterns(patterns);
}

void transform::ApplyVectorReductionToContractPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorReductionToContractPatterns(patterns);
Expand Down
75 changes: 75 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1795,6 +1795,75 @@ struct BreakDownVectorReduction final : OpRewritePattern<vector::ReductionOp> {
unsigned maxNumElementsToExtract = 0;
};

/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
/// into vector.outerproduct(A, B) such as :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Keep it shorter. In particular, "series" might suggest that more than one occurance of mulf(tr(broadcast(A)), broadcast(B)) is required:

Suggested change
/// Pattern aiming to fold a series of ops mulf(tr(broadcast(A)), broadcast(B))
/// into vector.outerproduct(A, B) such as :
/// Fold `mulf(tr(broadcast(A)), broadcast(B))` into `vector.outerproduct(A, B)`. Example:

/// ```mlir
/// %lhsBcast = vector.broadcast %lhs : vector<4xi32> to vector<4x4xi32>
/// %lhsT = vector.transpose %lhsBcast, [1, 0] : vector<4x4xi32> to
/// vector<4x4xi32> %rhsBcast = vector.broadcast %rhs : vector<4xi32> to
/// vector<4x4xi32> %mul = arith.muli %lhsT, %rhsBcast : vector<4x4xi32>
///```
/// Becomes :
///```mlir
/// %res = vector.outerproduct %lhs, %rhs : vector<4xi32>, vector<4xi32>
///```
/// Edge Cases where broadcast ops are not 1D to 2D as follow are not handled.
/// %ex1 = vector.broadcast %lhsCast : vector<1x4xf32> to vector<4x4xf32>
/// %ex2 = vector.broadcast %lhsCast : f32 to vector<4x4xf32>
/// %ex3 = vector.broadcast %lhsCast : vector<1x1xf32> to vector<4x4xf32>

template <typename MulOpType>
struct ElementwiseToOuterproduct : public OpRewritePattern<MulOpType> {
using OpRewritePattern<MulOpType>::OpRewritePattern;

LogicalResult matchAndRewrite(MulOpType mulOp,
PatternRewriter &rewriter) const override {
auto VT = llvm::cast<VectorType>(mulOp.getResult().getType());
if (!VT)
return failure();
if (VT.getRank() != 2)
return failure();

auto canonicalize = [&](Value OperandA,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I think that long lambdas (more than 2-3 lines) obscure readability. Why not create a helper hook outside this method? Others might disagree (looking at you @MacDue :) ).
  2. In any case, please document what this lambda does and use less generic name. "canonicalize" is a very powerful term and based on past experience I feel that it can mean different things depending on context. Also, we don't really want people hitting this when searching for "canonicalize" in MLIR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, it was quite inspired by CanonicalizeContract pattern working the same way. Lambdas allow not to pass Location and rewriter. Which I find nice. Renamed and commented.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lambdas allow not to pass Location and rewriter.

FYI, there is ImplicitLocOpBuilder which allows you just passing a single builder (but not location). I don't find it useful, so I usually don't suggest it in my reviews.

https://github.com/llvm/llvm-project/blob/0ec567c370df86893a22bf59d2716f6e553ca63b/mlir/include/mlir/IR/ImplicitLocOpBuilder.h#L20-L23C7

Value OperandB) -> vector::OuterProductOp {
vector::TransposeOp transposedLhs =
dyn_cast_or_null<vector::TransposeOp>(OperandA.getDefiningOp());
if (!transposedLhs)
return vector::OuterProductOp();
// Fail unless this is a true 2-D matrix transpose.
ArrayRef<int64_t> permutation = transposedLhs.getPermutation();
if (permutation[0] != 1 || permutation[1] != 0)
return vector::OuterProductOp();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that it is weird to have transpose with 0D or 1D vector types, but it is still a valid input IR. So I'd suggest to check permutation.size() == 2 as well. Otherwise, we could crash in accessing permutation[0] and permutation[1].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check rank == 2 a bit earlier.

if (resType.getRank() != 2)
      return failure();


// Fail in case it is not a 1-to-2 dimension to broadcast to avoid
// generating shape_casts/broadcasts which do not belong in this pattern.
vector::BroadcastOp broadcastedLhs = dyn_cast<vector::BroadcastOp>(
transposedLhs.getVector().getDefiningOp());
if (!broadcastedLhs ||
!broadcastedLhs.computeBroadcastedUnitDims().empty())
return vector::OuterProductOp();
// Avoid broadcast f32 or vector<f32> -> ResType
auto srcVT = dyn_cast<VectorType>(broadcastedLhs.getSourceType());
if (!srcVT || srcVT.getRank() != 1)
return vector::OuterProductOp();

vector::BroadcastOp broadcastedRhs =
dyn_cast<vector::BroadcastOp>(OperandB.getDefiningOp());
if (!broadcastedRhs || broadcastedRhs.getSourceType() != srcVT)
return vector::OuterProductOp();

return rewriter.replaceOpWithNewOp<vector::OuterProductOp>(
mulOp, VT, broadcastedLhs.getSource(), broadcastedRhs.getSource(),
Value(), vector::CombiningKind::ADD);
};
Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
vector::OuterProductOp outerP = canonicalize(a, b);
// Handle commutativity, the transposed op is the outerproduct LHS.
outerP = outerP ? outerP : canonicalize(b, a);
return outerP ? success() : failure();
}
};

} // namespace

void mlir::vector::populateFoldArithExtensionPatterns(
Expand Down Expand Up @@ -1882,6 +1951,12 @@ void mlir::vector::populateBreakDownVectorReductionPatterns(
maxNumElementsToExtract, benefit);
}

void mlir::vector::populateElementwiseToVectorOpsPatterns(
RewritePatternSet &patterns) {
patterns.add<ElementwiseToOuterproduct<arith::MulFOp>,
ElementwiseToOuterproduct<arith::MulIOp>>(patterns.getContext());
}

//===----------------------------------------------------------------------===//
// TableGen'd enum attribute definitions
//===----------------------------------------------------------------------===//
Expand Down
38 changes: 38 additions & 0 deletions mlir/test/Dialect/Vector/transform-vector.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,41 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @ewise_outerproduct
// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>,
// CHECK-SAME: %[[RHS:.*]]: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
// CHECK: %[[RES:.*]] = vector.outerproduct %[[LHS]], %[[RHS]] : vector<[4]xi32>, vector<[4]xi32>
// CHECK: return %[[RES]] : vector<[4]x[4]xi32>
func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to capture all unique features of a test in the corresponding function name.

Suggested change
func.func @ewise_outerproduct(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {
func.func @ewise_outerproduct_trans_lhs_i32(%lhs: vector<[4]xi32>, %rhs: vector<[4]xi32>) -> vector<[4]x[4]xi32> {

%lhsBcast = vector.broadcast %lhs : vector<[4]xi32> to vector<[4]x[4]xi32>
%lhsT = vector.transpose %lhsBcast, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
%rhsBcast = vector.broadcast %rhs : vector<[4]xi32> to vector<[4]x[4]xi32>
%mul = arith.muli %lhsT, %rhsBcast : vector<[4]x[4]xi32>
return %mul: vector<[4]x[4]xi32>
}

// CHECK-LABEL: func.func @ewise_outerproduct_transposed_rhs
// CHECK-SAME: %[[LHS:.*]]: vector<16xf32>,
// CHECK-SAME: %[[RHS:.*]]: vector<16xf32>) -> vector<16x16xf32> {
// CHECK: %[[RES:.*]] = vector.outerproduct %[[RHS]], %[[LHS]] : vector<16xf32>, vector<16xf32>
// CHECK: return %[[RES]] : vector<16x16xf32>
func.func @ewise_outerproduct_transposed_rhs(%lhs: vector<16xf32>, %rhs: vector<16xf32>) -> vector<16x16xf32> {
%rhsBcast = vector.broadcast %rhs : vector<16xf32> to vector<16x16xf32>
%rhsT = vector.transpose %rhsBcast, [1, 0] : vector<16x16xf32> to vector<16x16xf32>
%lhsBcast = vector.broadcast %lhs : vector<16xf32> to vector<16x16xf32>
%mul = arith.mulf %lhsBcast, %rhsT : vector<16x16xf32>
return %mul: vector<16x16xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.vector.elementwise_to_vector
} : !transform.any_op
transform.yield
}
}
Loading