@@ -264,109 +264,172 @@ struct CombineContractResultTranspose final
264
264
// / iterator_types = ["parallel", "parallel", "reduction"],
265
265
// / kind = add} %arg0, %arg1, %cst_f0
266
266
// / : vector<32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32>
267
- // / ```
268
- struct CombineContractBroadcast
269
- : public OpRewritePattern<vector::ContractionOp> {
270
- using OpRewritePattern::OpRewritePattern;
271
-
272
- LogicalResult matchAndRewrite (vector::ContractionOp contractOp,
273
- PatternRewriter &rewriter) const override {
274
- SmallVector<AffineMap> maps =
275
- llvm::to_vector<4 >(contractOp.getIndexingMapsArray ());
276
- Value lhs = contractOp.getLhs ();
277
- Value rhs = contractOp.getRhs ();
278
- size_t index = 0 ;
279
- bool changed = false ;
280
- for (Value *operand : {&lhs, &rhs}) {
281
- AffineMap &map = maps[index++];
282
- auto broadcast = operand->getDefiningOp <vector::BroadcastOp>();
283
- if (!broadcast)
284
- continue ;
285
- // contractionOp can only take vector as operands.
286
- auto srcType = dyn_cast<VectorType>(broadcast.getSourceType ());
287
- if (!srcType ||
288
- srcType.getRank () == broadcast.getResultVectorType ().getRank ())
289
- continue ;
290
- int64_t rankDiff =
291
- broadcast.getResultVectorType ().getRank () - srcType.getRank ();
292
- bool innerDimBroadcast = false ;
293
- SmallVector<AffineExpr> originalDims;
294
- for (const auto &dim : llvm::enumerate (srcType.getShape ())) {
295
- if (dim.value () != broadcast.getResultVectorType ().getDimSize (
296
- rankDiff + dim.index ())) {
297
- innerDimBroadcast = true ;
298
- break ;
299
- }
300
- originalDims.push_back (
301
- rewriter.getAffineDimExpr (dim.index () + rankDiff));
267
+ // / ```
268
+ // /
269
+ // / For masked vector.contract, the mask requires updating when a dimension is
270
+ // / dropped. In such cases, the dropped dimensions must correspond to the mask's
271
+ // / leading unit dimensions. Supporting more generic cases (e.g. non-unit dims)
272
+ // / is not supported.
273
+ FailureOr<Value> combineContractAndBroadcast (vector::ContractionOp contractOp,
274
+ MaskingOpInterface maskingOp,
275
+ PatternRewriter &rewriter) {
276
+ SmallVector<AffineMap> maps =
277
+ llvm::to_vector<4 >(contractOp.getIndexingMapsArray ());
278
+ Value lhs = contractOp.getLhs ();
279
+ Value rhs = contractOp.getRhs ();
280
+ size_t index = 0 ;
281
+ bool changed = false ;
282
+ for (Value *operand : {&lhs, &rhs}) {
283
+ AffineMap &map = maps[index++];
284
+ auto broadcast = operand->getDefiningOp <vector::BroadcastOp>();
285
+ if (!broadcast)
286
+ continue ;
287
+ // contractionOp can only take vector as operands.
288
+ auto srcType = dyn_cast<VectorType>(broadcast.getSourceType ());
289
+ if (!srcType ||
290
+ srcType.getRank () == broadcast.getResultVectorType ().getRank ())
291
+ continue ;
292
+ int64_t rankDiff =
293
+ broadcast.getResultVectorType ().getRank () - srcType.getRank ();
294
+ bool innerDimBroadcast = false ;
295
+ SmallVector<AffineExpr> originalDims;
296
+ for (const auto &dim : llvm::enumerate (srcType.getShape ())) {
297
+ if (dim.value () !=
298
+ broadcast.getResultVectorType ().getDimSize (rankDiff + dim.index ())) {
299
+ innerDimBroadcast = true ;
300
+ break ;
302
301
}
303
- // Contract doesn't support inner dimension broadcast. Once this is
304
- // relaxed we can remove this case.
305
- if (innerDimBroadcast)
306
- continue ;
302
+ originalDims.push_back (rewriter.getAffineDimExpr (dim.index () + rankDiff));
303
+ }
304
+ // Contract doesn't support inner dimension broadcast. Once this is
305
+ // relaxed we can remove this case.
306
+ if (innerDimBroadcast)
307
+ continue ;
307
308
308
- // It would be incorrect to fold a broadcast onto a reduction dimension
309
- // of non-unit size.
310
- bool nonUnitDimReductionBroadcast = false ;
311
- for (int64_t i = 0 ; i < rankDiff; ++i) {
312
- if (broadcast.getResultVectorType ().getDimSize (i) != 1 &&
313
- isReductionIterator (contractOp.getIteratorTypes ()
314
- .getValue ()[map.getDimPosition (i)])) {
315
- nonUnitDimReductionBroadcast = true ;
316
- break ;
317
- }
309
+ // It would be incorrect to fold a broadcast onto a reduction dimension
310
+ // of non-unit size.
311
+ bool nonUnitDimReductionBroadcast = false ;
312
+ for (int64_t i = 0 ; i < rankDiff; ++i) {
313
+ if (broadcast.getResultVectorType ().getDimSize (i) != 1 &&
314
+ isReductionIterator (contractOp.getIteratorTypes ()
315
+ .getValue ()[map.getDimPosition (i)])) {
316
+ nonUnitDimReductionBroadcast = true ;
317
+ break ;
318
318
}
319
- if (nonUnitDimReductionBroadcast)
320
- continue ;
321
-
322
- AffineMap broadcastMap =
323
- AffineMap::get (broadcast.getResultVectorType ().getRank (), 0 ,
324
- originalDims, contractOp.getContext ());
325
- map = broadcastMap.compose (map);
326
- *operand = broadcast.getSource ();
327
- changed = true ;
328
319
}
320
+ if (nonUnitDimReductionBroadcast)
321
+ continue ;
329
322
330
- if (!changed)
331
- return failure ();
323
+ AffineMap broadcastMap =
324
+ AffineMap::get (broadcast.getResultVectorType ().getRank (), 0 ,
325
+ originalDims, contractOp.getContext ());
326
+ map = broadcastMap.compose (map);
327
+ *operand = broadcast.getSource ();
328
+ changed = true ;
329
+ }
332
330
333
- // Determine which dims are usused, now that the maps have been composed
334
- // with the broadcast maps.
335
- llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector (maps);
336
- // Compress unused dims.
337
- for (auto &m : maps)
338
- m = compressDims (m, unusedDimsBitVector);
339
- // Compute the combined iterators.
340
- SmallVector<Attribute> iterators;
341
- for (unsigned i = 0 ; i < unusedDimsBitVector.size (); ++i) {
342
- if (!unusedDimsBitVector.test (i))
343
- iterators.push_back (contractOp.getIteratorTypes ().getValue ()[i]);
344
- }
345
- // Check that compressing unused dims isn't removing all reduction dimension
346
- // pairs. For example, if the vector.contract had only one reduction
347
- // iterator and that was a unit-dimension created by a broadcast,
348
- // then we should bail here, otherwise we would create a contract without
349
- // a reduction dimension pair.
350
- bool hasReductionIteratorApplyingOnBothSides = false ;
351
- for (unsigned i = 0 ; i < iterators.size (); ++i) {
352
- if (!isReductionIterator (iterators[i]))
353
- continue ;
354
- if (getResultIndex (maps[0 ], i) && getResultIndex (maps[1 ], i)) {
355
- hasReductionIteratorApplyingOnBothSides = true ;
331
+ if (!changed)
332
+ return failure ();
333
+
334
+ // Determine which dims are usused, now that the maps have been composed
335
+ // with the broadcast maps.
336
+ llvm::SmallBitVector unusedDimsBitVector = getUnusedDimsBitVector (maps);
337
+ // Compress unused dims.
338
+ for (auto &m : maps)
339
+ m = compressDims (m, unusedDimsBitVector);
340
+ // Compute the combined iterators.
341
+ SmallVector<Attribute> iterators;
342
+ for (unsigned i = 0 , e = unusedDimsBitVector.size (); i < e; ++i) {
343
+ if (!unusedDimsBitVector.test (i))
344
+ iterators.push_back (contractOp.getIteratorTypes ().getValue ()[i]);
345
+ }
346
+
347
+ // Check whether any of the unused dims is non-unit, e.g.:
348
+ // * vector.broadcast %arg0 : vector<8x4xi32> to vector<2x8x4xi32>
349
+ // This is only required when collapsing a mask. If there is no mask, skip.
350
+ VectorType oldMaskType;
351
+ bool isAnyUnusedDimNonUnit = false ;
352
+ if (maskingOp) {
353
+ oldMaskType = cast<VectorType>(maskingOp.getMask ().getType ());
354
+ for (unsigned i = 0 , e = unusedDimsBitVector.size (); i < e; ++i) {
355
+ if (unusedDimsBitVector.test (i) && oldMaskType.getShape ()[i] != 1 ) {
356
+ isAnyUnusedDimNonUnit = true ;
356
357
break ;
357
358
}
358
359
}
359
- if (!hasReductionIteratorApplyingOnBothSides)
360
- return failure ();
360
+ }
361
361
362
- // If the compressed maps have a dimension that is not used by either LHS or
363
- // RHS then the ContractionOp verifier would fail.
364
- if (getUnusedDimsBitVector ({maps[0 ], maps[1 ]}).any ())
365
- return failure ();
366
- rewriter.replaceOpWithNewOp <vector::ContractionOp>(
367
- contractOp, lhs, rhs, contractOp.getAcc (),
368
- rewriter.getAffineMapArrayAttr (maps), rewriter.getArrayAttr (iterators));
369
- return success ();
362
+ // Check that compressing unused dims isn't removing all reduction dimension
363
+ // pairs. For example, if the vector.contract had only one reduction
364
+ // iterator and that was a unit-dimension created by a broadcast,
365
+ // then we should bail here, otherwise we would create a contract without
366
+ // a reduction dimension pair.
367
+ bool hasReductionIteratorApplyingOnBothSides = false ;
368
+ for (unsigned i = 0 ; i < iterators.size (); ++i) {
369
+ if (!isReductionIterator (iterators[i]))
370
+ continue ;
371
+ if (getResultIndex (maps[0 ], i) && getResultIndex (maps[1 ], i)) {
372
+ hasReductionIteratorApplyingOnBothSides = true ;
373
+ break ;
374
+ }
375
+ }
376
+ if (!hasReductionIteratorApplyingOnBothSides)
377
+ return failure ();
378
+
379
+ // If the compressed maps have a dimension that is not used by either LHS or
380
+ // RHS then the ContractionOp verifier would fail.
381
+ if (getUnusedDimsBitVector ({maps[0 ], maps[1 ]}).any ())
382
+ return failure ();
383
+
384
+ Operation *newOp = rewriter.create <vector::ContractionOp>(
385
+ contractOp.getLoc (), lhs, rhs, contractOp.getAcc (),
386
+ rewriter.getAffineMapArrayAttr (maps), rewriter.getArrayAttr (iterators));
387
+
388
+ // Handle the mask.
389
+ if (maskingOp) {
390
+ if (isAnyUnusedDimNonUnit)
391
+ return rewriter.notifyMatchFailure (contractOp,
392
+ " Cannont drop non-unit mask dim." );
393
+ assert (unusedDimsBitVector.size () ==
394
+ static_cast <size_t >(oldMaskType.getRank ()) &&
395
+ " The mask rank is incorrect!" );
396
+
397
+ // If a dimension has been dropped, update the mask accordingly. Otherwise,
398
+ // keep it as is.
399
+ Value mask = maskingOp.getMask ();
400
+ if (unusedDimsBitVector.count () != 0 ) {
401
+ // At this point, two assumptions are made:
402
+ // * The unused dimensions are the leading mask dimensions
403
+ // (vector.contract does not support inner dim broadcasting).
404
+ // * The unused dimensions are all unit.
405
+ // These conditions are effectively verified in the blocks preceeding this
406
+ // one.
407
+ auto newShape =
408
+ oldMaskType.getShape ().drop_front (unusedDimsBitVector.count ());
409
+ auto newShapeScalableDims =
410
+ oldMaskType.getScalableDims ().drop_front (unusedDimsBitVector.count ());
411
+ VectorType maskOpType =
412
+ VectorType::get (newShape, rewriter.getI1Type (), newShapeScalableDims);
413
+ mask = rewriter
414
+ .create <vector::ShapeCastOp>(contractOp.getLoc (), maskOpType,
415
+ maskingOp.getMask ())
416
+ .getResult ();
417
+ }
418
+
419
+ newOp = mlir::vector::maskOperation (rewriter, newOp, mask);
420
+ }
421
+ return newOp->getResult (0 );
422
+ }
423
+
424
+ struct CombineContractBroadcastMask
425
+ : public MaskableOpRewritePattern<vector::ContractionOp> {
426
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
427
+ FailureOr<Value>
428
+
429
+ matchAndRewriteMaskableOp (vector::ContractionOp contractOp,
430
+ MaskingOpInterface maskingOp,
431
+ PatternRewriter &rewriter) const override {
432
+ return combineContractAndBroadcast (contractOp, maskingOp, rewriter);
370
433
}
371
434
};
372
435
@@ -2237,7 +2300,7 @@ void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT(
2237
2300
2238
2301
void mlir::vector::populateVectorReductionToContractPatterns (
2239
2302
RewritePatternSet &patterns, PatternBenefit benefit) {
2240
- patterns.add <MultiReduceToContract, CombineContractBroadcast ,
2303
+ patterns.add <MultiReduceToContract, CombineContractBroadcastMask ,
2241
2304
CombineContractABTranspose, CombineContractResultTranspose>(
2242
2305
patterns.getContext (), benefit);
2243
2306
}
0 commit comments