Skip to content

Commit 7f6defd

Browse files
committed
Merge remote-tracking branch 'origin/zhicong/deep_tile_matmul' into layout_propagation
2 parents 57e4f22 + 23dfa97 commit 7f6defd

40 files changed

+1210
-153
lines changed

.github/workflows/build-llvm.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@ name: LLVM Build
22

33
on:
44
workflow_dispatch:
5+
push:
6+
paths:
7+
- cmake/llvm-version.txt
8+
- .github/workflows/build-llvm.yml
59

610
permissions: read-all
711

812
jobs:
913
build:
1014
name: Build
11-
runs-on: [self-hosted, 0.0.1]
15+
runs-on: [self-hosted]
1216

1317
steps:
1418
- uses: actions/checkout@v4
@@ -27,7 +31,7 @@ jobs:
2731
python3 -m pip install -r mlir/python/requirements.txt
2832
mkdir llvm-install
2933
cmake -G Ninja llvm -B build -DCMAKE_INSTALL_PREFIX=llvm-install -DMLIR_ENABLE_BINDINGS_PYTHON=ON -DPython3_EXECUTABLE=$(which python3) \
30-
-DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=true -DLLVM_ENABLE_PROJECTS="mlir" -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_INSTALL_UTILS=true -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DLLVM_INSTALL_GTEST=ON
34+
-DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=true -DLLVM_ENABLE_PROJECTS="mlir" -DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD="SPIRV" -DLLVM_TARGETS_TO_BUILD="X86" -DLLVM_INSTALL_UTILS=true -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DLLVM_INSTALL_GTEST=ON
3135
cmake --build build --target install
3236
cd llvm-install
3337
tar -zcf ../llvm.tgz .

README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,15 @@ cmake --build . --target gc-check
5858
Notes:
5959
* `/PATH/TO/llvm-project/llvm-install` should be the install path of LLVM. If you installed LLVM elsewhere by `-DCMAKE_INSTALL_PREFIX` option when building LLVM, you need to change the path in `-DMLIR_DIR` accordingly.
6060
* The cmake option `-DLLVM_EXTERNAL_LIT` is for the tests of this project. It requires the `lit` tool to be installed in the system. You can install it via `pip install lit`. If you don't need to run the tests of this repo, you can omit this option in the command line.
61-
* If GPU components are on (`-DGC_USE_GPU=ON`), make sure the Level-zero runtime is installed in your system. Either install Level-zero runtime via system package managers (e.g. `apt`), or follow the instructions of [IMEX](https://github.com/intel/mlir-extensions).
61+
62+
More notes if GPU components are on (`-DGC_USE_GPU=ON`):
63+
* make sure the OpenCL runtime is installed in your system. You can either
64+
install using OS-provided package (Ubuntu 22.04)
65+
```sh
66+
sudo apt install -y intel-opencl-icd opencl-c-headers
67+
```
68+
Or, download and install package from: https://github.com/intel/compute-runtime/releases
69+
* the LLVM codebase needs to be patched to support XeGPU lowering (from IMEX). Please follow instructions of [IMEX](https://github.com/intel/mlir-extensions) on patching LLVM.
6270

6371
Graph Compiler supports the following build-time options.
6472

cmake/imex.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ get_property(IMEX_INCLUDES GLOBAL PROPERTY IMEX_INCLUDES)
44
if (NOT DEFINED IMEX_INCLUDES)
55
include(functions)
66
set(IMEX_CHECK_LLVM_VERSION ON)
7-
set(IMEX_ENABLE_L0_RUNTIME 1)
7+
set(IMEX_ENABLE_L0_RUNTIME 0)
88
# TODO: Change to main https://github.com/oneapi-src/oneDNN.git when all the
99
# required functionality is merged.
1010
gc_fetch_content(imex 496b240093b5e132b60c5ee69878300fe69be300 https://github.com/Menooker/mlir-extensions
11-
CMAKE_ARGS "-DMLIR_DIR=${MLIR_DIR};-DIMEX_CHECK_LLVM_VERSION=ON;-DIMEX_ENABLE_L0_RUNTIME=1"
11+
CMAKE_ARGS "-DMLIR_DIR=${MLIR_DIR};-DIMEX_CHECK_LLVM_VERSION=ON;-DIMEX_ENABLE_L0_RUNTIME=0"
1212
)
1313

1414
set(IMEX_INCLUDES

cmake/llvm-version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
37661a17e26d9002ae9ade8c0de3932c22f16360
1+
89946bda5e1c7ceaf6d26634cc8c8c9498d9f7be

include/gc/Analysis/MatmulConfigAnalysis.h

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,63 +22,57 @@ using namespace mlir;
2222
struct SystemDesc {
2323
// get runtime OMP_NUM_THREADS
2424
uint32_t getNumThreads() {
25-
DataLayout layout = DataLayout(module);
26-
MLIRContext *ctx = module->getContext();
2725
std::optional<Attribute> numThreads = layout.getDevicePropertyValue(
2826
Builder(ctx).getStringAttr("CPU" /* device ID*/),
2927
Builder(ctx).getStringAttr("num_threads"));
30-
if (numThreads && dyn_cast<IntegerAttr>(*numThreads)) {
31-
return cast<IntegerAttr>(*numThreads).getInt();
28+
if (numThreads && isa<IntegerAttr>(*numThreads)) {
29+
return dyn_cast<IntegerAttr>(*numThreads).getInt();
3230
}
3331
return 1;
3432
}
3533
// get cache size by cacheLevel
3634
size_t getCacheSize(uint8_t cacheLevel) {
37-
DataLayout layout = DataLayout(module);
38-
MLIRContext *ctx = module->getContext();
39-
4035
if (cacheLevel == 1) {
4136
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
4237
Builder(ctx).getStringAttr("CPU" /* device ID*/),
4338
Builder(ctx).getStringAttr("L1_cache_size_in_bytes"));
44-
if (cacheSize && dyn_cast<IntegerAttr>(*cacheSize)) {
45-
return cast<IntegerAttr>(*cacheSize).getInt();
39+
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
40+
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
4641
}
4742
} else if (cacheLevel == 2) {
4843
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
4944
Builder(ctx).getStringAttr("CPU" /* device ID*/),
5045
Builder(ctx).getStringAttr("L2_cache_size_in_bytes"));
51-
if (cacheSize && dyn_cast<IntegerAttr>(*cacheSize)) {
52-
return cast<IntegerAttr>(*cacheSize).getInt();
46+
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
47+
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
5348
}
5449
} else if (cacheLevel == 3) {
5550
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
5651
Builder(ctx).getStringAttr("CPU" /* device ID*/),
5752
Builder(ctx).getStringAttr("L3_cache_size_in_bytes"));
58-
if (cacheSize && dyn_cast<IntegerAttr>(*cacheSize)) {
59-
return cast<IntegerAttr>(*cacheSize).getInt();
53+
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
54+
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
6055
}
6156
}
6257
return 0;
6358
}
6459

6560
// get the maximum vector length in bits
6661
size_t getMaxVectorLength() {
67-
DataLayout layout = DataLayout(module);
68-
MLIRContext *ctx = module->getContext();
6962
std::optional<Attribute> maxVectorLength = layout.getDevicePropertyValue(
7063
Builder(ctx).getStringAttr("CPU" /* device ID*/),
7164
Builder(ctx).getStringAttr("max_vector_width"));
72-
if (maxVectorLength && dyn_cast<IntegerAttr>(*maxVectorLength)) {
73-
return cast<IntegerAttr>(*maxVectorLength).getInt();
65+
if (maxVectorLength && isa<IntegerAttr>(*maxVectorLength)) {
66+
return dyn_cast<IntegerAttr>(*maxVectorLength).getInt();
7467
}
7568
return 512;
7669
}
7770

78-
SystemDesc(ModuleOp m) : module(m) {}
71+
SystemDesc(ModuleOp m) : layout(m), ctx(m->getContext()) {}
7972

8073
private:
81-
ModuleOp module;
74+
DataLayout layout;
75+
MLIRContext *ctx;
8276
};
8377

8478
// The configuration for matmul tiling

include/gc/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
add_subdirectory(Dialect)
2-
add_subdirectory(Transforms)
2+
add_subdirectory(Transforms)

include/gc/Dialect/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
add_subdirectory(CPURuntime)
22
add_subdirectory(OneDNNGraph)
33
add_subdirectory(Microkernel)
4-
add_subdirectory(Linalgx)
4+
add_subdirectory(Linalgx)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
set(LLVM_TARGET_DEFINITIONS MicrokernelEnum.td)
2+
mlir_tablegen(MicrokernelEnum.h.inc -gen-enum-decls)
3+
mlir_tablegen(MicrokernelEnum.cpp.inc -gen-enum-defs)
4+
add_public_tablegen_target(MLIRMicrokernelAttrDefIncGen)
5+
16
add_mlir_dialect(MicrokernelOps microkernel)
27
add_mlir_doc(MicrokernelOps MicrokernelOps gc/Dialect/Microkernel/ -gen-op-doc)
38
add_mlir_doc(MicrokernelDialect MicrokernelDialect gc/Dialect/Microkernel/ -gen-dialect-doc)

include/gc/Dialect/Microkernel/MicrokernelDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define GC_DIALECTS_MICROKERNELDIALECT_H
1111

1212
#include "mlir/IR/Dialect.h"
13+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
1314

1415
#include "gc/Dialect/Microkernel/MicrokernelOpsDialect.h.inc"
1516

include/gc/Dialect/Microkernel/MicrokernelDialect.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,20 @@ include "mlir/IR/OpBase.td"
1515
// Microkernel dialect definition.
1616
//===----------------------------------------------------------------------===//
1717

18-
def MicrokernelDialect : Dialect {
18+
def Microkernel_Dialect : Dialect {
1919
let name = "microkernel";
2020
let summary = "A dialect for microkernel abstraction.";
2121
let description = [{
22-
The dialect wraps the BRGEMM API to set up the HW context etc.
22+
This dialect contains wrappers for microkernel primitives like BRGEMM.
2323
}];
2424
let cppNamespace = "::mlir::microkernel";
25-
26-
let useDefaultTypePrinterParser = 1;
2725
}
2826

27+
//===----------------------------------------------------------------------===//
28+
// Base microkernel operation definition.
29+
//===----------------------------------------------------------------------===//
30+
31+
class Microkernel_Op<string mnemonic, list<Trait> traits = []> :
32+
Op<Microkernel_Dialect, mnemonic, traits>;
33+
2934
#endif // MICROKERNEL_DIALECT
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//===- MicrokernelEnum.h - microkernel dialect enums ------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef GC_DIALECTS_MICROKERNELENUM_H
10+
#define GC_DIALECTS_MICROKERNELENUM_H
11+
12+
#include "mlir/IR/Attributes.h"
13+
#include "mlir/IR/DialectImplementation.h"
14+
15+
#define GET_ATTRDEF_CLASSES
16+
#include "gc/Dialect/Microkernel/MicrokernelEnum.h.inc"
17+
18+
#endif // GC_DIALECTS_MICROKERNELENUM_H
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- MicrokernelEnum.td - microkernel dialect enum -------*- tablegen -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MICROKERNEL_ENUM
10+
#define MICROKERNEL_ENUM
11+
12+
include "mlir/IR/EnumAttr.td"
13+
include "gc/Dialect/Microkernel/MicrokernelDialect.td"
14+
15+
def Microkernel_BrgemmFlags : I64EnumAttr<
16+
"BrgemmFlags", "Flags for indicating optional behaviours of Brgemm",
17+
[
18+
I64EnumAttrCase<"NONE", 0, "none">,
19+
I64EnumAttrCase<"BETA_0", 1, "beta_0">,
20+
I64EnumAttrCase<"STRIDE", 2, "stride">,
21+
I64EnumAttrCase<"LIST", 4, "list">
22+
]> {
23+
let cppNamespace = "::mlir::microkernel";
24+
}
25+
26+
#endif // MICROKERNEL_ENUM

include/gc/Dialect/Microkernel/MicrokernelOps.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,16 @@
99
#ifndef GC_DIALECTS_MICROKERNELOPS_H
1010
#define GC_DIALECTS_MICROKERNELOPS_H
1111

12+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
13+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14+
#include "mlir/Dialect/SCF/IR/SCF.h"
15+
#include "mlir/IR/BuiltinTypes.h"
16+
#include "mlir/IR/Dialect.h"
1217
#include "mlir/IR/OpDefinition.h"
18+
#include "mlir/Interfaces/SideEffectInterfaces.h"
19+
20+
#include "gc/Dialect/Microkernel/MicrokernelDialect.h"
21+
#include "gc/Dialect/Microkernel/MicrokernelEnum.h"
1322

1423
#define GET_OP_CLASSES
1524
#include "gc/Dialect/Microkernel/MicrokernelOps.h.inc"

include/gc/Dialect/Microkernel/MicrokernelOps.td

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,107 @@
1010
#define MICROKERNEL_OPS
1111

1212
include "MicrokernelDialect.td"
13+
include "gc/Dialect/Microkernel/MicrokernelEnum.td"
14+
include "mlir/Interfaces/SideEffectInterfaces.td"
1315

14-
#endif // MICROKERNEL_OPS
16+
class StaticMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
17+
Type<And<[MemRefOf<allowedTypes>.predicate,
18+
HasAnyRankOfPred<ranks>, HasStaticShapePred]>,
19+
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " static " #
20+
MemRefOf<allowedTypes>.summary, "::mlir::MemRefType">;
21+
22+
def Microkernel_BrgemmDispatchOp : Microkernel_Op<"brgemm.dispatch", [Pure]> {
23+
let summary = "JIT the brgemm microkernel given the parameters";
24+
let description = [{
25+
The operation has the following arguments: 1) m, n, k, lda, ldb, ldc, stride_a and stride_b.
26+
Inputs is a dense attribute of I64 elements. 2) flags carry information on
27+
the different flags that can be used for brgemm like whether beta == 0 or strided batch. For
28+
more details, see: `Microkernel_BrgemmFlags`. 3) data_types of operand A & B.
29+
Outpus is the id of JITed kernel.
30+
}];
31+
32+
let arguments = (ins
33+
ConfinedAttr<DenseI64ArrayAttr,
34+
[DenseArrayNonNegative<DenseI64ArrayAttr>]>:$inputs,
35+
TypedArrayAttrBase<Microkernel_BrgemmFlags, "brgemm flags">:$flags,
36+
TypedArrayAttrBase<TypeAttr, "brgemm dtypes">:$data_type);
37+
38+
let results = (outs I64:$results);
39+
let hasCustomAssemblyFormat = 1;
40+
let hasVerifier = 1;
41+
}
42+
43+
def Microkernel_BrgemmPrologueOp : Microkernel_Op<"brgemm.prologue"> {
44+
let summary = "Prologue before executing the JITed brgemm "
45+
"microkernel, and the context is considered core-level";
46+
let description = [{
47+
The operation has the following arguments: Input is the id of JITed kernel.
48+
There is no output.
49+
}];
50+
51+
let arguments = (ins I64:$inputs);
52+
53+
let assemblyFormat = [{
54+
`(` $inputs `)`
55+
attr-dict `:` functional-type($inputs, results)
56+
}];
57+
}
58+
59+
def Microkernel_BrgemmEpilogueOp : Microkernel_Op<"brgemm.epilogue"> {
60+
let summary = "Epilogue after executing the JITed brgemm microkernel";
61+
let description = [{
62+
The operation has the following arguments: Input is the id of JITed kernel.
63+
There is no output.
64+
}];
65+
66+
let arguments = (ins I64:$inputs);
67+
68+
let assemblyFormat = [{
69+
`(` $inputs `)`
70+
attr-dict `:` functional-type($inputs, results)
71+
}];
72+
}
73+
74+
/* A generic input type of Microkernel_BrgemmOp, allowing for `BrgemmMemRef` and I64.
75+
* The `BrgemmMemRef` should be a static MemRef, and for each operand its shape should be:
76+
* Operand A: StaticMemRefRankOf<[F32, BF16, SI8, UI8], [3]>;
77+
* Operand B (none-VNNI): StaticMemRefRankOf<[F32], [3]>;
78+
* Operand B (VNNI): StaticMemRefRankOf<[BF16, SI8, UI8], [4]>;
79+
* Operand C: StaticMemRefRankOf<[F32, SI32], [2]>;
80+
*/
81+
def BrgemmMemRefOrI64 : AnyTypeOf<[StaticMemRefRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>, I64]>;
82+
83+
def Microkernel_BrgemmOp : Microkernel_Op<"brgemm"> {
84+
let summary = "execute the JITed brgemm kernel.";
85+
let description = [{
86+
The operation has the following arguments:
87+
1) For stride mode, id of JITed kernel, MemRef of operand A/B/C, and the batch size;
88+
2) For addr mode, plus the length of addr list at the end.
89+
There is no output.
90+
}];
91+
92+
let arguments = (ins Variadic<BrgemmMemRefOrI64>:$inputs);
93+
94+
let assemblyFormat = [{
95+
`(` $inputs `)`
96+
attr-dict `:` functional-type($inputs, results)
97+
}];
98+
99+
let extraClassDeclaration = [{
100+
Value getDispatch() { return getInputs()[0]; }
101+
102+
Value getOperandA() { return getInputs()[1]; }
103+
104+
Value getOperandB() { return getInputs()[2]; }
105+
106+
Value getOutput() { return getInputs()[3]; }
107+
108+
Value getBatch() { return getInputs()[4]; }
109+
110+
Value getAddrLen() { return getInputs()[5]; }
111+
}];
112+
113+
let hasVerifier = 1;
114+
}
115+
116+
#endif // MICROKERNEL_OPS

include/gc/Transforms/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
2121
];
2222
}
2323

24-
24+
#ifdef GC_USE_GPU
2525
def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
2626
let summary = "Convert linalg dialect to XeGPU dialect.";
2727
let description = [{
@@ -46,6 +46,7 @@ def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
4646
"DPAS register block sizes MxNxK">,
4747
];
4848
}
49+
#endif
4950

5051
def DeepTileContractionNamedOp
5152
: Pass<"deep-tile-contraction-named-op", "func::FuncOp"> {
@@ -73,7 +74,6 @@ def MergeNestedForall : Pass<"merge-nested-forall"> {
7374
let dependentDialects = ["scf::SCFDialect"];
7475
}
7576

76-
7777
def PropagateLayoutOnNamedOps : Pass<"propagate-layout-on-named-ops"> {
7878
let summary = "Insert and propagte tensor.pack to pack the computation of linalg named ops and tensor ops.";
7979
let description = [{

0 commit comments

Comments
 (0)