@@ -420,38 +420,59 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
420
420
auto tileType = tileStoreOp.getVectorType ();
421
421
auto tileElementType = tileType.getElementType ();
422
422
423
- // Create a loop that stores each ZA tile slice from memory.
423
+ auto predicateType =
424
+ VectorType::get (tileType.getDimSize (1 ), rewriter.getI1Type (), true );
425
+
426
+ Value maskCols;
427
+ Value upperBound;
428
+ auto maskOp = tileStoreOp.getMask ();
429
+ if (maskOp) {
430
+ auto createMaskOp = maskOp.getDefiningOp <vector::CreateMaskOp>();
431
+ if (!createMaskOp)
432
+ return rewriter.notifyMatchFailure (
433
+ tileStoreOp, " unsupported mask op, only 'vector.create_mask' is "
434
+ " currently supported" );
435
+
436
+ auto numRows = createMaskOp.getOperands ()[0 ];
437
+ auto numCols = createMaskOp.getOperands ()[1 ];
438
+
439
+ upperBound = numRows;
440
+ maskCols =
441
+ rewriter.create <vector::CreateMaskOp>(loc, predicateType, numCols);
442
+ } else {
443
+ // Store all tile slices if no mask.
444
+ auto minTileSlices = rewriter.create <arith::ConstantIndexOp>(
445
+ loc, arm_sme::getSMETileSliceMinNumElts (tileElementType));
446
+ auto vscale =
447
+ rewriter.create <vector::VectorScaleOp>(loc, rewriter.getIndexType ());
448
+ // This describes both the number of ZA tile slices and the number of
449
+ // elements in a vector of SVL bits for a given element type (SVL_B,
450
+ // SVL_H,
451
+ // ..., SVL_Q).
452
+ auto numTileSlices =
453
+ rewriter.create <arith::MulIOp>(loc, minTileSlices, vscale);
454
+
455
+ upperBound = numTileSlices;
456
+ // Create an 'all true' predicate for the tile slice.
457
+ maskCols = rewriter.create <arith::ConstantOp>(
458
+ loc, DenseElementsAttr::get (predicateType, true ));
459
+ }
460
+
461
+ // Create a loop that stores each (active) active ZA tile slice from memory.
424
462
auto step = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
425
- auto minTileSlices = rewriter.create <arith::ConstantIndexOp>(
426
- loc, arm_sme::getSMETileSliceMinNumElts (tileElementType));
427
- auto vscale =
428
- rewriter.create <vector::VectorScaleOp>(loc, rewriter.getIndexType ());
429
463
auto lowerBound = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
430
- // This describes both the number of ZA tile slices and the number of
431
- // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
432
- // ..., SVL_Q).
433
- auto numTileSlices =
434
- rewriter.create <arith::MulIOp>(loc, minTileSlices, vscale);
435
- auto forOp =
436
- rewriter.create <scf::ForOp>(loc, lowerBound, numTileSlices, step);
464
+ auto forOp = rewriter.create <scf::ForOp>(loc, lowerBound, upperBound, step);
437
465
438
466
rewriter.setInsertionPointToStart (forOp.getBody ());
439
467
440
- // Create an 'all true' predicate for the tile slice.
441
- auto predicateType =
442
- VectorType::get (tileType.getDimSize (1 ), rewriter.getI1Type (), true );
443
- auto allTruePredicate = rewriter.create <arith::ConstantOp>(
444
- loc, DenseElementsAttr::get (predicateType, true ));
445
-
446
468
SmallVector<Value> memrefIndices;
447
469
auto tileSliceIndex = forOp.getInductionVar ();
448
470
getMemrefIndices (tileStoreOp.getIndices (),
449
471
tileStoreOp.getMemRefType ().getRank (), tileSliceIndex,
450
- numTileSlices , memrefIndices, loc, rewriter);
472
+ upperBound , memrefIndices, loc, rewriter);
451
473
rewriter.replaceOpWithNewOp <arm_sme::StoreTileSliceOp>(
452
- tileStoreOp, tileStoreOp.getValueToStore (), tileSliceIndex,
453
- allTruePredicate, tileStoreOp.getBase (), memrefIndices,
454
- tileStoreOp.getLayout ());
474
+ tileStoreOp, tileStoreOp.getValueToStore (), tileSliceIndex, maskCols,
475
+ tileStoreOp.getBase (), memrefIndices, tileStoreOp.getLayout ());
455
476
456
477
return success ();
457
478
}
0 commit comments