Skip to content

Commit 26bb905

Browse files
committed
[mlir][ArmSME] Add mask operand to store_tile_slice
1 parent 1908f47 commit 26bb905

File tree

7 files changed

+155
-131
lines changed

7 files changed

+155
-131
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,15 @@ class HasMatchingMaskTypeConstraint<string vector, string mask> :
6666
vector, mask,
6767
"::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">;
6868

69+
class TileSliceMaskConstraint<string tile, string mask> :
70+
TypesMatchWith<
71+
"`" # mask # "` has i1 element type and the shape is a slice of `" # tile # "`",
72+
tile, mask,
73+
"VectorType("
74+
"VectorType::Builder("
75+
"::llvm::cast<mlir::VectorType>($_self)"
76+
").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1)))">;
77+
6978
//===----------------------------------------------------------------------===//
7079
// ArmSME attr definitions
7180
//===----------------------------------------------------------------------===//
@@ -408,15 +417,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
408417
}
409418

410419
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
411-
AllTypesMatch<["tile", "result"]>,
412-
TypesMatchWith<
413-
"mask has i1 element type and is a slice of the result",
414-
"result", "mask",
415-
"VectorType("
416-
"VectorType::Builder("
417-
"::llvm::cast<mlir::VectorType>($_self)"
418-
").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
419-
")">,
420+
AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask">
420421
]> {
421422
let summary = "Tile slice load and update operation";
422423
let description = [{
@@ -474,7 +475,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
474475
}];
475476
}
476477

477-
def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
478+
def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
479+
TileSliceMaskConstraint<"tile", "mask">
480+
]> {
478481
let summary = "Tile slice store operation";
479482
let description = [{
480483
Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
@@ -489,22 +492,27 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
489492
dimensions since the operation is scalable, and the element type must be a
490493
scalar that matches the element type of the input tile.
491494

495+
An SSA value `mask` specifies to mask out elements written to the MemRef.
496+
The `mask` type is an `i1` vector with a shape that matches how elements
497+
are written to the MemRef.
498+
492499
Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.
493500
```mlir
494-
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
501+
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] : vector<[16]x[16]xi8>, vector<[16]xi1>, memref<?x?xi8>
495502
```
496503

497504
Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
498505
```mlir
499-
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
506+
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, vector<[4]xi1>, memref<?x?xf32>
500507
```
501508

502509
Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
503510
```mlir
504-
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
511+
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, vector<[1]xi1>, memref<?x?xi128>
505512
```
506513
}];
507-
let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
514+
let arguments = (ins
515+
SMETile:$tile, Index:$tile_slice_index, AnyVector:$mask,
508516
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
509517
Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
510518
);
@@ -518,8 +526,8 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
518526
}];
519527

520528
let assemblyFormat = [{
521-
$tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)?
522-
attr-dict `:` type($base) `,` type($tile)
529+
$tile `,` $tile_slice_index `,` $mask `,` $base `[` $indices `]` (`layout` `` $layout^)?
530+
attr-dict `:` type($base) `,` type($mask) `,` type($tile)
523531
}];
524532
}
525533

mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,14 +190,21 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
190190

191191
rewriter.setInsertionPointToStart(forOp.getBody());
192192

193+
// Create an 'all true' predicate for the tile slice.
194+
auto predicateType =
195+
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
196+
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
197+
loc, DenseElementsAttr::get(predicateType, true));
198+
193199
SmallVector<Value> memrefIndices;
194200
auto tileSliceIndex = forOp.getInductionVar();
195201
getMemrefIndices(tileStoreOp.getIndices(),
196202
tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
197203
numTileSlices, memrefIndices, loc, rewriter);
198204
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
199205
tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
200-
tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
206+
allTruePredicate, tileStoreOp.getBase(), memrefIndices,
207+
tileStoreOp.getLayout());
201208

202209
return success();
203210
}

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,7 @@ struct StoreTileSliceToArmSMELowering
278278
auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
279279
loc, rewriter.getI32Type(), tileSlice);
280280

281-
// Create all active predicate mask.
282-
auto one = rewriter.create<arith::ConstantOp>(
283-
loc, rewriter.getI1Type(),
284-
rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
285-
auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
286-
/*scalableDims=*/{true});
287-
auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
281+
auto maskOp = storeTileSliceOp.getMask();
288282

289283
Value tileI32 = castTileIDToI32(tile, loc, rewriter);
290284
arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
@@ -295,23 +289,23 @@ struct StoreTileSliceToArmSMELowering
295289
llvm_unreachable("unexpected element type!");
296290
case 8:
297291
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
298-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
292+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
299293
break;
300294
case 16:
301295
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
302-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
296+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
303297
break;
304298
case 32:
305299
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
306-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
300+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
307301
break;
308302
case 64:
309303
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
310-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
304+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
311305
break;
312306
case 128:
313307
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
314-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
308+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
315309
break;
316310
}
317311
} else {
@@ -320,23 +314,23 @@ struct StoreTileSliceToArmSMELowering
320314
llvm_unreachable("unexpected element type!");
321315
case 8:
322316
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
323-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
317+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
324318
break;
325319
case 16:
326320
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
327-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
321+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
328322
break;
329323
case 32:
330324
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
331-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
325+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
332326
break;
333327
case 64:
334328
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
335-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
329+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
336330
break;
337331
case 128:
338332
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
339-
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
333+
storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
340334
break;
341335
}
342336
}

mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
4848
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
4949
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
5050
// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
51+
// CHECK: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
5152
// CHECK: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
52-
// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
53+
// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[PTRUE_S]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
5354
func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
5455
%c0 = arith.constant 0 : index
5556
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>

0 commit comments

Comments
 (0)