Skip to content

Commit fe6f573

Browse files
committed
update matmul config analysis
1 parent b21ef81 commit fe6f573

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

include/gc/Analysis/GlobalAnalysis.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,15 @@ class OperatorLayout {
112112
return supportedOutputLayouts[idx];
113113
}
114114

115+
bool isPlain() const {
116+
for (const auto &layout : llvm::concat<const TensorLayout>(
117+
supportedInputLayouts, supportedOutputLayouts)) {
118+
if (!layout.isPlainLayout())
119+
return false;
120+
}
121+
return true;
122+
}
123+
115124
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
116125
const OperatorLayout &opLayout);
117126

lib/gc/Analysis/GlobalAnalysis.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <memory>
99

1010
#include "gc/Analysis/GlobalAnalysis.h"
11+
#include "gc/Analysis/MatmulConfigAnalysis.h"
1112

1213
namespace mlir {
1314
namespace gc {
@@ -176,25 +177,26 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
176177
IRRewriter rewriter(linalgOp);
177178
if (mlir::linalg::isaContractionOpInterface(linalgOp)) {
178179
// query the cost model
179-
// OperatorLayout suggestedLayout = costModel->queryLayout(linalgOp,
180-
// curInputLayouts);
181-
182-
// hardcode one for now
183-
// A side layoutCache, [0, 1, 0, 1]; {32, 32}
180+
MatmulConfig cfg =
181+
MatmulConfigAnalysis(linalgOp.getOperation()).getConfig();
182+
uint32_t iim = cfg.innerMostKBlock, iin = cfg.innerMostNBlock,
183+
iik = cfg.innerMostKBlock;
184+
// hardcode outer axis for now
185+
// A side layoutCache, [0, 1, 0, 1]; {iim, iik}
184186
TensorLayout A_layout(
185187
{0, 1}, {0, 1},
186-
SmallVector<OpFoldResult>{rewriter.getIndexAttr(32),
187-
rewriter.getIndexAttr(32)});
188-
// B side layoutCache, [1, 0, 0, 1]; {32, 32}
188+
SmallVector<OpFoldResult>{rewriter.getIndexAttr(iim),
189+
rewriter.getIndexAttr(iik)});
190+
// B side layoutCache, [1, 0, 0, 1]; {iik, iin}
189191
TensorLayout B_layout(
190192
{1, 0}, {0, 1},
191-
SmallVector<OpFoldResult>{rewriter.getIndexAttr(32),
192-
rewriter.getIndexAttr(32)});
193-
// C side layoutCache, [0, 1, 0, 1]; {32, 32}
193+
SmallVector<OpFoldResult>{rewriter.getIndexAttr(iik),
194+
rewriter.getIndexAttr(iin)});
195+
// C side layoutCache, [0, 1, 0, 1]; {iim, iin}
194196
TensorLayout C_layout(
195197
{0, 1}, {0, 1},
196-
SmallVector<OpFoldResult>{rewriter.getIndexAttr(32),
197-
rewriter.getIndexAttr(32)});
198+
SmallVector<OpFoldResult>{rewriter.getIndexAttr(iim),
199+
rewriter.getIndexAttr(iin)});
198200
OperatorLayout suggestedLayout({A_layout, B_layout}, {C_layout});
199201
layoutCache[linalgOp] = suggestedLayout;
200202
} else {

0 commit comments

Comments
 (0)