-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Update CombineContractBroadcastMask
#140050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -264,109 +264,172 @@ struct CombineContractResultTranspose final | |
/// iterator_types = ["parallel", "parallel", "reduction"], | ||
/// kind = add} %arg0, %arg1, %cst_f0 | ||
/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> | ||
/// ``` | ||
struct CombineContractBroadcast | ||
: public OpRewritePattern<vector::ContractionOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, | ||
PatternRewriter &rewriter) const override { | ||
SmallVector<AffineMap> maps = | ||
llvm::to_vector<4>(contractOp.getIndexingMapsArray()); | ||
Value lhs = contractOp.getLhs(); | ||
Value rhs = contractOp.getRhs(); | ||
size_t index = 0; | ||
bool changed = false; | ||
for (Value *operand : {&lhs, &rhs}) { | ||
AffineMap &map = maps[index++]; | ||
auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); | ||
if (!broadcast) | ||
continue; | ||
// contractionOp can only take vector as operands. | ||
auto srcType = dyn_cast<VectorType>(broadcast.getSourceType()); | ||
if (!srcType || | ||
srcType.getRank() == broadcast.getResultVectorType().getRank()) | ||
continue; | ||
int64_t rankDiff = | ||
broadcast.getResultVectorType().getRank() - srcType.getRank(); | ||
bool innerDimBroadcast = false; | ||
SmallVector<AffineExpr> originalDims; | ||
for (const auto &dim : llvm::enumerate(srcType.getShape())) { | ||
if (dim.value() != broadcast.getResultVectorType().getDimSize( | ||
rankDiff + dim.index())) { | ||
innerDimBroadcast = true; | ||
break; | ||
} | ||
originalDims.push_back( | ||
rewriter.getAffineDimExpr(dim.index() + rankDiff)); | ||
/// ``` | ||
/// | ||
/// For masked vector.contract, the mask requires updating when a dimension is | ||
/// dropped. In such cases, the dropped dimensions must correspond to the mask's | ||
/// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims) | ||
/// is not supported. | ||
FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp, | ||
MaskingOpInterface maskingOp, | ||
PatternRewriter &rewriter) { | ||
SmallVector<AffineMap> maps = | ||
llvm::to_vector<4>(contractOp.getIndexingMapsArray()); | ||
Value lhs = contractOp.getLhs(); | ||
Value rhs = contractOp.getRhs(); | ||
size_t index = 0; | ||
bool changed = false; | ||
for (Value *operand : {&lhs, &rhs}) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious... what is this doing and why are we using a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
On L327 *operand = broadcast.getSource(); Apologies if I am stating the obvious, I wasn't sure what specifically you are asking about. Also, note that I am not touching this code - this particular GitHub diff claims otherwise. However, if you change the view, things will be clearer 😅 |
||
AffineMap &map = maps[index++]; | ||
auto broadcast = operand->getDefiningOp<vector::BroadcastOp>(); | ||
if (!broadcast) | ||
continue; | ||
// contractionOp can only take vector as operands. | ||
auto srcType = dyn_cast<VectorType>(broadcast.getSourceType()); | ||
if (!srcType || | ||
srcType.getRank() == broadcast.getResultVectorType().getRank()) | ||
continue; | ||
int64_t rankDiff = | ||
broadcast.getResultVectorType().getRank() - srcType.getRank(); | ||
bool innerDimBroadcast = false; | ||
SmallVector<AffineExpr> originalDims; | ||
for (const auto &dim : llvm::enumerate(srcType.getShape())) { | ||
if (dim.value() != | ||
broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) { | ||
innerDimBroadcast = true; | ||
break; | ||
} | ||
// Contract doesn't support inner dimension broadcast. Once this is | ||
// relaxed we can remove this case. | ||
if (innerDimBroadcast) | ||
continue; | ||
originalDims.push_back(rewriter.getAffineDimExpr(dim.index() + rankDiff)); | ||
} | ||
// Contract doesn't support inner dimension broadcast. Once this is | ||
// relaxed we can remove this case. | ||
if (innerDimBroadcast) | ||
continue; | ||
|
||
// It would be incorrect to fold a broadcast onto a reduction dimension | ||
// of non-unit size. | ||
bool nonUnitDimReductionBroadcast = false; | ||
for (int64_t i = 0; i < rankDiff; ++i) { | ||
if (broadcast.getResultVectorType().getDimSize(i) != 1 && | ||
isReductionIterator(contractOp.getIteratorTypes() | ||
.getValue()[map.getDimPosition(i)])) { | ||
nonUnitDimReductionBroadcast = true; | ||
break; | ||
} | ||
// It would be incorrect to fold a broadcast onto a reduction dimension | ||
// of non-unit size. | ||
bool nonUnitDimReductionBroadcast = false; | ||
for (int64_t i = 0; i < rankDiff; ++i) { | ||
if (broadcast.getResultVectorType().getDimSize(i) != 1 && | ||
isReductionIterator(contractOp.getIteratorTypes() | ||
.getValue()[map.getDimPosition(i)])) { | ||
nonUnitDimReductionBroadcast = true; | ||
break; | ||
} | ||
if (nonUnitDimReductionBroadcast) | ||
continue; | ||
|
||
AffineMap broadcastMap = | ||
AffineMap::get(broadcast.getResultVectorType().getRank(), 0, | ||
originalDims, contractOp.getContext()); | ||
map = broadcastMap.compose(map); | ||
*operand = broadcast.getSource(); | ||
changed = true; | ||
} | ||
if (nonUnitDimReductionBroadcast) | ||
continue; | ||
|
||
if (!changed) | ||
return failure(); | ||
AffineMap broadcastMap = | ||
AffineMap::get(broadcast.getResultVectorType().getRank(), 0, | ||
originalDims, contractOp.getContext()); | ||
map = broadcastMap.compose(map); | ||
*operand = broadcast.getSource(); | ||
changed = true; | ||
} | ||
|
||
// Determine which dims are usused, now that the maps have been composed | ||
// with the broadcast maps. | ||
llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps); | ||
// Compress unused dims. | ||
for (auto &m : maps) | ||
m = compressDims(m, unusedDimsBitVector); | ||
// Compute the combined iterators. | ||
SmallVector<Attribute> iterators; | ||
for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) { | ||
if (!unusedDimsBitVector.test(i)) | ||
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); | ||
} | ||
// Check that compressing unused dims isn't removing all reduction dimension | ||
// pairs. For example, if the vector.contract had only one reduction | ||
// iterator and that was a unit-dimension created by a broadcast, | ||
// then we should bail here, otherwise we would create a contract without | ||
// a reduction dimension pair. | ||
bool hasReductionIteratorApplyingOnBothSides = false; | ||
for (unsigned i = 0; i < iterators.size(); ++i) { | ||
if (!isReductionIterator(iterators[i])) | ||
continue; | ||
if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) { | ||
hasReductionIteratorApplyingOnBothSides = true; | ||
if (!changed) | ||
return failure(); | ||
|
||
// Determine which dims are usused, now that the maps have been composed | ||
// with the broadcast maps. | ||
llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps); | ||
// Compress unused dims. | ||
for (auto &m : maps) | ||
m = compressDims(m, unusedDimsBitVector); | ||
// Compute the combined iterators. | ||
SmallVector<Attribute> iterators; | ||
for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) { | ||
if (!unusedDimsBitVector.test(i)) | ||
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]); | ||
} | ||
|
||
// Check whether any of the unused dims is non-unit, e.g.: | ||
// * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32> | ||
// This is only required when collapsing a mask. If there is no mask, skip. | ||
VectorType oldMaskType; | ||
bool isAnyUnusedDimNonUnit = false; | ||
if (maskingOp) { | ||
oldMaskType = cast<VectorType>(maskingOp.getMask().getType()); | ||
for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) { | ||
if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) { | ||
isAnyUnusedDimNonUnit = true; | ||
break; | ||
} | ||
} | ||
if (!hasReductionIteratorApplyingOnBothSides) | ||
return failure(); | ||
} | ||
|
||
// If the compressed maps have a dimension that is not used by either LHS or | ||
// RHS then the ContractionOp verifier would fail. | ||
if (getUnusedDimsBitVector({maps[0], maps[1]}).any()) | ||
return failure(); | ||
rewriter.replaceOpWithNewOp<vector::ContractionOp>( | ||
contractOp, lhs, rhs, contractOp.getAcc(), | ||
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); | ||
return success(); | ||
// Check that compressing unused dims isn't removing all reduction dimension | ||
// pairs. For example, if the vector.contract had only one reduction | ||
// iterator and that was a unit-dimension created by a broadcast, | ||
// then we should bail here, otherwise we would create a contract without | ||
// a reduction dimension pair. | ||
bool hasReductionIteratorApplyingOnBothSides = false; | ||
for (unsigned i = 0; i < iterators.size(); ++i) { | ||
if (!isReductionIterator(iterators[i])) | ||
continue; | ||
if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) { | ||
hasReductionIteratorApplyingOnBothSides = true; | ||
break; | ||
} | ||
} | ||
if (!hasReductionIteratorApplyingOnBothSides) | ||
return failure(); | ||
|
||
// If the compressed maps have a dimension that is not used by either LHS or | ||
// RHS then the ContractionOp verifier would fail. | ||
if (getUnusedDimsBitVector({maps[0], maps[1]}).any()) | ||
return failure(); | ||
|
||
Operation *newOp = rewriter.create<vector::ContractionOp>( | ||
contractOp.getLoc(), lhs, rhs, contractOp.getAcc(), | ||
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators)); | ||
|
||
// Handle the mask. | ||
if (maskingOp) { | ||
if (isAnyUnusedDimNonUnit) | ||
return rewriter.notifyMatchFailure(contractOp, | ||
"Cannont drop non-unit mask dim."); | ||
assert(unusedDimsBitVector.size() == | ||
static_cast<size_t>(oldMaskType.getRank()) && | ||
"The mask rank is incorrect!"); | ||
|
||
// If a dimension has been dropped, update the mask accordingly. Otherwise, | ||
// keep it as is. | ||
Value mask = maskingOp.getMask(); | ||
if (unusedDimsBitVector.count() != 0) { | ||
// At this point, two assumptions are made: | ||
// * The unused dimensions are the leading mask dimensions | ||
// (vector.contract does not support inner dim broadcasting). | ||
// * The unused dimensions are all unit. | ||
// These conditions are effectively verified in the blocks preceeding this | ||
// one. | ||
auto newShape = | ||
oldMaskType.getShape().drop_front(unusedDimsBitVector.count()); | ||
auto newShapeScalableDims = | ||
oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count()); | ||
VectorType maskOpType = | ||
VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims); | ||
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mask = rewriter | ||
.create<vector::ShapeCastOp>(contractOp.getLoc(), maskOpType, | ||
maskingOp.getMask()) | ||
.getResult(); | ||
} | ||
|
||
newOp = mlir::vector::maskOperation(rewriter, newOp, mask); | ||
} | ||
return newOp->getResult(0); | ||
} | ||
|
||
struct CombineContractBroadcastMask | ||
: public MaskableOpRewritePattern<vector::ContractionOp> { | ||
using MaskableOpRewritePattern::MaskableOpRewritePattern; | ||
FailureOr<Value> | ||
|
||
matchAndRewriteMaskableOp(vector::ContractionOp contractOp, | ||
MaskingOpInterface maskingOp, | ||
PatternRewriter &rewriter) const override { | ||
return combineContractAndBroadcast(contractOp, maskingOp, rewriter); | ||
} | ||
}; | ||
|
||
|
@@ -2237,7 +2300,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT( | |
|
||
void mlir::vector::populateVectorReductionToContractPatterns( | ||
RewritePatternSet &patterns, PatternBenefit benefit) { | ||
patterns.add<MultiReduceToContract, CombineContractBroadcast, | ||
patterns.add<MultiReduceToContract, CombineContractBroadcastMask, | ||
CombineContractABTranspose, CombineContractResultTranspose>( | ||
patterns.getContext(), benefit); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, I am not touching this code - this particular GitHub diff claims otherwise. However, if you change the view, things will be clearer 😅 Since I am not touching this code, I'd rather leave it as is.