Skip to content

Commit a9c4ff3

Browse files
committed
init layout propagation
1 parent 0225874 commit a9c4ff3

File tree

8 files changed

+468
-1
lines changed

8 files changed

+468
-1
lines changed

include/gc/Analysis/GlobalAnalysis.h

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*******************************************************************************
2+
* Copyright 2024 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
#ifndef MLIR_ANALYSIS_GLOBALANALYSIS_H
18+
#define MLIR_ANALYSIS_GLOBALANALYSIS_H
19+
20+
#include <iostream>
21+
#include <memory>
22+
#include <numeric>
23+
24+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
25+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
26+
#include "mlir/Support/LLVM.h"
27+
#include "llvm/ADT/DenseMap.h"
28+
29+
namespace mlir {
30+
namespace gc {
31+
32+
using namespace mlir;
33+
34+
class TensorLayout {
35+
public:
36+
TensorLayout(ArrayRef<int64_t> outerAxis, ArrayRef<int64_t> innerAxis,
37+
ArrayRef<int64_t> tileSizes) {
38+
assert(innerAxis.size() == tileSizes.size());
39+
for (auto oa : outerAxis) {
40+
OuterAxis.push_back(oa);
41+
}
42+
for (auto ia : innerAxis) {
43+
InnerAxis.push_back(ia);
44+
}
45+
for (auto ts : tileSizes) {
46+
TileSizes.push_back(ts);
47+
}
48+
}
49+
50+
bool isPlainLayout() const {
51+
for (int64_t i = 0; i < static_cast<int64_t>(OuterAxis.size()); ++i) {
52+
if (i != OuterAxis[i])
53+
return false;
54+
}
55+
return TileSizes.empty() && InnerAxis.empty();
56+
}
57+
58+
static TensorLayout createPlainLayout(int64_t rank) {
59+
SmallVector<int64_t> outerAxis(rank, 0);
60+
std::iota(outerAxis.begin(), outerAxis.end(), 0);
61+
return TensorLayout(outerAxis, SmallVector<int64_t>{},
62+
SmallVector<int64_t>{});
63+
}
64+
65+
size_t getTensorRank() const { return OuterAxis.size(); }
66+
67+
SmallVector<int64_t> getOuterAxis() const { return OuterAxis; }
68+
69+
SmallVector<int64_t> getInnerAxis() const { return InnerAxis; }
70+
71+
SmallVector<int64_t> getTileSizes() const { return TileSizes; }
72+
73+
friend std::ostream &operator<<(std::ostream &ss, const TensorLayout &layout);
74+
75+
bool operator==(const TensorLayout &layout);
76+
77+
private:
78+
SmallVector<int64_t> OuterAxis;
79+
SmallVector<int64_t> InnerAxis;
80+
SmallVector<int64_t> TileSizes;
81+
};
82+
83+
class OperatorLayout {
84+
public:
85+
OperatorLayout() {}
86+
87+
OperatorLayout(SmallVector<TensorLayout> inputLayouts,
88+
SmallVector<TensorLayout> outputLayouts) {
89+
supportedInputLayouts = inputLayouts;
90+
supportedOutputLayouts = outputLayouts;
91+
}
92+
93+
SmallVector<TensorLayout> getSupportedInputLayouts() const {
94+
return supportedInputLayouts;
95+
}
96+
97+
SmallVector<TensorLayout> getSupportedOutputLayouts() const {
98+
return supportedOutputLayouts;
99+
}
100+
101+
TensorLayout getOutputLayout(int64_t idx) const {
102+
assert(idx < static_cast<int64_t>(supportedOutputLayouts.size()));
103+
return supportedOutputLayouts[idx];
104+
}
105+
106+
friend std::ostream &operator<<(std::ostream &ss,
107+
const OperatorLayout &opLayout);
108+
109+
private:
110+
SmallVector<TensorLayout> supportedInputLayouts;
111+
SmallVector<TensorLayout> supportedOutputLayouts;
112+
};
113+
114+
class GlobalAnalysis {
115+
public:
116+
explicit GlobalAnalysis(Operation *root);
117+
118+
FailureOr<OperatorLayout> getOpLayout(Operation *op) {
119+
if (layout.find(op) != layout.end())
120+
return layout[op];
121+
else
122+
return op->emitError("Current op does not have layout information.");
123+
}
124+
125+
private:
126+
DenseMap<Operation *, OperatorLayout> layout;
127+
};
128+
129+
} // namespace gc
130+
} // namespace mlir
131+
132+
#endif

include/gc/Transforms/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,12 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
3131
];
3232
}
3333

34+
def PropagateLayout : Pass<"propagate-layout"> {
35+
let summary = "Insert and propagte tensor.pack to pack the computation of general linalg named ops and tensor ops.";
36+
let description = [{
37+
Insert and propagte tensor.pack
38+
}];
39+
let dependentDialects = ["mlir::tensor::TensorDialect", "mlir::linalg::LinalgDialect"];
40+
}
41+
3442
#endif // GC_DIALECT_GC_PASSES

lib/gc/Analysis/CMakeLists.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
2+
MLIRIR
3+
MLIRSupport)
4+
5+
add_mlir_library(GCAnalysis
6+
GlobalAnalysis.cpp
7+
8+
ADDITIONAL_HEADER_DIRS
9+
${PROJECT_SOURCE_DIR}/include
10+
11+
DEPENDS
12+
GraphCompilerPassIncGen
13+
14+
LINK_LIBS PUBLIC
15+
${mlir_dialect_libs}
16+
${MLIR_LINK_COMPONENTS}
17+
)
18+
19+
set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS GCAnalysis)

0 commit comments

Comments
 (0)