Skip to content

[mlir][ArmSME] Lower transfer_write + transpose to vertical store #71181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,31 @@ struct TransferReadToArmSMELowering

/// Conversion pattern for vector.transfer_write.
///
/// vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
/// memref<?x?xi8>
/// ---
///
/// Example 1: op with identity permutation map to horizontal
/// arm_sme.tile_store:
///
/// vector.transfer_write %vector, %source[%c0, %c0]
/// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
///
/// is converted to:
///
/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this is now always passing layout to the builder, shouldn't this be printed as:

arm_sme.tile_store %vector, %source[%c0, %c0] layout<horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

horizontal layout is the default and isn't printed, but I could add it to be explicit if you think that's clearer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made an invalid assumption (I was expecting it to be printed since it was being passed to the builder). But, from: https://mlir.llvm.org/docs/DefiningDialects/AttributesAndTypes/#optional-and-default-valued-parameters

An optional parameter is omitted when it is equal to its default value.

So this is correct.

I could add it to be explicit if you think that's clearer

That would be my personal preference, but then what's the point of "default" values? No strong opinion, to me it's a bit confusing, but we can revisit some other time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My preference is the omit expected default values (and I think horizontal would be expected) :)

/// vector<[16]x[16]xi8>
/// ---
///
/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
/// (in-flight transpose):
///
/// vector.transfer_write %vector, %source[%c0, %c0]
/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
/// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
///
/// is converted to:
///
/// arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
/// : memref<?x?xi8>, vector<[16]x[16]xi8>
struct TransferWriteToArmSMELowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
Expand All @@ -156,9 +174,28 @@ struct TransferWriteToArmSMELowering
if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
return failure();

// Out-of-bounds dims are not supported.
if (writeOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(writeOp,
"not inbounds transfer write");

AffineExpr d0, d1;
bindDims(writeOp.getContext(), d0, d1);
AffineMap map = writeOp.getPermutationMap();
bool isTranspose = (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
writeOp.getContext()));

if (!map.isIdentity() && !isTranspose)
return rewriter.notifyMatchFailure(writeOp,
"unsupported permutation map");

arm_sme::TileSliceLayout layout =
isTranspose ? arm_sme::TileSliceLayout::Vertical
: arm_sme::TileSliceLayout::Horizontal;

rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
writeOp.getMask());
writeOp.getMask(), layout);
return success();
}
};
Expand Down
42 changes: 42 additions & 0 deletions mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,37 @@ func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest

// -----

/// in-flight transpose via vertical store.

// CHECK-LABEL: func.func @transfer_write_2d_transpose_i64(
// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xi64>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
func.func @transfer_write_2d_transpose_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
return
}

// -----

/// in-flight transpose via vertical store with mask.

// CHECK-LABEL: func.func @transfer_write_2d_transpose_with_mask_bf16(
// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xbf16>,
// CHECK-SAME: %[[MASK:.*]]: vector<[8]x[8]xi1>) {
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>, %mask : vector<[8]x[8]xi1>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
return
}

// -----

// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
// lowering only occurs for vector types of correct rank, shape, element size
// and number of scalable dims.
Expand Down Expand Up @@ -398,6 +429,17 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
return
}

// -----

// CHECK-LABEL: @transfer_write_2d__out_of_bounds
// CHECK: vector.transfer_write
// CHECK-NOT: arm_sme.tile_store
func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [false, false]} : vector<[4]x[4]xf32>, memref<?x?xf32>
return
}

//===----------------------------------------------------------------------===//
// vector.broadcast
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@ func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: i
return
}

// Vector transpose + store.
func.func @transfer_write_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
vector.transfer_write %0, %A[%base1, %base2] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
vector<[4]x[4]xf32>, memref<?x?xf32>
return
}

// Vector transpose + masked store.
func.func @transfer_write_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%mask = vector.create_mask %c4, %c2 : vector<[4]x[4]xi1>
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
vector.transfer_write %0, %A[%base1, %base2], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
vector<[4]x[4]xf32>, memref<?x?xf32>
return
}

// Vector load + print.
func.func @load_and_print(%A : memref<?x?xf32>, %base1: index, %base2: index) {
%0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
Expand Down Expand Up @@ -116,6 +135,26 @@ func.func @entry() {
call @transfer_write_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()

// 4. Reload 3. + transpose + store.
// CHECK-LABEL: TILE BEGIN:
// CHECK-NEXT: ( 0, 0, 20, 30
// CHECK-NEXT: ( 0, 0, 21, 31
// CHECK-NEXT: ( 0, 0, 0, 0
// CHECK-NEXT: ( 3, 13, 0, 0
call @transfer_write_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()

// 5. Reload 4. + transpose + masked store (nrows=4, ncols=2).
// The mask applies after permutation. Columns 2 and 3 (from 4.) are
// preserved.
// CHECK-LABEL: TILE BEGIN:
// CHECK-NEXT: ( 0, 0, 20, 30
// CHECK-NEXT: ( 0, 0, 21, 31
// CHECK-NEXT: ( 20, 21, 0, 0
// CHECK-NEXT: ( 30, 31, 0, 0
call @transfer_write_2d_mask_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()

memref.dealloc %A : memref<?x?xf32>

return
Expand Down