Skip to content

[mlir][ArmSME] Add support for lowering masked tile_store ops #71180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 43 additions & 22 deletions mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,38 +173,59 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
auto tileType = tileStoreOp.getVectorType();
auto tileElementType = tileType.getElementType();

// Create a loop that stores each ZA tile slice from memory.
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);

Value maskCols;
Value upperBound;
auto maskOp = tileStoreOp.getMask();
if (maskOp) {
auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
tileStoreOp, "unsupported mask op, only 'vector.create_mask' is "
"currently supported");

auto numRows = createMaskOp.getOperands()[0];
auto numCols = createMaskOp.getOperands()[1];

upperBound = numRows;
maskCols =
rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
} else {
// Store all tile slices if no mask.
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
// This describes both the number of ZA tile slices and the number of
// elements in a vector of SVL bits for a given element type (SVL_B,
// SVL_H,
// ..., SVL_Q).
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);

upperBound = numTileSlices;
// Create an 'all true' predicate for the tile slice.
maskCols = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predicateType, true));
}

// Create a loop that stores each (active) active ZA tile slice from memory.
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
auto vscale =
rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
// This describes both the number of ZA tile slices and the number of
// elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
// ..., SVL_Q).
auto numTileSlices =
rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
auto forOp =
rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);

rewriter.setInsertionPointToStart(forOp.getBody());

// Create an 'all true' predicate for the tile slice.
auto predicateType =
VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
auto allTruePredicate = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(predicateType, true));

SmallVector<Value> memrefIndices;
auto tileSliceIndex = forOp.getInductionVar();
getMemrefIndices(tileStoreOp.getIndices(),
tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
numTileSlices, memrefIndices, loc, rewriter);
upperBound, memrefIndices, loc, rewriter);
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
allTruePredicate, tileStoreOp.getBase(), memrefIndices,
tileStoreOp.getLayout());
tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());

return success();
}
Expand Down
25 changes: 23 additions & 2 deletions mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
// CHECK-DAG: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
// CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
// CHECK: %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
// CHECK: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[PTRUE_S]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
Expand All @@ -67,6 +67,27 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
return
}

// -----

// CHECK-LABEL: func.func @arm_sme_tile_store_hor_with_mask(
// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[NUM_ROWS:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK-NEXT: arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[NUM_COLS]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}

//===----------------------------------------------------------------------===//
// vector.print
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// DEFINE: %{entry_point} = entry
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: -enable-arm-streaming="mode=locally enable-za" \
// DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
// DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils

// RUN: %{compile} | %{run} | FileCheck %s

// Vector store.
func.func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%c0 = arith.constant 0.0 : f32
%zero = vector.splat %c0 : vector<[4]x[4]xf32>
vector.transfer_write %zero, %A[%base1, %base2] {in_bounds=[true, true]} :
vector<[4]x[4]xf32>, memref<?x?xf32>
return
}

// Masked vector store.
func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%c0 = arith.constant 0.0 : f32
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
%zero = vector.splat %c0 : vector<[4]x[4]xf32>
vector.transfer_write %zero, %A[%base1, %base2], %mask {in_bounds=[true, true]} :
vector<[4]x[4]xf32>, memref<?x?xf32>
return
}

// Vector load + print.
func.func @load_and_print(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>

vector.print str "TILE BEGIN:"
vector.print %0: vector<[4]x[4]xf32>

return
}

// Allocate heap memory of size 'd0' x 'd1' and initialize.
//
// Example:
//
// initialize_memory(%c4, %c5)
//
// 0, 1, 2, 3, 4
// 10, 11, 12, 13, 14
// 20, 21, 22, 23, 24
// 30, 31, 32, 33, 34
//
// Returns dynamic memref. It's the callers responsiblity to free the returned
// memref.
func.func @initialize_memory(%d0 : index, %d1 : index) -> memref<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c1_f32 = arith.constant 1.0 : f32
%c10_f32 = arith.constant 10.0 : f32

%A = memref.alloc(%d0, %d1) : memref<?x?xf32>

%init = arith.constant 0.0 : f32
scf.for %i = %c0 to %d0 step %c1 iter_args(%val = %init) -> f32 {
scf.for %j = %c0 to %d1 step %c1 iter_args(%inner_val = %val) -> f32 {
memref.store %inner_val, %A[%i, %j] : memref<?x?xf32>
%inner_val_next = arith.addf %inner_val, %c1_f32 : f32
scf.yield %inner_val_next : f32
}
%val_next = arith.addf %val, %c10_f32 : f32
scf.yield %val_next : f32
}

return %A : memref<?x?xf32>
}

func.func @entry() {
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index

// Allocate enough memory to load a 32-bit tile plus a tiny bit more to test
// non-zero offsets while remaining inbounds.
%vscale = vector.vscale
%svl_s = arith.muli %c4, %vscale : index
%svl_s_plus_two = arith.addi %svl_s, %c2 : index

// 1. Initialize memory
// CHECK-LABEL: TILE BEGIN:
// CHECK-NEXT: ( 0, 1, 2, 3
// CHECK-NEXT: ( 10, 11, 12, 13
// CHECK-NEXT: ( 20, 21, 22, 23
// CHECK-NEXT: ( 30, 31, 32, 33
%A = call @initialize_memory(%svl_s_plus_two, %svl_s_plus_two) : (index, index) -> memref<?x?xf32>
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()

// 2. Write 2-D vector of zeroes to 1. at offset [2, 2].
// CHECK-LABEL: TILE BEGIN:
// CHECK-NEXT: ( 0, 1, 2, 3
// CHECK-NEXT: ( 10, 11, 12, 13
// CHECK-NEXT: ( 20, 21, 0, 0
// CHECK-NEXT: ( 30, 31, 0, 0
call @transfer_write_2d(%A, %c2, %c2) : (memref<?x?xf32>, index, index) -> ()
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()

// 3. Write 2-D vector of zeroes to 2. but with mask (nrows=2, ncols=3).
// CHECK-LABEL: TILE BEGIN:
// CHECK-NEXT: ( 0, 0, 0, 3
// CHECK-NEXT: ( 0, 0, 0, 13
// CHECK-NEXT: ( 20, 21, 0, 0
// CHECK-NEXT: ( 30, 31, 0, 0
call @transfer_write_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()

memref.dealloc %A : memref<?x?xf32>

return
}