Skip to content

Commit b190ecb

Browse files
committed
enhance config
1 parent 37c7f67 commit b190ecb

File tree

7 files changed

+570
-129
lines changed

7 files changed

+570
-129
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ endif()
9696

9797
set(GC_LIB_LINKED_LIBS
9898
GCPasses
99+
GCAnalysis
99100
MLIROneDNNGraph
100101
)
101102
add_library(graph_compiler SHARED ${GC_LIB_SOURCES})
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
//===-- MatmulConfigAnalysis.h - DESC ---------------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
10+
#define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H
11+
12+
#include "gc/Dialect/Linalgx/LinalgxOps.h"
13+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
14+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
15+
#include "mlir/Pass/Pass.h"
16+
#include "mlir/Support/LLVM.h"
17+
#include "llvm/ADT/DenseMap.h"
18+
#include <llvm/Support/Debug.h>
19+
#include <memory>
20+
#include <numeric>
21+
22+
namespace mlir {
23+
namespace gc {
24+
25+
using namespace mlir;
26+
27+
struct SystemDesc {
28+
// get runtime OMP_NUM_THREADS
29+
uint32_t getNumThreads() {
30+
char *numThreads = getenv("OMP_NUM_THREADS");
31+
if (numThreads) {
32+
return std::stoi(numThreads);
33+
}
34+
return 1;
35+
}
36+
// get cache size by cacheLevel
37+
size_t getCacheSize(uint8_t cacheLevel) {
38+
if (cacheLevel == 1) {
39+
char *cacheSize = getenv("L1_CACHE_SIZE");
40+
if (cacheSize) {
41+
return std::stoi(cacheSize);
42+
}
43+
} else if (cacheLevel == 2) {
44+
char *cacheSize = getenv("L2_CACHE_SIZE");
45+
if (cacheSize) {
46+
return std::stoi(cacheSize);
47+
}
48+
} else if (cacheLevel == 3) {
49+
char *cacheSize = getenv("L3_CACHE_SIZE");
50+
if (cacheSize) {
51+
return std::stoi(cacheSize);
52+
}
53+
}
54+
return 0;
55+
}
56+
57+
SmallVector<size_t> getContractionOperationMaxVectorLength() {
58+
return {512UL, 512UL};
59+
}
60+
};
61+
62+
struct MatmulConfig {
63+
uint32_t MBlock, NBlock, KBlock;
64+
uint32_t MThreads, NThreads, KThreads;
65+
uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
66+
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
67+
const MatmulConfig &config);
68+
};
69+
70+
enum DimType { Batch, M, N, K };
71+
72+
[[maybe_unused]] static SmallVector<unsigned>
73+
extractDimTypeIdx(ArrayRef<DimType> tyList, DimType ty) {
74+
SmallVector<unsigned> idxList;
75+
for (auto [idx, type] : llvm::enumerate(tyList)) {
76+
if (type == ty) {
77+
idxList.push_back(idx);
78+
}
79+
}
80+
return idxList;
81+
}
82+
83+
static FailureOr<SmallVector<SmallVector<DimType>>>
84+
getOprandDimType(linalg::LinalgOp &linalgOp) {
85+
if (isa<linalg::MatmulOp>(linalgOp)) {
86+
return SmallVector<SmallVector<DimType>>{
87+
SmallVector<DimType>{DimType::M, DimType::K},
88+
SmallVector<DimType>{DimType::K, DimType::N},
89+
SmallVector<DimType>{DimType::M, DimType::N}};
90+
} else if (llvm::isa<linalgx::Mm2DVnniOp>(linalgOp)) {
91+
return SmallVector<SmallVector<DimType>>{
92+
SmallVector<DimType>{DimType::M, DimType::K},
93+
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
94+
DimType::K},
95+
SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
96+
} else if (llvm::isa<linalgx::Mm4DVnniOp>(linalgOp)) {
97+
return SmallVector<SmallVector<DimType>>{
98+
SmallVector<DimType>{DimType::M, DimType::K, DimType::M, DimType::K},
99+
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
100+
DimType::K},
101+
SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
102+
} else if (llvm::isa<linalg::BatchMatmulOp>(linalgOp)) {
103+
return SmallVector<SmallVector<DimType>>{
104+
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
105+
SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
106+
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
107+
}
108+
return failure();
109+
}
110+
111+
struct MatmulConfigAnalysis {
112+
public:
113+
explicit MatmulConfigAnalysis(Operation *root);
114+
MatmulConfig getConfig() { return config; }
115+
116+
private:
117+
MatmulConfig config;
118+
};
119+
120+
} // namespace gc
121+
} // namespace mlir
122+
123+
#endif

lib/gc/Analysis/CMakeLists.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
2+
MLIRIR
3+
MLIRSupport)
4+
5+
add_mlir_library(GCAnalysis
6+
MatmulConfigAnalysis.cpp
7+
8+
ADDITIONAL_HEADER_DIRS
9+
${PROJECT_SOURCE_DIR}/include
10+
11+
DEPENDS
12+
GraphCompilerPassIncGen
13+
14+
LINK_LIBS PUBLIC
15+
${mlir_dialect_libs}
16+
${MLIR_LINK_COMPONENTS}
17+
)
18+
19+
set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS GCAnalysis)

0 commit comments

Comments
 (0)