|
8 | 8 | #include <memory>
|
9 | 9 |
|
10 | 10 | #include "gc/Analysis/GlobalAnalysis.h"
|
| 11 | +#include "gc/Analysis/MatmulConfigAnalysis.h" |
11 | 12 |
|
12 | 13 | namespace mlir {
|
13 | 14 | namespace gc {
|
@@ -176,25 +177,26 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
|
176 | 177 | IRRewriter rewriter(linalgOp);
|
177 | 178 | if (mlir::linalg::isaContractionOpInterface(linalgOp)) {
|
178 | 179 | // query the cost model
|
179 |
| - // OperatorLayout suggestedLayout = costModel->queryLayout(linalgOp, |
180 |
| - // curInputLayouts); |
181 |
| - |
182 |
| - // hardcode one for now |
183 |
| - // A side layoutCache, [0, 1, 0, 1]; {32, 32} |
| 180 | + MatmulConfig cfg = |
| 181 | + MatmulConfigAnalysis(linalgOp.getOperation()).getConfig(); |
| 182 | + uint32_t iim = cfg.innerMostKBlock, iin = cfg.innerMostNBlock, |
| 183 | + iik = cfg.innerMostKBlock; |
| 184 | + // hardcode outer axis for now |
| 185 | + // A side layoutCache, [0, 1, 0, 1]; {iim, iik} |
184 | 186 | TensorLayout A_layout(
|
185 | 187 | {0, 1}, {0, 1},
|
186 |
| - SmallVector<OpFoldResult>{rewriter.getIndexAttr(32), |
187 |
| - rewriter.getIndexAttr(32)}); |
188 |
| - // B side layoutCache, [1, 0, 0, 1]; {32, 32} |
| 188 | + SmallVector<OpFoldResult>{rewriter.getIndexAttr(iim), |
| 189 | + rewriter.getIndexAttr(iik)}); |
| 190 | + // B side layoutCache, [1, 0, 0, 1]; {iik, iin} |
189 | 191 | TensorLayout B_layout(
|
190 | 192 | {1, 0}, {0, 1},
|
191 |
| - SmallVector<OpFoldResult>{rewriter.getIndexAttr(32), |
192 |
| - rewriter.getIndexAttr(32)}); |
193 |
| - // C side layoutCache, [0, 1, 0, 1]; {32, 32} |
| 193 | + SmallVector<OpFoldResult>{rewriter.getIndexAttr(iik), |
| 194 | + rewriter.getIndexAttr(iin)}); |
| 195 | + // C side layoutCache, [0, 1, 0, 1]; {iim, iin} |
194 | 196 | TensorLayout C_layout(
|
195 | 197 | {0, 1}, {0, 1},
|
196 |
| - SmallVector<OpFoldResult>{rewriter.getIndexAttr(32), |
197 |
| - rewriter.getIndexAttr(32)}); |
| 198 | + SmallVector<OpFoldResult>{rewriter.getIndexAttr(iim), |
| 199 | + rewriter.getIndexAttr(iin)}); |
198 | 200 | OperatorLayout suggestedLayout({A_layout, B_layout}, {C_layout});
|
199 | 201 | layoutCache[linalgOp] = suggestedLayout;
|
200 | 202 | } else {
|
|
0 commit comments