Skip to content

[mlir][ArmSME] Support vertical layout in load and store ops #66758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

c-rhodes
Copy link
Collaborator

In SME a ZA tile slice is a one-dimensional set of horizontally or vertically contiguous elements within a ZA tile. Currently the load and store ops only support horizontal tile slices. This patch adds a tile slice layout attribute to the load and store ops to support both horizontal and vertical tile slices.

When lowering from Vector dialect horizontal layout is the default.

@llvmbot
Copy link
Member

llvmbot commented Sep 19, 2023

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir

Changes

In SME a ZA tile slice is a one-dimensional set of horizontally or vertically contiguous elements within a ZA tile. Currently the load and store ops only support horizontal tile slices. This patch adds a tile slice layout attribute to the load and store ops to support both horizontal and vertical tile slices.

When lowering from Vector dialect horizontal layout is the default.


Patch is 98.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66758.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h (+5)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td (+77-39)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt (+7)
  • (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+9-9)
  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+8-6)
  • (modified) mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp (+12)
  • (modified) mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+71-22)
  • (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+31-11)
  • (modified) mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir (+3-3)
  • (added) mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir (+401)
  • (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+450-90)
  • (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+8-8)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir (+110)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index d1ed02abfd5c552..f947fc8fe1631b8 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -21,6 +21,11 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
+
 #include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..1a4984f3bd6ba27 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -14,6 +14,7 @@
 #ifndef ARMSME_OPS
 #define ARMSME_OPS
 
+include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
@@ -36,6 +37,7 @@ def ArmSME_Dialect : Dialect {
     https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
   }];
   let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
+  let useDefaultAttributePrinterParser = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -83,6 +85,24 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
                   "::llvm::cast<VectorType>($_self).getElementType())"
                   ".getWidth())">;
 
+//===----------------------------------------------------------------------===//
+// ArmSME attr definitions
+//===----------------------------------------------------------------------===//
+
+def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
+  I32EnumAttrCase<"Horizontal", 0, "hor">,
+  I32EnumAttrCase<"Vertical", 1, "ver">,
+]> {
+  let cppNamespace = "::mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+/// An attribute that specifies the layout of a tile slice in a tile.
+def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
+                                          "layout"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME op definitions
 //===----------------------------------------------------------------------===//
@@ -239,27 +259,33 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
   let summary = "Tile load operation";
   let description = [{
     Loads a 2D SME "virtual tile" from memory defined by a base and indices,
-    with the shape defined by the 2D scalable vector type of the result tile.
-    The slice of memory must be contiguous. The memref must be either rank 1 or
-    rank 2 with dynamic dimensions, since the operation is scalable, and the
-    element type must be a scalar that matches the element type of the result.
+    with the shape defined by the 2D scalable vector type of the result tile. A
+    tile slice layout attribute specifies whether the slices of the tile being
+    loaded are horizontal or vertical. The slice of memory must be contiguous.
+    The memref must be either rank 1 or rank 2 with dynamic dimensions, since
+    the operation is scalable, and the element type must be a scalar that
+    matches the element type of the result.
+
+    The default tile slice layout when lowering from higher-level dialects is
+    horizontal.
 
-    Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
+    Example 1: Load an 8-bit element ZA tile with horizontal layout from memory (ZA0.B).
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+    %tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
     ```
 
-    Example 2: Load a FP 32-bit element ZA tile from memory.
+    Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+    %tile = arm_sme.tile_load <ver>, %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
     ```
 
-    Example 3: Load a 128-bit element ZA tile from memory.
+    Example 3: Load a 128-bit element ZA tile with horizontal layout from memory.
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+    %tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
   }];
   let arguments = (ins
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
       Variadic<Index>:$indices);
   let results = (outs SMETile:$result);
@@ -274,7 +300,8 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
   }];
 
   let assemblyFormat =
-      "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
+      "$layout `,` $base `[` $indices `]` attr-dict "
+        "`:` type($base) `,` type($result)";
 }
 
 def TileStoreOp : ArmSME_Op<"tile_store"> {
@@ -282,27 +309,32 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
   let description = [{
     Stores a 2D SME "virtual tile" to memory defined by a base and indices,
     with the shape defined by the 2D scalable vector type of the tile being
-    stored. The slice of memory must be contiguous. The memref must be either
-    rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
-    and the element type must be a scalar that matches the element type of the
-    result.
+    stored. A tile slice layout attribute specifies whether the slices of the
+    tile being stored are horizontal or vertical. The slice of memory must be
+    contiguous.  The memref must be either rank 1 or rank 2 with dynamic
+    dimensions, since the operation is scalable, and the element type must be a
+    scalar that matches the element type of the result.
+
+    The default tile slice layout when lowering from higher-level dialects is
+    horizontal.
 
-    Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
+    Example 1: Store an 8-bit element ZA tile with horizontal layout to memory (ZA0.B).
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+    arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
     ```
 
-    Example 2: Store a FP 32-bit element ZA tile to memory.
+    Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+    arm_sme.tile_store %tile, <ver>, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
     ```
 
-    Example 3: Store a 128-bit element ZA tile to memory.
+    Example 3: Store a 128-bit element ZA tile with horizontal layout to memory.
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+    arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
   }];
   let arguments = (ins SMETile:$valueToStore,
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
       Variadic<Index>:$indices);
   let extraClassDeclaration = [{
@@ -314,8 +346,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
     }
   }];
 
-  let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
-                       "`:` type($base) `,` type($valueToStore)";
+  let assemblyFormat =
+    "$valueToStore `,` $layout `,` $base `[` $indices `]` attr-dict "
+      "`:` type($base) `,` type($valueToStore)";
 }
 
 def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
@@ -326,29 +359,32 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
     Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
     slice is defined by the dimension of the 2D scalable vector type pointed by
     the index. A tile slice index describes where in the input tile the tile
-    slice is loaded to. The updated tile is returned as the result.
+    slice is loaded to. A tile slice layout attribute specifies whether the
+    tile slice being loaded at the given index is horizontal or vertical. The
+    updated tile is returned as the result.
 
     The slice of memory read is defined by a base and indices and must be
     contiguous. The memref must be either rank 1 or rank 2, have dynamic
     dimensions since the operation is scalable, and the element type must be a
     scalar that matches the element type of the result.
 
-    Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index.
+    Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+    %tile_update = arm_sme.load_tile_slice <hor>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
     ```
 
-    Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index.
+    Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+    %tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
     ```
 
-    Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index.
+    Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+    %tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
   }];
   let arguments = (ins
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to load from">:$base,
       SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
   let results = (outs SMETile:$result);
@@ -363,7 +399,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
   }];
 
   let assemblyFormat = [{
-    $base `[` $indices `]` `,` $tile `,` $tile_slice_index
+    $layout `,` $base `[` $indices `]` `,` $tile `,` $tile_slice_index
       attr-dict `:` type($base) `,` type($result)
   }];
 }
@@ -374,29 +410,31 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
     Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
     slice is defined by the dimension of the 2D scalable vector type pointed by
     the index. A tile slice index describes where in the input tile the tile
-    slice is stored from.
+    slice is stored from. A tile slice layout attribute specifies whether the
+    tile slice being stored from the given index is horizontal or vertical.
 
     The slice of memory written is defined by a base and indices and must be
     contiguous. The memref must be either rank 1 or rank 2, have dynamic
     dimensions since the operation is scalable, and the element type must be a
     scalar that matches the element type of the input tile.
 
-    Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory.
+    Example 1: Store vector<[16]xi8> horizontal tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
     ```
 
-    Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory.
+    Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
     ```
 
-    Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory.
+    Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
   }];
   let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
       Variadic<Index>:$indices);
   let extraClassDeclaration = [{
@@ -409,7 +447,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
   }];
 
   let assemblyFormat = [{
-    $tile `,` $tile_slice_index `,` $base `[` $indices `]`
+    $tile `,` $tile_slice_index `,` $layout `,` $base `[` $indices `]`
       attr-dict `:` type($base) `,` type($tile)
   }];
 }
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index d20ee65e62e7dc0..7afd0d014541687 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -4,3 +4,10 @@ add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
 set(LLVM_TARGET_DEFINITIONS ArmSME.td)
 mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
+
+mlir_tablegen(ArmSMEEnums.h.inc -gen-enum-decls)
+mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
+mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
+add_public_tablegen_target(MLIRArmSMEEnumsIncGen)
+add_dependencies(mlir-headers MLIRArmSMEEnumsIncGen)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 4028a7ad0870b51..86cabe67f2695f1 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -54,7 +54,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///
 ///  BEFORE:
 ///  ```mlir
-///  %tile = arm_sme.tile_load %src[%c0, %c0] :
+///  %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] :
 ///    memref<?x?xi32>, vector<[4]x[4]xi32>
 ///  ```
 ///
@@ -68,7 +68,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///  %min_svl_s = arith.constant 4 : index
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
 ///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-///    %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
+///    %tile_update = arm_sme.load_tile_slice <hor>, %src[%tile_slice_idx],
 ///      %tile, %tile_slice_idx : memref<?x?xi32>, vector<[4]x[4]xi32>
 ///  }
 ///  ```
@@ -116,9 +116,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     getMemrefIndices(tileLoadOp.getIndices(),
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      numTileSlices, memrefIndices, loc, rewriter);
-    rewriter.create<arm_sme::LoadTileSliceOp>(loc, tileType,
-                                              tileLoadOp.getBase(), tile,
-                                              memrefIndices, tileSliceIndex);
+    rewriter.create<arm_sme::LoadTileSliceOp>(
+        loc, tileType, tileLoadOp.getLayout(), tileLoadOp.getBase(), tile,
+        memrefIndices, tileSliceIndex);
 
     rewriter.setInsertionPointAfter(forOp);
 
@@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///
 ///  BEFORE:
 ///  ```mlir
-///  arm_sme.tile_store %tile, %dest[%c0, %c0]
+///  arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0]
 ///    : memref<?x?xi32>, vector<[4]x[4]xi32
 ///  ```
 ///
@@ -146,8 +146,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///  %min_svl_s = arith.constant 4 : index
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
 ///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-///    arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx]
-///      : memref<?x?xi32>, vector<[4]x[4]xi32>
+///    arm_sme.store_tile_slice %tile, %tile_slice_idx, <ver>,
+///      %dest[%tile_slice_idx] : memref<?x?xi32>, vector<[4]x[4]xi32>
 ///  }
 ///  ```
 struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
@@ -184,7 +184,7 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
                      numTileSlices, memrefIndices, loc, rewriter);
     rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
         tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
-        tileStoreOp.getBase(), memrefIndices);
+        tileStoreOp.getLayout(), tileStoreOp.getBase(), memrefIndices);
 
     return success();
   }
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 0a1a087d9c8d6c7..feaec0e035ed9fd 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -65,8 +65,8 @@ namespace {
 ///
 /// is converted to:
 ///
-///   arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
-///                                                   vector<[16]x[16]xi8>
+///   arm_sme.tile_store %vector, <hor>, %source[%c0, %c0]
+///     : memref<?x?xi8>, vector<[16]x[16]xi8>
 struct TransferWriteToArmSMELowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -81,8 +81,8 @@ struct TransferWriteToArmSMELowering
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
-        writeOp, writeOp.getVector(), writeOp.getSource(),
-        writeOp.getIndices());
+        writeOp, writeOp.getVector(), arm_sme::TileSliceLayout::Horizontal,
+        writeOp.getSource(), writeOp.getIndices());
     return success();
   }
 };
@@ -97,7 +97,8 @@ struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
-        load, load.getVectorType(), load.getBase(), load.getIndices());
+        load, load.getVectorType(), arm_sme::TileSliceLayout::Horizontal,
+        load.getBase(), load.getIndices());
 
     return success();
   }
@@ -113,7 +114,8 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
-        store, store.getValueToStore(), store.getBase(), store.getIndices());
+        store, store.getValueToStore(), arm_sme::TileSliceLayout::Horizontal,
+        store.getBase(), store.getIndices());
 
     return success();
   }
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 750627421215dfb..92fb146691a0beb 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,8 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace mlir::arm_sme;
@@ -22,13 +24,23 @@ using namespace mlir::arm_sme;
 
 #include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc"
 
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.cpp.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
 
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc"
 
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
+
 void ArmSMEDialect::initialize() {
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
+      >();
+
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 9b6332a478ad...
[truncated]

@banach-space
Copy link
Contributor

banach-space commented Sep 20, 2023

Quite a large change, but very well executed and makes a lot of sense. Btw, what logic would be deciding to generate vertical instead of horizontal loads/stores?

@c-rhodes
Copy link
Collaborator Author

Quite a large change, but very well executed and makes a lot of sense. Btw, what logic would be deciding to generate vertical instead of horizontal loads/stores?

Thanks for reviewing. For vector.load and vector.store I don't believe there is anyway to express this, but perhaps a canonicalization could be added that replaces a transpose of a load/store to a load/store in the opposite direction.

vector.transfer_read and vector.transfer_write however take an affine_map that can express a transpose, here an example I just created based on mlir/test/Integration/Dialect/Vector/CPU/test-transfer-to-loops.mlir:

#transpose_map = affine_map<(d0, d1) -> (d1, d0)>

func.func private @printMemrefF32(memref<*xf32>)

func.func @alloc_2d_filled_f32(%arg0: index, %arg1: index) -> memref<?x?xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
  scf.for %arg5 = %c0 to %arg0 step %c1 {
    scf.for %arg6 = %c0 to %arg1 step %c1 {
      %tmp2 = arith.index_cast %arg6: index to i32
      %tmp3 = arith.sitofp %tmp2 : i32 to f32
      memref.store %tmp3, %0[%arg5, %arg6] : memref<?x?xf32>
    }
  }
  return %0 : memref<?x?xf32>
}

func.func @main() {
  %c0 = arith.constant 0 : index
  %c4 = arith.constant 4 : index
  %cst = arith.constant -4.2e+01 : f32

  %0 = call @alloc_2d_filled_f32(%c4, %c4) : (index, index) -> memref<?x?xf32>
  %converted = memref.cast %0 : memref<?x?xf32> to memref<*xf32>
  call @printMemrefF32(%converted): (memref<*xf32>) -> ()

  %1 = vector.transfer_read %0[%c0, %c0], %cst {permutation_map = #transpose_map} : memref<?x?xf32>, vector<4x4xf32>
  vector.transfer_write %1, %0[%c0, %c0] : vector<4x4xf32>, memref<?x?xf32>
  call @printMemrefF32(%converted): (memref<*xf32>) -> ()

  memref.dealloc %0 : memref<?x?xf32>
  return
}

run:

Unranked Memref base@ = 0x216f7190 rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data =
[[0,   1,   2,   3],
 [0,   1,   2,   3],
 [0,   1,   2,   3],
 [0,   1,   2,   3]]
Unranked Memref base@ = 0x216f7190 rank = 2 offset = 0 sizes = [4, 4] strides = [4, 1] data =
[[0,   0,   0,   0],
 [1,   1,   1,   1],
 [2,   2,   2,   2],
 [3,   3,   3,   3]]

We can extend to the lowering to SME to support this.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Just a small refactoring comment

Comment on lines +211 to +233
switch (tileElementWidth) {
default:
llvm_unreachable("unexpected element type!");
case 8:
rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 16:
rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 32:
rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 64:
rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
break;
case 128:
rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
loc, allActiveMask, ptr, tileI32, tileSliceI32);
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth refactor this into a template function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not obvious to me how I could improve this with a template function, could you give us some pointers?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only thing that came to my mind:

template <int N>
void callLoadIntrinsic(ConversionPatternRewriter &rewriter,
                       arm_sme::StoreTileSliceOp storeTileSliceOp,
                       mlir::vector::SplatOp allActiveMask, Value ptr, Value tileI32,
                       mlir::arith::IndexCastUIOp tileSliceI32) {
}

template <>
void callLoadIntrinsic<8>(ConversionPatternRewriter &rewriter,
                          arm_sme::StoreTileSliceOp storeTileSliceOp,
                          mlir::vector::SplatOp allActiveMask, Value ptr,
                          Value tileI32,
                          mlir::arith::IndexCastUIOp tileSliceI32) {
  rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
      storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
}

But that wouldn't be an improvement IMHO, so I'd go ahead with what you have here already.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could have a template function with a single switch and then:
callLoadIntrinsic<arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz...>(...)
callLoadIntrinsic<the vertical ones>(...)

That should reduce the code by 2x but maybe I'm missing something :)

case 128:
rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same?

In SME a ZA tile slice is a one-dimensional set of horizontally or
vertically contiguous elements within a ZA tile. Currently the load and
store ops only support horizontal tile slices. This patch adds a tile
slice layout attribute to the load and store ops to support both
horizontal and vertical tile slices.

When lowering from Vector dialect horizontal layout is the default.
@c-rhodes c-rhodes force-pushed the arm-sme-support-vertical-tile-slice-layout branch from a3322e8 to 802cbec Compare September 22, 2023 16:03
@c-rhodes
Copy link
Collaborator Author

thanks for reviewing all. I've updated this to make tile slice layout optional with a default of horizontal, as suggested by @dcaballe on #66760. Optional custom attributes aren't supported as the first argument in the asm string so I've moved layout to the end of the asm string for both loads and stores for consistency. Quite a lot of churn to support this unfortunately, would appreciate a second pair of eyes on the changes before I land.

@c-rhodes
Copy link
Collaborator Author

Forgot to mention, also changes the format in the IR from:
<hor> -> <horizontal>
<ver> -> <vertical>
as suggested on #66760.

@github-actions
Copy link

github-actions bot commented Sep 22, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@c-rhodes c-rhodes merged commit 75a71c2 into llvm:main Sep 25, 2023
c-rhodes added a commit to c-rhodes/llvm-project that referenced this pull request Sep 25, 2023
This patch adds support for lowering vector.transpose to ArmSME. It's
implemented by storing the input tile of the tranpose to memory and
reloading vertically, building on top of the tile slice layout support.

Tranposing via memory is obviously expensive, the current intention is
to avoid the transpose if possible, this is therefore intended as a
fallback and to provide base support for Vector ops. If it turns out
transposes can't be avoided then this should be replaced with a more
optimal implementation, perhaps with tile <-> vector (MOVA) ops.

Depends on llvm#66758.
c-rhodes added a commit that referenced this pull request Sep 25, 2023
This patch adds support for lowering vector.transpose to ArmSME. It's
implemented by storing the input tile of the tranpose to memory and
reloading vertically, building on top of the tile slice layout support.

Tranposing via memory is obviously expensive, the current intention is
to avoid the transpose if possible, this is therefore intended as a
fallback and to provide base support for Vector ops. If it turns out
transposes can't be avoided then this should be replaced with a more
optimal implementation, perhaps with tile <-> vector (MOVA) ops.

Depends on #66758.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants