1
+ // ===-- MatmulConfigAnalysis.h - DESC ---------------------------*- C++ -*-===//
2
+ //
3
+ // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4
+ // See https://llvm.org/LICENSE.txt for license information.
5
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
+ //
7
+ // ===----------------------------------------------------------------------===//
8
+
9
+ #ifndef MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
10
+ #define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
11
+
12
+ #include " gc/Dialect/Linalgx/LinalgxOps.h"
13
+ #include " mlir/Dialect/Linalg/IR/Linalg.h"
14
+ #include " mlir/Dialect/Tensor/IR/Tensor.h"
15
+ #include " mlir/Pass/Pass.h"
16
+ #include " mlir/Support/LLVM.h"
17
+ #include " llvm/ADT/DenseMap.h"
18
+ #include < llvm/Support/Debug.h>
19
+ #include < memory>
20
+ #include < numeric>
21
+
22
+ namespace mlir {
23
+ namespace gc {
24
+
25
+ using namespace mlir ;
26
+
27
+ struct SystemDesc {
28
+ // get runtime OMP_NUM_THREADS
29
+ uint32_t getNumThreads () {
30
+ char *numThreads = getenv (" OMP_NUM_THREADS" );
31
+ if (numThreads) {
32
+ return std::stoi (numThreads);
33
+ }
34
+ return 1 ;
35
+ }
36
+ // get cache size by cacheLevel
37
+ size_t getCacheSize (uint8_t cacheLevel) {
38
+ if (cacheLevel == 1 ) {
39
+ char *cacheSize = getenv (" L1_CACHE_SIZE" );
40
+ if (cacheSize) {
41
+ return std::stoi (cacheSize);
42
+ }
43
+ } else if (cacheLevel == 2 ) {
44
+ char *cacheSize = getenv (" L2_CACHE_SIZE" );
45
+ if (cacheSize) {
46
+ return std::stoi (cacheSize);
47
+ }
48
+ } else if (cacheLevel == 3 ) {
49
+ char *cacheSize = getenv (" L3_CACHE_SIZE" );
50
+ if (cacheSize) {
51
+ return std::stoi (cacheSize);
52
+ }
53
+ }
54
+ return 0 ;
55
+ }
56
+
57
+ SmallVector<size_t > getContractionOperationMaxVectorLength () {
58
+ return {512UL , 512UL };
59
+ }
60
+ };
61
+
62
+ struct MatmulConfig {
63
+ uint32_t MBlock, NBlock, KBlock;
64
+ uint32_t MThreads, NThreads, KThreads;
65
+ uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
66
+ friend llvm::raw_ostream &operator <<(llvm::raw_ostream &ss,
67
+ const MatmulConfig &config);
68
+ };
69
+
70
+ enum DimType { Batch, M, N, K };
71
+
72
+ [[maybe_unused]] static SmallVector<unsigned >
73
+ extractDimTypeIdx (ArrayRef<DimType> tyList, DimType ty) {
74
+ SmallVector<unsigned > idxList;
75
+ for (auto [idx, type] : llvm::enumerate (tyList)) {
76
+ if (type == ty) {
77
+ idxList.push_back (idx);
78
+ }
79
+ }
80
+ return idxList;
81
+ }
82
+
83
+ static FailureOr<SmallVector<SmallVector<DimType>>>
84
+ getOprandDimType (linalg::LinalgOp &linalgOp) {
85
+ if (isa<linalg::MatmulOp>(linalgOp)) {
86
+ return SmallVector<SmallVector<DimType>>{
87
+ SmallVector<DimType>{DimType::M, DimType::K},
88
+ SmallVector<DimType>{DimType::K, DimType::N},
89
+ SmallVector<DimType>{DimType::M, DimType::N}};
90
+ } else if (llvm::isa<linalgx::Mm2DVnniOp>(linalgOp)) {
91
+ return SmallVector<SmallVector<DimType>>{
92
+ SmallVector<DimType>{DimType::M, DimType::K},
93
+ SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
94
+ DimType::K},
95
+ SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
96
+ } else if (llvm::isa<linalgx::Mm4DVnniOp>(linalgOp)) {
97
+ return SmallVector<SmallVector<DimType>>{
98
+ SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K},
99
+ SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
100
+ DimType::K},
101
+ SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
102
+ } else if (llvm::isa<linalg::BatchMatmulOp>(linalgOp)) {
103
+ return SmallVector<SmallVector<DimType>>{
104
+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
105
+ SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
106
+ SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
107
+ }
108
+ return failure ();
109
+ }
110
+
111
+ struct MatmulConfigAnalysis {
112
+ public:
113
+ explicit MatmulConfigAnalysis (Operation *root);
114
+ MatmulConfig getConfig () { return config; }
115
+
116
+ private:
117
+ MatmulConfig config;
118
+ };
119
+
120
+ } // namespace gc
121
+ } // namespace mlir
122
+
123
+ #endif
0 commit comments