Skip to content

[mlir][Vector] Support mixed mode vector.contract lowering #117753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 38 additions & 25 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,28 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
}

static Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,
Type dstElementType) {
Type elementType = getElementTypeOrSelf(v.getType());
if (elementType == dstElementType)
return v;

// vector.contract only allows extension on operands.
assert(elementType.getIntOrFloatBitWidth() <=
dstElementType.getIntOrFloatBitWidth() &&
"vector.contract does not allow truncation of operands");

Type promotedType = dstElementType;
if (auto vecType = dyn_cast<VectorType>(v.getType()))
promotedType = vecType.clone(promotedType);

if (isa<FloatType>(dstElementType))
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
// For integer types, vector.contract only supports signless integer types
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd add a new line before this comment

// and promotion happens via sign extension.
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
}

// Helper method to possibly drop a dimension in a load.
// TODO
static Value reshapeLoad(Location loc, Value val, VectorType type,
Expand Down Expand Up @@ -136,6 +158,11 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
using vector::CombiningKind;
Value mul;

if (acc) {
x = promoteToElementType(loc, rewriter, x, getElementTypeOrSelf(acc));
y = promoteToElementType(loc, rewriter, y, getElementTypeOrSelf(acc));
}

if (isInt) {
if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
Expand Down Expand Up @@ -413,21 +440,6 @@ struct UnrolledOuterProductGenerator
return rewriter.create<vector::TransposeOp>(loc, v, perm);
}

Value promote(Value v, Type dstElementType) {
Type elementType = v.getType();
auto vecType = dyn_cast<VectorType>(elementType);
if (vecType)
elementType = vecType.getElementType();
if (elementType == dstElementType)
return v;
Type promotedType = dstElementType;
if (vecType)
promotedType = vecType.clone(promotedType);
if (isa<FloatType>(dstElementType))
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
}

FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
VectorType lhsType, int reductionSize,
std::optional<Value> maybeMask = std::nullopt) {
Expand All @@ -439,8 +451,8 @@ struct UnrolledOuterProductGenerator
for (int64_t k = 0; k < reductionSize; ++k) {
Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
extractA = promote(extractA, resElementType);
extractB = promote(extractB, resElementType);
extractA = promoteToElementType(loc, rewriter, extractA, resElementType);
extractB = promoteToElementType(loc, rewriter, extractB, resElementType);
Value extractMask;
if (maybeMask.has_value() && maybeMask.value())
extractMask =
Expand Down Expand Up @@ -764,6 +776,8 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
Value b = rank == 1
? rhs
: rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, c);
a = promoteToElementType(loc, rewriter, a, getElementTypeOrSelf(dstType));
b = promoteToElementType(loc, rewriter, b, getElementTypeOrSelf(dstType));
Value m = createMul(op.getLoc(), a, b, isInt, rewriter);
Value reduced = rewriter.create<vector::ReductionOp>(
op.getLoc(), vector::CombiningKind::ADD, m);
Expand Down Expand Up @@ -925,12 +939,6 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
if (failed(filter(op)))
return failure();

// TODO: support mixed mode contract lowering.
if (op.getLhsType().getElementType() !=
getElementTypeOrSelf(op.getAccType()) ||
op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
return failure();

// TODO: the code below assumes the default contraction, make sure it supports
// other kinds before enabling this lowering.
if (op.getKind() != vector::CombiningKind::ADD) {
Expand Down Expand Up @@ -1149,10 +1157,15 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
if (rhsType.getRank() != 1)
return rewriter.notifyMatchFailure(
op, "When LHS has rank 1, expected also RHS to have rank 1");
Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
auto kind = vector::CombiningKind::ADD;

Value acc = op.getAcc();
Value lhs = promoteToElementType(loc, rewriter, op.getLhs(),
getElementTypeOrSelf(acc));
Value rhs = promoteToElementType(loc, rewriter, op.getRhs(),
getElementTypeOrSelf(acc));
Value m = createMul(loc, lhs, rhs, isInt, rewriter);
auto kind = vector::CombiningKind::ADD;

Operation *reductionOp =
acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
: rewriter.create<vector::ReductionOp>(loc, kind, m);
Expand Down
27 changes: 27 additions & 0 deletions mlir/test/Dialect/Vector/vector-contract-to-dot-transforms.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,33 @@ func.func @contract_one_sided_unit_reduction_dim(%arg0 : vector<1x2xi32>, %arg1
return %res : vector<2xi32>
}

// CHECK-LABEL: @matmul_mixed
// CHECK: %[[EXT00:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
// CHECK: %[[EXT01:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
// CHECK: %[[MUL1:.+]] = arith.mulf %[[EXT00]], %[[EXT01]] : vector<2xf32>
// CHECK: vector.reduction <add>, %[[MUL1]] : vector<2xf32> into f32

// CHECK: %[[EXT11:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
// CHECK: %[[MUL2:.+]] = arith.mulf %[[EXT00]], %[[EXT11]] : vector<2xf32>
// CHECK: vector.reduction <add>, %[[MUL2]] : vector<2xf32> into f32

// CHECK: %[[EXT20:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
// CHECK: %[[EXT21:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
// CHECK: %[[MUL3:.+]] = arith.mulf %[[EXT20]], %[[EXT21]] : vector<2xf32>
// CHECK: vector.reduction <add>, %[[MUL3]] : vector<2xf32> into f32

// CHECK: %[[EXT31:.+]] = arith.extf %{{.*}} : vector<2xf16> to vector<2xf32>
// CHECK: %[[MUL4:.+]] = arith.mulf %[[EXT20]], %[[EXT31]] : vector<2xf32>
// CHECK: vector.reduction <add>, %[[MUL4]] : vector<2xf32> into f32

func.func @matmul_mixed(%arg0: vector<2x2xf16>,
%arg1: vector<2x2xf16>,
%arg2: vector<2x2xf32>) -> vector<2x2xf32> {
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
: vector<2x2xf16>, vector<2x2xf16> into vector<2x2xf32>
return %0 : vector<2x2xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ func.func @parallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vect
return %0 : f32
}

// CHECK-LABEL: func @parallel_contract_lowering_mixed_types
// CHECK: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16>
// CHECK: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : f16 from vector<1x1xf16>
// CHECK: %[[EXT0:.+]] = arith.extf %[[E0]] : f16 to f32
// CHECK: %[[EXT1:.+]] = arith.extf %[[E1]] : f16 to f32
// CHECK: %[[M:.*]] = arith.mulf %[[EXT0]], %[[EXT1]] : f32
// CHECK: %[[A:.*]] = arith.addf %[[M]], %{{.*}} : f32
// CHECK: return %[[A]] : f32
func.func @parallel_contract_lowering_mixed_types(%arg0: vector<1x1xf16>, %arg1: vector<1x1xf16>, %arg2: f32) -> f32 {
%0 = vector.contract {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> ()>],
iterator_types = ["reduction", "reduction"], kind = #vector.kind<add>}
%arg0, %arg1, %arg2 : vector<1x1xf16>, vector<1x1xf16> into f32
return %0 : f32
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
Expand Down
Loading