-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][ArmSME] Add mask operand to store_tile_slice #70838
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,6 +66,15 @@ class HasMatchingMaskTypeConstraint<string vector, string mask> : | |
vector, mask, | ||
"::llvm::cast<mlir::VectorType>($_self).cloneWith({}, IntegerType::get($_ctxt, 1))">; | ||
|
||
class TileSliceMaskConstraint<string tile, string mask> : | ||
TypesMatchWith< | ||
"`" # mask # "` has i1 element type and the shape is a slice of `" # tile # "`", | ||
tile, mask, | ||
"VectorType(" | ||
"VectorType::Builder(" | ||
"::llvm::cast<mlir::VectorType>($_self)" | ||
").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1)))">; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ArmSME attr definitions | ||
//===----------------------------------------------------------------------===// | ||
|
@@ -408,15 +417,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [ | |
} | ||
|
||
def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ | ||
AllTypesMatch<["tile", "result"]>, | ||
TypesMatchWith< | ||
"mask has i1 element type and is a slice of the result", | ||
"result", "mask", | ||
"VectorType(" | ||
"VectorType::Builder(" | ||
"::llvm::cast<mlir::VectorType>($_self)" | ||
").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))" | ||
")">, | ||
AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask"> | ||
]> { | ||
let summary = "Tile slice load and update operation"; | ||
let description = [{ | ||
|
@@ -474,7 +475,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ | |
}]; | ||
} | ||
|
||
def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> { | ||
def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [ | ||
TileSliceMaskConstraint<"tile", "mask"> | ||
]> { | ||
let summary = "Tile slice store operation"; | ||
let description = [{ | ||
Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile | ||
|
@@ -489,22 +492,27 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> { | |
dimensions since the operation is scalable, and the element type must be a | ||
scalar that matches the element type of the input tile. | ||
|
||
An SSA value `mask` specifies to mask out elements written to the MemRef. | ||
The `mask` type is an `i1` vector with a shape that matches how elements | ||
are written to the MemRef. | ||
|
||
Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory. | ||
```mlir | ||
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8> | ||
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] : vector<[16]x[16]xi8>, vector<[16]xi1>, memref<?x?xi8> | ||
``` | ||
|
||
Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory. | ||
```mlir | ||
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32> | ||
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, vector<[4]xi1>, memref<?x?xf32> | ||
``` | ||
|
||
Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory. | ||
```mlir | ||
arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, memref<?x?xi128> | ||
arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, vector<[1]xi1>, memref<?x?xi128> | ||
``` | ||
}]; | ||
let arguments = (ins SMETile:$tile, Index:$tile_slice_index, | ||
let arguments = (ins | ||
SMETile:$tile, Index:$tile_slice_index, AnyVector:$mask, | ||
c-rhodes marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base, | ||
Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout | ||
); | ||
|
@@ -518,8 +526,8 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> { | |
}]; | ||
|
||
let assemblyFormat = [{ | ||
$tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)? | ||
attr-dict `:` type($base) `,` type($tile) | ||
$tile `,` $tile_slice_index `,` $mask `,` $base `[` $indices `]` (`layout` `` $layout^)? | ||
attr-dict `:` type($base) `,` type($mask) `,` type($tile) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No it can be inferred as you say, but I decided to add the type anyway since this maps 1-1 with intrinsic. The type is also present in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer omitting the types that are very clear from context (i.e. the mask size must match the slice), but this is fine for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok cool cheers, I'll post a follow-up :) |
||
}]; | ||
} | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.