@@ -44,13 +44,15 @@ bool TensorLayout::operator==(const TensorLayout &layout) {
44
44
45
45
llvm::raw_ostream &operator <<(llvm::raw_ostream &ss,
46
46
const OperatorLayout &opLayout) {
47
- for (auto &&[idx, layoutCache] :
48
- llvm::enumerate (opLayout.getSupportedInputLayouts ())) {
49
- ss << " input " << idx << " 's layout: " << layoutCache << " \n " ;
47
+ if (!opLayout.getSupportedInputLayouts ().empty ()) {
48
+ ss << " Input layouts: " ;
49
+ llvm::interleave (opLayout.getSupportedInputLayouts (), ss, " ; " );
50
+ ss << " . " ;
50
51
}
51
- for (auto &&[idx, layoutCache] :
52
- llvm::enumerate (opLayout.getSupportedOutputLayouts ())) {
53
- ss << " output " << idx << " 's layout: " << layoutCache << " \n " ;
52
+ if (!opLayout.getSupportedOutputLayouts ().empty ()) {
53
+ ss << " Output layouts: " ;
54
+ llvm::interleave (opLayout.getSupportedOutputLayouts (), ss, " ; " );
55
+ ss << " . " ;
54
56
}
55
57
return ss;
56
58
}
@@ -217,8 +219,6 @@ static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
217
219
GlobalAnalysis::GlobalAnalysis (Operation *root) {
218
220
root->walk ([&](Operation *op) {
219
221
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
220
- LLVM_DEBUG (llvm::dbgs ()
221
- << " Inferring layout of op: " << op->getName () << " \n " );
222
222
auto curInputs = linalgOp.getDpsInputOperands ();
223
223
auto curResults = linalgOp.getOperation ()->getResults ();
224
224
// ---------------- Get Current Input Layouts -------------------
@@ -277,8 +277,11 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
277
277
rewriter.getIndexAttr (iin)});
278
278
OperatorLayout suggestedLayout ({ALayout, BLayout}, {CLayout});
279
279
layoutCache[linalgOp] = suggestedLayout;
280
+ LLVM_DEBUG (llvm::dbgs () << " Inferred layout of op: " << op->getName ()
281
+ << " is: " << suggestedLayout << " \n " );
280
282
} else if (!mlir::linalg::isaContractionOpInterface (linalgOp) &&
281
- !mlir::linalg::isaConvolutionOpInterface (linalgOp) &&
283
+ !isa<linalg::ConvolutionOpInterface>(
284
+ linalgOp.getOperation ()) &&
282
285
!supportedContractionNamedOpList (linalgOp)) {
283
286
// infer layout for non-contraction/non-convolution linalg named ops
284
287
// and linalg generic ops
@@ -311,6 +314,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
311
314
outputLayouts.push_back (outputLayout);
312
315
OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
313
316
layoutCache[linalgOp] = suggestedLayout;
317
+ LLVM_DEBUG (llvm::dbgs () << " Inferred layout of op: " << op->getName ()
318
+ << " is: " << suggestedLayout << " \n " );
314
319
}
315
320
} else if (auto padOp = dyn_cast<tensor::PadOp>(op)) {
316
321
auto inputOperand = padOp.getSource ();
@@ -325,6 +330,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
325
330
outputLayouts{curInputLayout};
326
331
OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
327
332
layoutCache[padOp] = suggestedLayout;
333
+ LLVM_DEBUG (llvm::dbgs () << " Inferred layout of op: " << op->getName ()
334
+ << " is: " << suggestedLayout << " \n " );
328
335
} else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
329
336
SmallVector<ReassociationIndices> reassocIndices =
330
337
expandShapeOp.getReassociationIndices ();
@@ -343,8 +350,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
343
350
ArrayRef<int64_t > innerDimsPos = curInputLayout.getInnerAxis ();
344
351
ArrayRef<int64_t > outerDimsPerm = curInputLayout.getOuterAxis ();
345
352
SmallVector<int64_t > projectedInnerDimsPos =
346
- projectToInnerMostNonUnitDimsPos (curInputLayout. getInnerAxis () ,
347
- reassocIndices, staticOutputShape);
353
+ projectToInnerMostNonUnitDimsPos (innerDimsPos, reassocIndices ,
354
+ staticOutputShape);
348
355
349
356
if (!isDimsDivisibleByTileSizes (projectedInnerDimsPos, staticOutputShape,
350
357
innerTileSizes)) {
@@ -362,6 +369,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
362
369
outputLayouts{outputLayout};
363
370
OperatorLayout suggestedLayout (inputLayouts, outputLayouts);
364
371
layoutCache[expandShapeOp] = suggestedLayout;
372
+ LLVM_DEBUG (llvm::dbgs () << " Inferred layout of op: " << op->getName ()
373
+ << " is: " << suggestedLayout << " \n " );
365
374
}
366
375
return WalkResult::advance ();
367
376
});
0 commit comments