Skip to content

[mlir][vector] Update CombineContractBroadcastMask #140050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 157 additions & 94 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,109 +264,172 @@ struct CombineContractResultTranspose final
/// iterator_types = ["parallel", "parallel", "reduction"],
/// kind = add} %arg0, %arg1, %cst_f0
/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
/// ```
struct CombineContractBroadcast
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
SmallVector<AffineMap> maps =
llvm::to_vector<4>(contractOp.getIndexingMapsArray());
Value lhs = contractOp.getLhs();
Value rhs = contractOp.getRhs();
size_t index = 0;
bool changed = false;
for (Value *operand : {&lhs, &rhs}) {
AffineMap &map = maps[index++];
auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
if (!broadcast)
continue;
// contractionOp can only take vector as operands.
auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
if (!srcType ||
srcType.getRank() == broadcast.getResultVectorType().getRank())
continue;
int64_t rankDiff =
broadcast.getResultVectorType().getRank() - srcType.getRank();
bool innerDimBroadcast = false;
SmallVector<AffineExpr> originalDims;
for (const auto &dim : llvm::enumerate(srcType.getShape())) {
if (dim.value() != broadcast.getResultVectorType().getDimSize(
rankDiff + dim.index())) {
innerDimBroadcast = true;
break;
}
originalDims.push_back(
rewriter.getAffineDimExpr(dim.index() + rankDiff));
/// ```
///
/// For masked vector.contract, the mask requires updating when a dimension is
/// dropped. In such cases, the dropped dimensions must correspond to the mask's
/// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
/// is not supported.
FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
MaskingOpInterface maskingOp,
PatternRewriter &rewriter) {
SmallVector<AffineMap> maps =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, I am not touching this code - this particular GitHub diff claims otherwise. However, if you change the view, things will be clearer 😅 Since I am not touching this code, I'd rather leave it as is.

llvm::to_vector<4>(contractOp.getIndexingMapsArray());
Value lhs = contractOp.getLhs();
Value rhs = contractOp.getRhs();
size_t index = 0;
bool changed = false;
for (Value *operand : {&lhs, &rhs}) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious... what is this doing and why are we using a Value *?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this doing
Iterates over &lhs and &rhs (which are Values).

why are we using a Value *

On L327

    *operand = broadcast.getSource();

Apologies if I am stating the obvious, I wasn't sure what specifically you are asking about. Also, note that I am not touching this code - this particular GitHub diff claims otherwise. However, if you change the view, things will be clearer 😅

AffineMap &map = maps[index++];
auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
if (!broadcast)
continue;
// contractionOp can only take vector as operands.
auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
if (!srcType ||
srcType.getRank() == broadcast.getResultVectorType().getRank())
continue;
int64_t rankDiff =
broadcast.getResultVectorType().getRank() - srcType.getRank();
bool innerDimBroadcast = false;
SmallVector<AffineExpr> originalDims;
for (const auto &dim : llvm::enumerate(srcType.getShape())) {
if (dim.value() !=
broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
innerDimBroadcast = true;
break;
}
// Contract doesn't support inner dimension broadcast. Once this is
// relaxed we can remove this case.
if (innerDimBroadcast)
continue;
originalDims.push_back(rewriter.getAffineDimExpr(dim.index() + rankDiff));
}
// Contract doesn't support inner dimension broadcast. Once this is
// relaxed we can remove this case.
if (innerDimBroadcast)
continue;

// It would be incorrect to fold a broadcast onto a reduction dimension
// of non-unit size.
bool nonUnitDimReductionBroadcast = false;
for (int64_t i = 0; i < rankDiff; ++i) {
if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
isReductionIterator(contractOp.getIteratorTypes()
.getValue()[map.getDimPosition(i)])) {
nonUnitDimReductionBroadcast = true;
break;
}
// It would be incorrect to fold a broadcast onto a reduction dimension
// of non-unit size.
bool nonUnitDimReductionBroadcast = false;
for (int64_t i = 0; i < rankDiff; ++i) {
if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
isReductionIterator(contractOp.getIteratorTypes()
.getValue()[map.getDimPosition(i)])) {
nonUnitDimReductionBroadcast = true;
break;
}
if (nonUnitDimReductionBroadcast)
continue;

AffineMap broadcastMap =
AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
originalDims, contractOp.getContext());
map = broadcastMap.compose(map);
*operand = broadcast.getSource();
changed = true;
}
if (nonUnitDimReductionBroadcast)
continue;

if (!changed)
return failure();
AffineMap broadcastMap =
AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
originalDims, contractOp.getContext());
map = broadcastMap.compose(map);
*operand = broadcast.getSource();
changed = true;
}

// Determine which dims are usused, now that the maps have been composed
// with the broadcast maps.
llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
// Compress unused dims.
for (auto &m : maps)
m = compressDims(m, unusedDimsBitVector);
// Compute the combined iterators.
SmallVector<Attribute> iterators;
for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
if (!unusedDimsBitVector.test(i))
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
}
// Check that compressing unused dims isn't removing all reduction dimension
// pairs. For example, if the vector.contract had only one reduction
// iterator and that was a unit-dimension created by a broadcast,
// then we should bail here, otherwise we would create a contract without
// a reduction dimension pair.
bool hasReductionIteratorApplyingOnBothSides = false;
for (unsigned i = 0; i < iterators.size(); ++i) {
if (!isReductionIterator(iterators[i]))
continue;
if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
hasReductionIteratorApplyingOnBothSides = true;
if (!changed)
return failure();

// Determine which dims are usused, now that the maps have been composed
// with the broadcast maps.
llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
// Compress unused dims.
for (auto &m : maps)
m = compressDims(m, unusedDimsBitVector);
// Compute the combined iterators.
SmallVector<Attribute> iterators;
for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
if (!unusedDimsBitVector.test(i))
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
}

// Check whether any of the unused dims is non-unit, e.g.:
// * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32>
// This is only required when collapsing a mask. If there is no mask, skip.
VectorType oldMaskType;
bool isAnyUnusedDimNonUnit = false;
if (maskingOp) {
oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
isAnyUnusedDimNonUnit = true;
break;
}
}
if (!hasReductionIteratorApplyingOnBothSides)
return failure();
}

// If the compressed maps have a dimension that is not used by either LHS or
// RHS then the ContractionOp verifier would fail.
if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
return failure();
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
contractOp, lhs, rhs, contractOp.getAcc(),
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
return success();
// Check that compressing unused dims isn't removing all reduction dimension
// pairs. For example, if the vector.contract had only one reduction
// iterator and that was a unit-dimension created by a broadcast,
// then we should bail here, otherwise we would create a contract without
// a reduction dimension pair.
bool hasReductionIteratorApplyingOnBothSides = false;
for (unsigned i = 0; i < iterators.size(); ++i) {
if (!isReductionIterator(iterators[i]))
continue;
if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
hasReductionIteratorApplyingOnBothSides = true;
break;
}
}
if (!hasReductionIteratorApplyingOnBothSides)
return failure();

// If the compressed maps have a dimension that is not used by either LHS or
// RHS then the ContractionOp verifier would fail.
if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
return failure();

Operation *newOp = rewriter.create<vector::ContractionOp>(
contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));

// Handle the mask.
if (maskingOp) {
if (isAnyUnusedDimNonUnit)
return rewriter.notifyMatchFailure(contractOp,
"Cannont drop non-unit mask dim.");
assert(unusedDimsBitVector.size() ==
static_cast<size_t>(oldMaskType.getRank()) &&
"The mask rank is incorrect!");

// If a dimension has been dropped, update the mask accordingly. Otherwise,
// keep it as is.
Value mask = maskingOp.getMask();
if (unusedDimsBitVector.count() != 0) {
// At this point, two assumptions are made:
// * The unused dimensions are the leading mask dimensions
// (vector.contract does not support inner dim broadcasting).
// * The unused dimensions are all unit.
// These conditions are effectively verified in the blocks preceeding this
// one.
auto newShape =
oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
auto newShapeScalableDims =
oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
VectorType maskOpType =
VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims);
mask = rewriter
.create<vector::ShapeCastOp>(contractOp.getLoc(), maskOpType,
maskingOp.getMask())
.getResult();
}

newOp = mlir::vector::maskOperation(rewriter, newOp, mask);
}
return newOp->getResult(0);
}

struct CombineContractBroadcastMask
: public MaskableOpRewritePattern<vector::ContractionOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
FailureOr<Value>

matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
}
};

Expand Down Expand Up @@ -2237,7 +2300,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(

void mlir::vector::populateVectorReductionToContractPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<MultiReduceToContract, CombineContractBroadcast,
patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
CombineContractABTranspose, CombineContractResultTranspose>(
patterns.getContext(), benefit);
}
Expand Down
Loading