Skip to content

Commit e22508e

Browse files
authored
[mlir][vector] Update CombineContractBroadcastMask (#140050)
This patch updates `CombineContractBroadcastMask` to inherit from `MaskableOpRewritePattern`, enabling it to handle masked `vector.contract` operations. The pattern rewrites: ```mlir %a = vector.broadcast %a_bc %res vector.contract %a_bc, %b, ... ``` into: ```mlir // Move the broadcast into vector.contract (by updating the indexing // maps) %res vector.contract %a, %b, ... ``` The main challenge is supporting cases where the pattern drops a leading unit dimension. For example: ```mlir func.func @contract_broadcast_unit_dim_reduction_masked( %arg0 : vector<8x4xi32>, %arg1 : vector<8x4xi32>, %arg2 : vector<8x8xi32>, %mask: vector<1x8x8x4xi1>) -> vector<8x8xi32> { %0 = vector.broadcast %arg0 : vector<8x4xi32> to vector<1x8x4xi32> %1 = vector.broadcast %arg1 : vector<8x4xi32> to vector<1x8x4xi32> %result = vector.mask %mask { vector.contract { indexing_maps = [#map0, #map1, #map2], iterator_types = ["reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add> } %0, %1, %arg2 : vector<1x8x4xi32>, vector<1x8x4xi32> into vector<8x8xi32> } : vector<1x8x8x4xi1> -> vector<8x8xi32> return %result : vector<8x8xi32> } ``` Here, the leading unit dimension is dropped. To handle this, the mask is cast to the correct shape using a `vector.shape_cast`: ```mlir func.func @contract_broadcast_unit_dim_reduction_masked( %arg0: vector<8x4xi32>, %arg1: vector<8x4xi32>, %arg2: vector<8x8xi32>, %arg3: vector<1x8x8x4xi1>) -> vector<8x8xi32> { %mask_sc = vector.shape_cast %arg3 : vector<1x8x8x4xi1> to vector<8x8x4xi1> %res = vector.mask %mask_sc { vector.contract { indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add> } %arg0, %arg1, %mask_sc : vector<8x4xi32>, vector<8x4xi32> into vector<8x8xi32> } : vector<8x8x4xi1> -> vector<8x8xi32> return %res : vector<8x8xi32> } ``` While this isn't ideal - since it introduces a `vector.shape_cast` that must be cleaned up later - it reflects the best we can do once the input reaches `CombineContractBroadcastMask`. A more robust solution may involve simplifying the input earlier. I am leaving that as a TODO for myself to explore this further. Posting this now to unblock downstream work. LIMITATIONS Currently, this pattern assumes: * Only leading dimensions are dropped in the mask. * All dropped dimensions must be unit-sized.
1 parent e3e5bd1 commit e22508e

File tree

2 files changed

+344
-127
lines changed

2 files changed

+344
-127
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 157 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -264,109 +264,172 @@ struct CombineContractResultTranspose final
264264
/// iterator_types = ["parallel", "parallel", "reduction"],
265265
/// kind = add} %arg0, %arg1, %cst_f0
266266
/// : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
267-
/// ```
268-
struct CombineContractBroadcast
269-
: public OpRewritePattern<vector::ContractionOp> {
270-
using OpRewritePattern::OpRewritePattern;
271-
272-
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
273-
PatternRewriter &rewriter) const override {
274-
SmallVector<AffineMap> maps =
275-
llvm::to_vector<4>(contractOp.getIndexingMapsArray());
276-
Value lhs = contractOp.getLhs();
277-
Value rhs = contractOp.getRhs();
278-
size_t index = 0;
279-
bool changed = false;
280-
for (Value *operand : {&lhs, &rhs}) {
281-
AffineMap &map = maps[index++];
282-
auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
283-
if (!broadcast)
284-
continue;
285-
// contractionOp can only take vector as operands.
286-
auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
287-
if (!srcType ||
288-
srcType.getRank() == broadcast.getResultVectorType().getRank())
289-
continue;
290-
int64_t rankDiff =
291-
broadcast.getResultVectorType().getRank() - srcType.getRank();
292-
bool innerDimBroadcast = false;
293-
SmallVector<AffineExpr> originalDims;
294-
for (const auto &dim : llvm::enumerate(srcType.getShape())) {
295-
if (dim.value() != broadcast.getResultVectorType().getDimSize(
296-
rankDiff + dim.index())) {
297-
innerDimBroadcast = true;
298-
break;
299-
}
300-
originalDims.push_back(
301-
rewriter.getAffineDimExpr(dim.index() + rankDiff));
267+
/// ```
268+
///
269+
/// For masked vector.contract, the mask requires updating when a dimension is
270+
/// dropped. In such cases, the dropped dimensions must correspond to the mask's
271+
/// leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
272+
/// is not supported.
273+
FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp,
274+
MaskingOpInterface maskingOp,
275+
PatternRewriter &rewriter) {
276+
SmallVector<AffineMap> maps =
277+
llvm::to_vector<4>(contractOp.getIndexingMapsArray());
278+
Value lhs = contractOp.getLhs();
279+
Value rhs = contractOp.getRhs();
280+
size_t index = 0;
281+
bool changed = false;
282+
for (Value *operand : {&lhs, &rhs}) {
283+
AffineMap &map = maps[index++];
284+
auto broadcast = operand->getDefiningOp<vector::BroadcastOp>();
285+
if (!broadcast)
286+
continue;
287+
// contractionOp can only take vector as operands.
288+
auto srcType = dyn_cast<VectorType>(broadcast.getSourceType());
289+
if (!srcType ||
290+
srcType.getRank() == broadcast.getResultVectorType().getRank())
291+
continue;
292+
int64_t rankDiff =
293+
broadcast.getResultVectorType().getRank() - srcType.getRank();
294+
bool innerDimBroadcast = false;
295+
SmallVector<AffineExpr> originalDims;
296+
for (const auto &dim : llvm::enumerate(srcType.getShape())) {
297+
if (dim.value() !=
298+
broadcast.getResultVectorType().getDimSize(rankDiff + dim.index())) {
299+
innerDimBroadcast = true;
300+
break;
302301
}
303-
// Contract doesn't support inner dimension broadcast. Once this is
304-
// relaxed we can remove this case.
305-
if (innerDimBroadcast)
306-
continue;
302+
originalDims.push_back(rewriter.getAffineDimExpr(dim.index() + rankDiff));
303+
}
304+
// Contract doesn't support inner dimension broadcast. Once this is
305+
// relaxed we can remove this case.
306+
if (innerDimBroadcast)
307+
continue;
307308

308-
// It would be incorrect to fold a broadcast onto a reduction dimension
309-
// of non-unit size.
310-
bool nonUnitDimReductionBroadcast = false;
311-
for (int64_t i = 0; i < rankDiff; ++i) {
312-
if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
313-
isReductionIterator(contractOp.getIteratorTypes()
314-
.getValue()[map.getDimPosition(i)])) {
315-
nonUnitDimReductionBroadcast = true;
316-
break;
317-
}
309+
// It would be incorrect to fold a broadcast onto a reduction dimension
310+
// of non-unit size.
311+
bool nonUnitDimReductionBroadcast = false;
312+
for (int64_t i = 0; i < rankDiff; ++i) {
313+
if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
314+
isReductionIterator(contractOp.getIteratorTypes()
315+
.getValue()[map.getDimPosition(i)])) {
316+
nonUnitDimReductionBroadcast = true;
317+
break;
318318
}
319-
if (nonUnitDimReductionBroadcast)
320-
continue;
321-
322-
AffineMap broadcastMap =
323-
AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
324-
originalDims, contractOp.getContext());
325-
map = broadcastMap.compose(map);
326-
*operand = broadcast.getSource();
327-
changed = true;
328319
}
320+
if (nonUnitDimReductionBroadcast)
321+
continue;
329322

330-
if (!changed)
331-
return failure();
323+
AffineMap broadcastMap =
324+
AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
325+
originalDims, contractOp.getContext());
326+
map = broadcastMap.compose(map);
327+
*operand = broadcast.getSource();
328+
changed = true;
329+
}
332330

333-
// Determine which dims are usused, now that the maps have been composed
334-
// with the broadcast maps.
335-
llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
336-
// Compress unused dims.
337-
for (auto &m : maps)
338-
m = compressDims(m, unusedDimsBitVector);
339-
// Compute the combined iterators.
340-
SmallVector<Attribute> iterators;
341-
for (unsigned i = 0; i < unusedDimsBitVector.size(); ++i) {
342-
if (!unusedDimsBitVector.test(i))
343-
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
344-
}
345-
// Check that compressing unused dims isn't removing all reduction dimension
346-
// pairs. For example, if the vector.contract had only one reduction
347-
// iterator and that was a unit-dimension created by a broadcast,
348-
// then we should bail here, otherwise we would create a contract without
349-
// a reduction dimension pair.
350-
bool hasReductionIteratorApplyingOnBothSides = false;
351-
for (unsigned i = 0; i < iterators.size(); ++i) {
352-
if (!isReductionIterator(iterators[i]))
353-
continue;
354-
if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
355-
hasReductionIteratorApplyingOnBothSides = true;
331+
if (!changed)
332+
return failure();
333+
334+
// Determine which dims are usused, now that the maps have been composed
335+
// with the broadcast maps.
336+
llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector(maps);
337+
// Compress unused dims.
338+
for (auto &m : maps)
339+
m = compressDims(m, unusedDimsBitVector);
340+
// Compute the combined iterators.
341+
SmallVector<Attribute> iterators;
342+
for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
343+
if (!unusedDimsBitVector.test(i))
344+
iterators.push_back(contractOp.getIteratorTypes().getValue()[i]);
345+
}
346+
347+
// Check whether any of the unused dims is non-unit, e.g.:
348+
// * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32>
349+
// This is only required when collapsing a mask. If there is no mask, skip.
350+
VectorType oldMaskType;
351+
bool isAnyUnusedDimNonUnit = false;
352+
if (maskingOp) {
353+
oldMaskType = cast<VectorType>(maskingOp.getMask().getType());
354+
for (unsigned i = 0, e = unusedDimsBitVector.size(); i < e; ++i) {
355+
if (unusedDimsBitVector.test(i) && oldMaskType.getShape()[i] != 1) {
356+
isAnyUnusedDimNonUnit = true;
356357
break;
357358
}
358359
}
359-
if (!hasReductionIteratorApplyingOnBothSides)
360-
return failure();
360+
}
361361

362-
// If the compressed maps have a dimension that is not used by either LHS or
363-
// RHS then the ContractionOp verifier would fail.
364-
if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
365-
return failure();
366-
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
367-
contractOp, lhs, rhs, contractOp.getAcc(),
368-
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
369-
return success();
362+
// Check that compressing unused dims isn't removing all reduction dimension
363+
// pairs. For example, if the vector.contract had only one reduction
364+
// iterator and that was a unit-dimension created by a broadcast,
365+
// then we should bail here, otherwise we would create a contract without
366+
// a reduction dimension pair.
367+
bool hasReductionIteratorApplyingOnBothSides = false;
368+
for (unsigned i = 0; i < iterators.size(); ++i) {
369+
if (!isReductionIterator(iterators[i]))
370+
continue;
371+
if (getResultIndex(maps[0], i) && getResultIndex(maps[1], i)) {
372+
hasReductionIteratorApplyingOnBothSides = true;
373+
break;
374+
}
375+
}
376+
if (!hasReductionIteratorApplyingOnBothSides)
377+
return failure();
378+
379+
// If the compressed maps have a dimension that is not used by either LHS or
380+
// RHS then the ContractionOp verifier would fail.
381+
if (getUnusedDimsBitVector({maps[0], maps[1]}).any())
382+
return failure();
383+
384+
Operation *newOp = rewriter.create<vector::ContractionOp>(
385+
contractOp.getLoc(), lhs, rhs, contractOp.getAcc(),
386+
rewriter.getAffineMapArrayAttr(maps), rewriter.getArrayAttr(iterators));
387+
388+
// Handle the mask.
389+
if (maskingOp) {
390+
if (isAnyUnusedDimNonUnit)
391+
return rewriter.notifyMatchFailure(contractOp,
392+
"Cannont drop non-unit mask dim.");
393+
assert(unusedDimsBitVector.size() ==
394+
static_cast<size_t>(oldMaskType.getRank()) &&
395+
"The mask rank is incorrect!");
396+
397+
// If a dimension has been dropped, update the mask accordingly. Otherwise,
398+
// keep it as is.
399+
Value mask = maskingOp.getMask();
400+
if (unusedDimsBitVector.count() != 0) {
401+
// At this point, two assumptions are made:
402+
// * The unused dimensions are the leading mask dimensions
403+
// (vector.contract does not support inner dim broadcasting).
404+
// * The unused dimensions are all unit.
405+
// These conditions are effectively verified in the blocks preceeding this
406+
// one.
407+
auto newShape =
408+
oldMaskType.getShape().drop_front(unusedDimsBitVector.count());
409+
auto newShapeScalableDims =
410+
oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count());
411+
VectorType maskOpType =
412+
VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims);
413+
mask = rewriter
414+
.create<vector::ShapeCastOp>(contractOp.getLoc(), maskOpType,
415+
maskingOp.getMask())
416+
.getResult();
417+
}
418+
419+
newOp = mlir::vector::maskOperation(rewriter, newOp, mask);
420+
}
421+
return newOp->getResult(0);
422+
}
423+
424+
struct CombineContractBroadcastMask
425+
: public MaskableOpRewritePattern<vector::ContractionOp> {
426+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
427+
FailureOr<Value>
428+
429+
matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
430+
MaskingOpInterface maskingOp,
431+
PatternRewriter &rewriter) const override {
432+
return combineContractAndBroadcast(contractOp, maskingOp, rewriter);
370433
}
371434
};
372435

@@ -2237,7 +2300,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
22372300

22382301
void mlir::vector::populateVectorReductionToContractPatterns(
22392302
RewritePatternSet &patterns, PatternBenefit benefit) {
2240-
patterns.add<MultiReduceToContract, CombineContractBroadcast,
2303+
patterns.add<MultiReduceToContract, CombineContractBroadcastMask,
22412304
CombineContractABTranspose, CombineContractResultTranspose>(
22422305
patterns.getContext(), benefit);
22432306
}

0 commit comments

Comments
 (0)