Skip to content

Commit 0b9355b

Browse files
committed
simplified to support single nested scf.for
1 parent b401193 commit 0b9355b

File tree

2 files changed

+43
-199
lines changed

2 files changed

+43
-199
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 11 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,38 +1464,12 @@ checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
14641464
/// failure otherwise.
14651465
static FailureOr<OpOperand *> getConsumerFromUses(Value val,
14661466
Block *containingOpBlock) {
1467-
// Step 1. Check that the value has exactly one use excluding `insertSliceOp`
1468-
// or `ParallelInsertSliceOp`.
1469-
OpOperand *operand = nullptr;
1470-
for (auto &use : val.getUses()) {
1471-
Operation *user = use.getOwner();
1472-
if (isa<tensor::ParallelInsertSliceOp>(user))
1473-
continue;
1474-
if (isa<tensor::InsertSliceOp>(user)) {
1475-
// The only one use is expected as dummy extractSliceOp without any uses.
1476-
// For more details, please refer to:
1477-
// https://github.com/llvm/llvm-project/pull/88712#discussion_r1609384470
1478-
if (user->hasOneUse()) {
1479-
if (auto extractOp =
1480-
dyn_cast<tensor::ExtractSliceOp>(*user->getUsers().begin());
1481-
extractOp && extractOp->use_empty()) {
1482-
// Erase dummy extractSliceOp.
1483-
extractOp.erase();
1484-
// DO NOT erase `user` inside iteration of `getUses`.
1485-
user->moveBefore(&containingOpBlock->getOperations().back());
1486-
continue;
1487-
}
1488-
}
1489-
// Otherwise return.
1490-
return failure();
1491-
}
1492-
// Only one valid use expected
1493-
if (operand)
1494-
return failure();
1495-
operand = &use;
1496-
}
1467+
// Step 1. Check that the value has exactly one use.
1468+
if (!llvm::hasSingleElement(val.getUses()))
1469+
return failure();
14971470
// Step 2. Get uses.
1498-
Operation *consumerOp = operand->getOwner();
1471+
OpOperand &operand = (*val.getUses().begin());
1472+
Operation *consumerOp = operand.getOwner();
14991473
// TODO: We have to init result of consumer before scf.for, use
15001474
// DestinationStyleOpInterface to get result shape from init for now.
15011475
// Add support for other op such as op has InferTypeOpInterface.
@@ -1504,7 +1478,7 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
15041478
return failure();
15051479
if (containingOpBlock != consumerOp->getBlock())
15061480
return failure();
1507-
return operand;
1481+
return &operand;
15081482
}
15091483

15101484
/// Recursively find the outer nest loops of given loop(included) while the
@@ -1693,9 +1667,9 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
16931667

16941668
/// Implementation of fusing consumer of a single slice by computing the
16951669
/// slice of the consumer in-place for scf loop.
1696-
static FailureOr<scf::SCFFuseConsumerOfSliceResult>
1697-
tileAndFuseConsumerOfSliceImpl(RewriterBase &rewriter,
1698-
Operation *candidateSliceOp) {
1670+
FailureOr<scf::SCFFuseConsumerOfSliceResult>
1671+
mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
1672+
Operation *candidateSliceOp) {
16991673
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
17001674
candidateSliceOp))
17011675
return failure();
@@ -1742,7 +1716,7 @@ tileAndFuseConsumerOfSliceImpl(RewriterBase &rewriter,
17421716
// 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
17431717
// top-level loop is the outer-most one of these nested loops.
17441718
Operation *oldTopLevelLoop = oldLoopOp;
1745-
SmallVector<LoopLikeOpInterface> oldNestedForOps, newNestedForOps;
1719+
SmallVector<LoopLikeOpInterface> oldNestedForOps;
17461720
if (isInsertSliceOp) {
17471721
oldNestedForOps =
17481722
getOuterNestLoopsWhile(cast<LoopLikeOpInterface>(oldTopLevelLoop),
@@ -1781,6 +1755,7 @@ tileAndFuseConsumerOfSliceImpl(RewriterBase &rewriter,
17811755
// 3.a Create new outer scf loops with new Inits only if nested `scf.for`
17821756
// case was found.
17831757
bool isNestedForOps = isInsertSliceOp && oldNestedForOps.size() > 1;
1758+
SmallVector<LoopLikeOpInterface> newNestedForOps;
17841759
if (isNestedForOps) {
17851760
for (auto &&[index, loopOp] :
17861761
llvm::enumerate(MutableArrayRef(oldNestedForOps).drop_back())) {
@@ -1979,110 +1954,6 @@ tileAndFuseConsumerOfSliceImpl(RewriterBase &rewriter,
19791954
tileAndFuseResult->tiledOps};
19801955
}
19811956

1982-
/// Get the real consumers from candidate InsertSliceOp. E.g
1983-
///
1984-
/// ```
1985-
/// %1 = scf.for
1986-
/// %2 = scf.for
1987-
/// %3 = scf.for
1988-
/// ...
1989-
/// %4 = insert
1990-
/// yield %4
1991-
/// %5 = insert %3
1992-
/// yield %5
1993-
/// yield %2
1994-
/// %6 = consumerOp ins(%1)
1995-
/// ```
1996-
///
1997-
/// @param candidateSliceOp: %4 = insert
1998-
/// @param forwardSlice: in-out parameter populated by forward insertSliceOps
1999-
/// @return OpOperand consumers: %6 = consumerOp ins(%1)
2000-
static FailureOr<SmallVector<OpOperand *>> getRealConsumersFromInsertSliceOp(
2001-
Operation *candidateSliceOp,
2002-
SmallVector<OffsetSizeAndStrideOpInterface> &forwardSlice,
2003-
unsigned curDepth = 0, unsigned maxDepth = 5) {
2004-
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
2005-
candidateSliceOp))
2006-
return failure();
2007-
// Control recursive time in avoid of stack overflow
2008-
if (curDepth > maxDepth)
2009-
return failure();
2010-
2011-
forwardSlice.push_back(
2012-
cast<OffsetSizeAndStrideOpInterface>(candidateSliceOp));
2013-
Value resultOfLoop;
2014-
if (auto sliceOp =
2015-
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
2016-
Value destValue = sliceOp.getDest();
2017-
auto iterArg = cast<BlockArgument>(destValue);
2018-
auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
2019-
if (!forallOp)
2020-
return failure();
2021-
resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
2022-
} else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(candidateSliceOp)) {
2023-
Value resultValue = sliceOp.getResult();
2024-
for (auto &useOperand : resultValue.getUses()) {
2025-
if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
2026-
if (llvm::detail::isPresent(resultOfLoop))
2027-
return failure();
2028-
auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
2029-
if (!forOp)
2030-
return failure();
2031-
resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
2032-
}
2033-
}
2034-
}
2035-
2036-
if (!llvm::detail::isPresent(resultOfLoop))
2037-
return failure();
2038-
2039-
bool traverseUpperLoop;
2040-
do {
2041-
traverseUpperLoop = false;
2042-
for (OpOperand &useOperand : resultOfLoop.getUses()) {
2043-
if (auto sliceOp =
2044-
dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
2045-
return getRealConsumersFromInsertSliceOp(sliceOp, forwardSlice,
2046-
curDepth + 1);
2047-
}
2048-
if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
2049-
// Walk through outer loop.
2050-
auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
2051-
if (!forOp)
2052-
return failure();
2053-
resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
2054-
traverseUpperLoop = true;
2055-
break;
2056-
}
2057-
}
2058-
} while (traverseUpperLoop);
2059-
// Return all operands using result of top level loop.
2060-
return llvm::map_to_vector(resultOfLoop.getUses(),
2061-
[](OpOperand &u) -> OpOperand * { return &u; });
2062-
}
2063-
2064-
/// Fusing real consumer of a single slice even within complex nested loops via
2065-
/// multiple application of `tileAndFuseConsumerOfSliceImpl`.
2066-
FailureOr<scf::SCFFuseConsumerOfSliceResult>
2067-
mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
2068-
Operation *candidateSliceOp) {
2069-
SmallVector<OffsetSizeAndStrideOpInterface> forwardSlice;
2070-
if (failed(getRealConsumersFromInsertSliceOp(candidateSliceOp, forwardSlice)))
2071-
return failure();
2072-
2073-
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResult;
2074-
// Reverse forward slice from outer to inner.
2075-
std::reverse(forwardSlice.begin(), forwardSlice.end());
2076-
// Multiple application of `tileAndFuseConsumerOfSliceImpl`.
2077-
for (auto &sliceOp : forwardSlice) {
2078-
fuseConsumerResult = tileAndFuseConsumerOfSliceImpl(rewriter, sliceOp);
2079-
if (failed(fuseConsumerResult))
2080-
return rewriter.notifyMatchFailure(sliceOp,
2081-
"could not fuse consumer of sliceOp");
2082-
}
2083-
return fuseConsumerResult;
2084-
}
2085-
20861957
//===----------------------------------------------------------------------===//
20871958
// lowerToLoopsUsingSCFForOp implementation.
20881959
//===----------------------------------------------------------------------===//

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 32 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -374,38 +374,27 @@ module attributes {transform.with_named_sequence} {
374374

375375
// -----
376376

377-
#map = affine_map<(d0) -> (d0 * 128)>
378377
module {
379-
func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
378+
func.func @fuse_add_consumer_into_nested_scf_for(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
380379
%c0 = arith.constant 0 : index
381380
%c64 = arith.constant 64 : index
382-
%c128 = arith.constant 128 : index
381+
%c256 = arith.constant 256 : index
383382
%cst = arith.constant 0.000000e+00 : f32
384383
%dest0 = tensor.empty() : tensor<256x256xf32>
385384
%dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
386-
%1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
387-
%iv0 = affine.apply #map(%arg3)
388-
%iv1 = affine.apply #map(%arg4)
389-
%extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
390-
%extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
391-
%extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
392-
%2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
393-
%3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
394-
%extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
395-
%extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
396-
%extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
397-
%4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
398-
%insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
399-
scf.yield %insert_slice : tensor<128x128xf32>
400-
}
401-
scf.yield %3 : tensor<128x128xf32>
402-
}
403-
scf.forall.in_parallel {
404-
tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
385+
%1 = scf.for %arg3 = %c0 to %c256 step %c64 iter_args(%arg4 = %dest1) -> (tensor<256x256xf32>) {
386+
%2 = scf.for %arg5 = %c0 to %c256 step %c64 iter_args(%arg6 = %arg4) -> (tensor<256x256xf32>) {
387+
%extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32>
388+
%extracted_slice_2 = tensor.extract_slice %arg0[%arg3, 0] [64, 512] [1, 1] : tensor<256x512xf32> to tensor<64x512xf32>
389+
%extracted_slice_3 = tensor.extract_slice %arg1[0, %arg5] [512, 64] [1, 1] : tensor<512x256xf32> to tensor<512x64xf32>
390+
%3 = linalg.matmul ins(%extracted_slice_2, %extracted_slice_3 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_1 : tensor<64x64xf32>) -> tensor<64x64xf32>
391+
%insert_slice = tensor.insert_slice %3 into %arg6[%arg3, %arg5] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<256x256xf32>
392+
scf.yield %insert_slice : tensor<256x256xf32>
405393
}
394+
scf.yield %2 : tensor<256x256xf32>
406395
}
407-
%5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
408-
return %5 : tensor<256x256xf32>
396+
%4 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
397+
return %4 : tensor<256x256xf32>
409398
}
410399
}
411400

@@ -418,49 +407,33 @@ module attributes {transform.with_named_sequence} {
418407
transform.yield
419408
}
420409
}
421-
// CHECK: #[[MAP0:.*]] = affine_map<(d0) -> (d0 * 128)>
422-
// CHECK: func.func @fuse_tilable_consumer_nested_scf_loop(
410+
// CHECK: func.func @fuse_add_consumer_into_nested_scf_for(
423411
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
424412
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
425413
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
426414
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
427415
// CHECK: %[[dest1:.*]] = linalg.fill
428416
// CHECK-SAME: outs(%[[dest0]] :
429-
// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
430-
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG0:.*]] = %[[dest1]], %[[SECOND_OUT_ARG0:.*]] = %[[dest0]])
417+
// CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV1:.*]] = %[[C0]]
418+
// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[dest1]], %[[SECOND_OUT_ARG1:.*]] = %[[dest0]])
431419
// CHECK-SAME: {
432-
// CHECK: %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
433-
// CHECK: %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
434-
// CHECK: %[[MAT_OUT_SLICE0:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
435-
// CHECK: %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
436-
// CHECK: %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
437-
// CHECK: %[[ADD_OPERAND2_SLICE0:.*]] = tensor.extract_slice %[[ARG2]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
438-
// CHECK: %[[ADD_OUT_SLICE0:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
439-
// CHECK: %[[LOOP_RESULT1:.*]]:2 = scf.for %[[IV3:.*]] = %[[C0]]
440-
// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG1:.*]] = %[[MAT_OUT_SLICE0]], %[[SECOND_OUT_ARG1:.*]] = %[[ADD_OUT_SLICE0]])
441-
// CHECK-SAME: {
442-
// CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV4:.*]] = %[[C0]]
443-
// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
444-
// CHECK-SAME: {
445-
// CHECK: %[[MAT_OUT_SLICE1:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
446-
// CHECK: %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
447-
// CHECK: %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
420+
// CHECK: %[[LOOP_RESULT2:.*]]:2 = scf.for %[[IV2:.*]] = %[[C0]]
421+
// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG2:.*]] = %[[FIRST_OUT_ARG1]], %[[SECOND_OUT_ARG2:.*]] = %[[SECOND_OUT_ARG1]])
422+
// CHECK-SAME: {
423+
// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
424+
// CHECK: %[[INPUT_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 512] [1, 1]
425+
// CHECK: %[[WEIGHT_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[IV2]]] [512, 64] [1, 1]
448426
// CHECK: %[[TILED_MAT_OUT:.*]] = linalg.matmul
449-
// CHECK-SAME: outs(%[[MAT_OUT_SLICE1]] :
450-
// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
451-
// CHECK: %[[ADD_OPERAND2_SLICE1:.*]] = tensor.extract_slice %[[ADD_OPERAND2_SLICE0]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
452-
// CHECK: %[[ADD_OUT_SLICE1:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
427+
// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] :
428+
// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[FIRST_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
429+
// CHECK: %[[ADD_OPERAND2_SLICE:.*]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
430+
// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
453431
// CHECK: %[[TILED_ADD_OUT:.*]] = linalg.add
454-
// CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE1]] :
455-
// CHECK-SAME: outs(%[[ADD_OUT_SLICE1]] :
456-
// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
432+
// CHECK-SAME: ins(%[[TILED_MAT_OUT]], %[[ADD_OPERAND2_SLICE]] :
433+
// CHECK-SAME: outs(%[[ADD_OUT_SLICE]] :
434+
// CHECK: %[[INSERT_ADD:.*]] = tensor.insert_slice %[[TILED_ADD_OUT]] into %[[SECOND_OUT_ARG2]][%[[IV1]], %[[IV2]]] [64, 64] [1, 1]
457435
// CHECK: scf.yield %[[INSERT_MAT]], %[[INSERT_ADD]] :
458-
// CHECK: }
459-
// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
460-
// CHECK: }
461-
// CHECK: scf.forall.in_parallel {
462-
// CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]]#1 into %[[SECOND_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
463-
// CHECK: tensor.parallel_insert_slice %[[LOOP_RESULT1]]#0 into %[[FIRST_OUT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
464-
// CHECK: }
436+
// CHECK: }
437+
// CHECK: scf.yield %[[LOOP_RESULT2]]#0, %[[LOOP_RESULT2]]#1 :
465438
// CHECK: }
466-
// CHECK: return %[[FINAL_RESULT]]#1 :
439+
// CHECK: return %[[LOOP_RESULT1]]#1 :

0 commit comments

Comments
 (0)