|
| 1 | +//===-- Passes.td - ArmSVE pass definition file ------------*- tablegen -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#ifndef MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD |
| 10 | +#define MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD |
| 11 | + |
| 12 | +include "mlir/Pass/PassBase.td" |
| 13 | + |
| 14 | +def LegalizeVectorStorage |
| 15 | + : Pass<"arm-sve-legalize-vector-storage", "mlir::func::FuncOp"> { |
| 16 | + let summary = "Ensures stores of SVE vector types will be legal"; |
| 17 | + let description = [{ |
| 18 | + This pass ensures that loads, stores, and allocations of SVE vector types |
| 19 | + will be legal in the LLVM backend. It does this at the memref level, so this |
| 20 | + pass must be applied before lowering all the way to LLVM. |
| 21 | + |
| 22 | + This pass currently addresses two issues. |
| 23 | + |
| 24 | + ## Loading and storing predicate types |
| 25 | + |
| 26 | + It is only legal to load/store predicate types equal to (or greater than) a |
| 27 | + full predicate register, which in MLIR is `vector<[16]xi1>`. Smaller |
| 28 | + predicate types (`vector<[1|2|4|8]xi1>`) must be converted to/from a full |
| 29 | + predicate type (referred to as a `svbool`) before and after storing and |
| 30 | + loading respectively. This pass does this by widening allocations and |
| 31 | + inserting conversion intrinsics. Note: Non-powers-of-two masks (e.g. |
| 32 | + `vector<[7]xi1>`), which are not SVE predicates, are ignored. |
| 33 | + |
| 34 | + For example: |
| 35 | + |
| 36 | + ```mlir |
| 37 | + %alloca = memref.alloca() : memref<vector<[4]xi1>> |
| 38 | + %mask = vector.constant_mask [4] : vector<[4]xi1> |
| 39 | + memref.store %mask, %alloca[] : memref<vector<[4]xi1>> |
| 40 | + %reload = memref.load %alloca[] : memref<vector<[4]xi1>> |
| 41 | + ``` |
| 42 | + Becomes: |
| 43 | + ```mlir |
| 44 | + %alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>> |
| 45 | + %mask = vector.constant_mask [4] : vector<[4]xi1> |
| 46 | + %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1> |
| 47 | + memref.store %svbool, %alloca[] : memref<vector<[16]xi1>> |
| 48 | + %reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>> |
| 49 | + %reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1> |
| 50 | + ``` |
| 51 | + |
| 52 | + ## Relax alignments for SVE vector allocas |
| 53 | + |
| 54 | + The storage for SVE vector types only needs to have an alignment that |
| 55 | + matches the element type (for example 4 byte alignment for `f32`s). However, |
| 56 | + the LLVM backend currently defaults to aligning to `base size` x |
| 57 | + `element size` bytes. For non-legal vector types like `vector<[8]xf32>` this |
| 58 | + results in 8 x 4 = 32-byte alignment, but the backend only supports up to |
| 59 | + 16-byte alignment for SVE vectors on the stack. Explicitly setting a smaller |
| 60 | + alignment prevents this issue. |
| 61 | + }]; |
| 62 | + let constructor = "mlir::arm_sve::createLegalizeVectorStoragePass()"; |
| 63 | + let dependentDialects = ["func::FuncDialect", |
| 64 | + "memref::MemRefDialect", "vector::VectorDialect", |
| 65 | + "arm_sve::ArmSVEDialect"]; |
| 66 | +} |
| 67 | + |
| 68 | +#endif // MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD |
0 commit comments