Skip to content

Commit 004a3e0

Browse files
author
Ofri Frishman
committed
[MLIR] Add pattern to bubble up tensor.extract_slice
Add a pattern that bubbles up tensor.extract_slice through tensor.expand_shape, and add a transform op to tensor dialect to directly use this pattern. This pattern enables tiling and fusing op chains which contain tensor.expand_shape if added as a cleanup pattern of tile and fuse utility. Without this pattern that would not be possible, as tensor.expand_shape does not implement the tiling interface. In addition, registering this pattern as a cleanup pattern for transform.structured.fuse. The pattren was first implement in IREE project by Quinn Dawkins and is being upstreamed.
1 parent 73413bd commit 004a3e0

File tree

7 files changed

+414
-0
lines changed

7 files changed

+414
-0
lines changed

mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,16 @@ def ApplyReassociativeReshapeFoldingPatternsOp : Op<Transform_Dialect,
111111
let assemblyFormat = "attr-dict";
112112
}
113113

114+
def ApplyBubbleUpExtractSlicePatternsOp : Op<Transform_Dialect,
115+
"apply_patterns.tensor.bubble_up_extract_slice",
116+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
117+
let description = [{
118+
Indicates that tensor.extract_slice and its producer should swap location.
119+
}];
120+
121+
let assemblyFormat = "attr-dict";
122+
}
123+
114124
def ApplyRewriteTensorOpsAsConstantPatternsOp : Op<Transform_Dialect,
115125
"apply_patterns.tensor.rewrite_as_constant",
116126
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ void populateFoldTensorSubsetIntoVectorTransferPatterns(
5858
void populateMergeConsecutiveInsertExtractSlicePatterns(
5959
RewritePatternSet &patterns);
6060

61+
/// Appends patterns that are used to bubble up tensor.extract slice op above
62+
/// its producer. When used as cleanup patterns of tile and fuse, enables fusing
63+
/// the producer with the consumer even if the producer does not implement the
64+
/// tiling interface.
65+
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
66+
6167
/// Populates `patterns` with patterns that drop redundant tensor.insert_slice
6268
/// rank expansions.
6369
void populateDropRedundantInsertSliceRankExpansionPatterns(

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
582582
RewritePatternSet patterns(context);
583583
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
584584
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
585+
tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
585586
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
586587
}
587588

mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
125125
tensor::populateReassociativeReshapeFoldingPatterns(patterns);
126126
}
127127

128+
void transform::ApplyBubbleUpExtractSlicePatternsOp::populatePatterns(
129+
RewritePatternSet &patterns) {
130+
tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
131+
}
132+
128133
void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
129134
RewritePatternSet &patterns) {
130135
ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) {

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Arith/Utils/Utils.h"
911
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1012
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1113
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
1215
#include "llvm/Support/Debug.h"
1316

1417
using namespace mlir;
@@ -210,6 +213,178 @@ struct BubbleUpExpandThroughParallelCollapse
210213
}
211214
};
212215

216+
/// Converts `tensor.extract_slice(tensor.expand_shape)` to
217+
/// `tensor.expand_shape(tensor.extract_slice)`.
218+
/// For this transformation to be possible, the slice must be fully contiguous
219+
/// within each reassociation group of the expand_shape. If the transformation
220+
/// is not possible, or if the slice is rank reducting, the function returns
221+
/// failure.
222+
///
223+
/// Example:
224+
/// ```
225+
/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
226+
/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
227+
/// %slice = tensor.extract_slice %reshape ...
228+
/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
229+
///
230+
/// // The transformation is possible because each reassociation group has a
231+
/// // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
232+
/// // After the transformation:
233+
///
234+
/// %slice = tensor.extract_slice %in ...
235+
/// tensor<8x16x32xf32> to tensor<8x5x4xf32>
236+
/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
237+
/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
238+
/// ```
239+
///
240+
/// Note - this pattern could be reworked to be a swap pattern between
241+
/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
242+
/// implemented only as a bubble up pattern for `tensor.extract_slice`.
243+
struct BubbleUpExpandShapeThroughExtractSlice
244+
: public OpRewritePattern<tensor::ExtractSliceOp> {
245+
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
246+
247+
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
248+
PatternRewriter &rewriter) const override {
249+
auto expandShapeOp =
250+
sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
251+
if (!expandShapeOp) {
252+
return rewriter.notifyMatchFailure(
253+
sliceOp, "slice source not produced by expand_shape");
254+
}
255+
256+
if (!sliceOp.hasUnitStride()) {
257+
return rewriter.notifyMatchFailure(sliceOp,
258+
"unsupported: non-unit stride");
259+
}
260+
261+
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
262+
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
263+
264+
if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
265+
sizes.size()) {
266+
return rewriter.notifyMatchFailure(sliceOp,
267+
"unimplemented: rank reducing slice");
268+
}
269+
270+
// Helper variables and function for accumulating the new offset and length
271+
// values.
272+
Location loc = expandShapeOp->getLoc();
273+
AffineExpr d0, d1, d2;
274+
bindDims(rewriter.getContext(), d0, d1, d2);
275+
// Multiply two integers.
276+
auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
277+
auto mulMap = AffineMap::get(2, 0, {d0 * d1});
278+
return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
279+
{v1, v2});
280+
};
281+
282+
SmallVector<OpFoldResult> outputShape =
283+
getMixedValues(expandShapeOp.getStaticOutputShape(),
284+
expandShapeOp.getOutputShape(), rewriter);
285+
286+
auto isZeroOffsetAndFullSize =
287+
[](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
288+
if (!isConstantIntValue(offset, 0))
289+
return false;
290+
FailureOr<bool> maybeEqual =
291+
ValueBoundsConstraintSet::areEqual(sliceSize, size);
292+
return llvm::succeeded(maybeEqual) && maybeEqual.value();
293+
};
294+
295+
// First verify that this is a full slice of the expanded tensor.
296+
for (const ReassociationIndices &indices :
297+
expandShapeOp.getReassociationIndices()) {
298+
int64_t i = 0;
299+
int64_t e = indices.size();
300+
// Find the first expanded dim after the first dim with non-unit extracted
301+
// size.
302+
for (; i < e; ++i) {
303+
if (!isConstantIntValue(sizes[indices[i]], 1)) {
304+
// +1 to skip the first non-unit size dim.
305+
i++;
306+
break;
307+
}
308+
}
309+
310+
// Verify that all subsequent dimensions extract the full size of the
311+
// source tensor.
312+
for (; i < e; ++i) {
313+
int64_t expandedDim = indices[i];
314+
if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
315+
outputShape[expandedDim])) {
316+
return rewriter.notifyMatchFailure(
317+
sliceOp, "Not a contiguous slice of the expanded tensor.");
318+
}
319+
}
320+
}
321+
322+
// Compute new offsets, lengths, and strides.
323+
SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
324+
for (const ReassociationIndices &indices :
325+
expandShapeOp.getReassociationIndices()) {
326+
OpFoldResult newSize = rewriter.getIndexAttr(1);
327+
SmallVector<OpFoldResult> basis, delinOffsets;
328+
329+
int64_t i = 0;
330+
int64_t e = indices.size();
331+
// Offset = cumulative product of leading unit extracted dims.
332+
for (; i < e; ++i) {
333+
int64_t expandedDim = indices[i];
334+
if (!isConstantIntValue(sizes[expandedDim], 1))
335+
break;
336+
337+
basis.push_back(outputShape[expandedDim]);
338+
delinOffsets.push_back(offsets[expandedDim]);
339+
}
340+
341+
if (i != e) {
342+
int64_t expandedDim = indices[i];
343+
basis.push_back(outputShape[expandedDim]);
344+
delinOffsets.push_back(offsets[expandedDim]);
345+
newSize = sizes[expandedDim];
346+
i++;
347+
}
348+
349+
for (; i < e; ++i) {
350+
OpFoldResult fullSize = outputShape[indices[i]];
351+
basis.push_back(fullSize);
352+
delinOffsets.push_back(rewriter.getIndexAttr(0));
353+
newSize = mul(newSize, fullSize);
354+
}
355+
SmallVector<Value> offsetVals =
356+
llvm::map_to_vector(delinOffsets, [&](OpFoldResult ofr) {
357+
return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
358+
});
359+
OpFoldResult newOffset =
360+
rewriter
361+
.create<affine::AffineLinearizeIndexOp>(loc, offsetVals, basis,
362+
/*disjoint=*/true)
363+
.getResult();
364+
newOffsets.push_back(newOffset);
365+
newLengths.push_back(newSize);
366+
367+
// Only unit stride supported.
368+
newStrides.push_back(rewriter.getIndexAttr(1));
369+
}
370+
371+
// The shape of the result can be obtained from the sizes passed in.
372+
SmallVector<Value> dynDims;
373+
SmallVector<int64_t> shape;
374+
dispatchIndexOpFoldResults(sizes, dynDims, shape);
375+
RankedTensorType resultType = RankedTensorType::get(
376+
shape, expandShapeOp.getResultType().getElementType());
377+
378+
// Create a new ExtractSliceOp and ExpandShapeOp.
379+
Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
380+
loc, expandShapeOp.getSrc(), newOffsets, newLengths, newStrides);
381+
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
382+
sliceOp, resultType, newSliceOp,
383+
expandShapeOp.getReassociationIndices(), sizes);
384+
return success();
385+
}
386+
};
387+
213388
} // namespace
214389

215390
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
@@ -227,3 +402,8 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
227402
RewritePatternSet &patterns) {
228403
patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
229404
}
405+
406+
void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
407+
RewritePatternSet &patterns) {
408+
patterns.add<BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext());
409+
}

mlir/test/Dialect/Linalg/transform-op-fuse.mlir

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,141 @@ module attributes {transform.with_named_sequence} {
278278
transform.yield
279279
}
280280
}
281+
282+
// -----
283+
284+
// CHECK-LABEL: func.func @swap_expand_shape_with_extract_slice
285+
// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
286+
// CHECK: scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
287+
// CHECK: scf.for %[[Z:[A-Za-z0-9]+]] = {{.*}}
288+
// CHECK: %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[Z]]] by (2, 3, 10)
289+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [5] [1] : tensor<60xf32> to tensor<5xf32>
290+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 1, 5]
291+
// CHECK: linalg.exp ins(%[[EXPAND]]
292+
func.func @swap_expand_shape_with_extract_slice(%0: tensor<60xf32>) -> tensor<2x3x10xf32> {
293+
%expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
294+
%empty = tensor.empty() : tensor<2x3x10xf32>
295+
%exp = linalg.exp ins(%expand : tensor<2x3x10xf32>) outs(%empty : tensor<2x3x10xf32>) -> tensor<2x3x10xf32>
296+
return %exp : tensor<2x3x10xf32>
297+
}
298+
299+
module attributes {transform.with_named_sequence} {
300+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
301+
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
302+
%transformed, %loops:3 = transform.structured.fuse %0 [1, 1, 5] interchange [0, 1, 2] apply_cleanup = true :
303+
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op)
304+
transform.yield
305+
}
306+
}
307+
308+
// -----
309+
310+
// CHECK-LABEL: func.func @swap_expand_shape_with_extract_slice_full_inner_dim
311+
// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
312+
// CHECK: scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
313+
// CHECK: %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]]{{.*}} by (3, 4, 10)
314+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [20] [1] : tensor<120xf32> to tensor<20xf32>
315+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 2, 10]
316+
// CHECK: linalg.exp ins(%[[EXPAND]]
317+
func.func @swap_expand_shape_with_extract_slice_full_inner_dim(%0: tensor<120xf32>) -> tensor<3x4x10xf32> {
318+
%expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [3, 4, 10] : tensor<120xf32> into tensor<3x4x10xf32>
319+
%empty = tensor.empty() : tensor<3x4x10xf32>
320+
%exp = linalg.exp ins(%expand : tensor<3x4x10xf32>) outs(%empty : tensor<3x4x10xf32>) -> tensor<3x4x10xf32>
321+
return %exp : tensor<3x4x10xf32>
322+
}
323+
324+
module attributes {transform.with_named_sequence} {
325+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
326+
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
327+
%transformed, %loops:2 = transform.structured.fuse %0 [1, 2, 0] interchange [0, 1, 2] apply_cleanup = true :
328+
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
329+
transform.yield
330+
}
331+
}
332+
333+
// -----
334+
335+
// CHECK-LABEL: func.func @no_swap_expand_shape_with_extract_slice_non_contiguous
336+
// CHECK: tensor.expand_shape
337+
// CHECK: scf.for
338+
// CHECK: scf.for
339+
// CHECK: scf.for
340+
// CHECK: linalg.exp
341+
func.func @no_swap_expand_shape_with_extract_slice_non_contiguous(%0: tensor<120xf32>) -> tensor<3x4x10xf32> {
342+
%expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [3, 4, 10] : tensor<120xf32> into tensor<3x4x10xf32>
343+
%empty = tensor.empty() : tensor<3x4x10xf32>
344+
%exp = linalg.exp ins(%expand : tensor<3x4x10xf32>) outs(%empty : tensor<3x4x10xf32>) -> tensor<3x4x10xf32>
345+
return %exp : tensor<3x4x10xf32>
346+
}
347+
348+
module attributes {transform.with_named_sequence} {
349+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
350+
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
351+
%transformed, %loops:3 = transform.structured.fuse %0 [1, 2, 5] interchange [0, 1, 2] apply_cleanup = true :
352+
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op)
353+
transform.yield
354+
}
355+
}
356+
357+
// -----
358+
359+
// CHECK-LABEL: func.func @swap_expand_shape_with_extract_slice_multiple_expanded_dims
360+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
361+
// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
362+
// CHECK: scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
363+
// CHECK: scf.for %[[Z:[A-Za-z0-9]+]] = {{.*}}
364+
// CHECK: scf.for %[[W:[A-Za-z0-9]+]] = {{.*}}
365+
// CHECK: %[[LINEAR_IDX0:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[C0]]] by (3, 4, 10)
366+
// CHECK: %[[LINEAR_IDX1:.+]] = affine.linearize_index disjoint [%[[Z]], %[[W]]] by (7, 8)
367+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX0]], %[[LINEAR_IDX1]]] [20, 4] [1, 1] : tensor<120x56xf32> to tensor<20x4xf32>
368+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [1, 2, 10, 1, 4]
369+
// CHECK: linalg.exp ins(%[[EXPAND]]
370+
module {
371+
func.func @swap_expand_shape_with_extract_slice_multiple_expanded_dims(%0: tensor<120x56xf32>) -> tensor<3x4x10x7x8xf32> {
372+
%expand = tensor.expand_shape %0 [[0, 1, 2], [3, 4]] output_shape [3, 4, 10, 7, 8] : tensor<120x56xf32> into tensor<3x4x10x7x8xf32>
373+
%empty = tensor.empty() : tensor<3x4x10x7x8xf32>
374+
%exp = linalg.exp ins(%expand : tensor<3x4x10x7x8xf32>) outs(%empty : tensor<3x4x10x7x8xf32>) -> tensor<3x4x10x7x8xf32>
375+
return %exp : tensor<3x4x10x7x8xf32>
376+
}
377+
}
378+
379+
module attributes {transform.with_named_sequence} {
380+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
381+
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
382+
%transformed, %loops:4 = transform.structured.fuse %0 [1, 2, 0, 1, 4] interchange [0, 1, 2, 3, 4] apply_cleanup = true :
383+
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op, !transform.any_op)
384+
transform.yield
385+
}
386+
}
387+
388+
// -----
389+
390+
// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
391+
// CHECK: %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], {{.*}} by (8, 32)
392+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[0, 0, %[[LINEAR_IDX]]] [1, 1800, 32] [1, 1, 1] : tensor<1x1800x256xf32> to tensor<1x1800x32xf32>
393+
// CHECK: %[[ABS:.+]] = linalg.abs ins(%[[SLICE]]
394+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ABS]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 1800, 1, 32]
395+
// CHECK: linalg.exp ins(%[[EXPAND]]
396+
module {
397+
func.func @swap_expand_shape_with_extract_slice_and_fuse_with_expand_producer(%0: tensor<1x1800x256xf32>) -> tensor<1x1800x8x32xf32> {
398+
%empty1 = tensor.empty() : tensor<1x1800x256xf32>
399+
%exp1 = linalg.abs ins(%0 : tensor<1x1800x256xf32>) outs(%empty1 : tensor<1x1800x256xf32>) -> tensor<1x1800x256xf32>
400+
%expand = tensor.expand_shape %exp1 [[0], [1], [2, 3]] output_shape [1, 1800, 8, 32] : tensor<1x1800x256xf32> into tensor<1x1800x8x32xf32>
401+
%empty2 = tensor.empty() : tensor<1x1800x8x32xf32>
402+
%exp2 = linalg.exp ins(%expand : tensor<1x1800x8x32xf32>) outs(%empty2 : tensor<1x1800x8x32xf32>) -> tensor<1x1800x8x32xf32>
403+
return %exp2 : tensor<1x1800x8x32xf32>
404+
}
405+
}
406+
407+
module attributes {transform.with_named_sequence} {
408+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
409+
%0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
410+
%transformed, %loops:1 = transform.structured.fuse %0 [0, 0, 1, 0] interchange [0, 1, 2, 3] apply_cleanup = true :
411+
(!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
412+
transform.yield
413+
}
414+
}
415+
416+
417+
418+

0 commit comments

Comments
 (0)