Skip to content

Commit 6b03bae

Browse files
committed
Revert "[mlir] Extract offsets-sizes-strides computation from makeTiledShape(s)."
This reverts commit 56d94b3.
1 parent 1bd31a6 commit 6b03bae

File tree

8 files changed

+87
-168
lines changed

8 files changed

+87
-168
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -214,44 +214,6 @@ Value materializeOpFoldResult(ImplicitLocOpBuilder &builder,
214214
Value materializeOpFoldResult(OpBuilder &b, Location loc,
215215
OpFoldResult opFoldResult);
216216

217-
/// A struct containg offsets-sizes-strides arguments of the tiled shape.
218-
struct SliceParameters {
219-
SmallVector<OpFoldResult, 3> offsets;
220-
SmallVector<OpFoldResult, 3> sizes;
221-
SmallVector<OpFoldResult, 3> strides;
222-
};
223-
224-
/// Computes SliceParameters for a single `valueToTile`. `omitPartialTileCheck`
225-
/// controls whether to omit the partial/boundary tile condition check in cases
226-
/// where we statically know that it is unnecessary.
227-
SliceParameters
228-
computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
229-
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
230-
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
231-
ArrayRef<OpFoldResult> subShapeSizes,
232-
bool omitPartialTileCheck);
233-
234-
/// Computes SliceParamaters for all `valuesToTile` of the given
235-
/// `linalgOp`, assuming `linalgOp` is being fused into a loop
236-
/// nest for tiling with the given induction variables `ivs` and tile sizes
237-
/// `tileSizes`. `sizeBounds` are the iteration space bounds for *all* the
238-
/// implicit loops in `linalgOp`. `omitPartialTileCheck` controls whether to
239-
/// omit the partial/boundary tile condition check in cases where we statically
240-
/// know that it is unnecessary.
241-
///
242-
/// Note that a constant zero in `tileSizes` means no tiling at that implicit
243-
/// loop. The number of non-zero values in `tileSizes` should be equal to the
244-
/// number of values in `ivs`.
245-
///
246-
/// Some of the `valuesToTile` won't be affected by tiling. For these values,
247-
/// llvm::None will be returned.
248-
SmallVector<Optional<SliceParameters>>
249-
computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
250-
ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
251-
ArrayRef<OpFoldResult> tileSizes,
252-
ArrayRef<OpFoldResult> sizeBounds,
253-
bool omitPartialTileCheck);
254-
255217
/// Creates an extract_slice/subview op for a single `valueToTile` with
256218
/// `builder`. This new operation extracts a tile of `valueToTile`, starting
257219
/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 43 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -802,61 +802,28 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
802802
assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
803803
}
804804

805-
static Value materializeTiledShape(OpBuilder &builder, Location loc,
806-
Value valueToTile,
807-
const SliceParameters &sliceParams) {
808-
auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
809-
auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
810-
.Case([&](MemRefType) {
811-
return builder.create<memref::SubViewOp>(
812-
loc, valueToTile, sliceParams.offsets,
813-
sliceParams.sizes, sliceParams.strides);
814-
})
815-
.Case([&](RankedTensorType) {
816-
return makeComposedExtractSliceOp(
817-
builder, loc, valueToTile, sliceParams.offsets,
818-
sliceParams.sizes, sliceParams.strides);
819-
})
820-
.Default([](ShapedType) -> Operation * {
821-
llvm_unreachable("Unexpected shaped type");
822-
});
823-
return sliceOp->getResult(0);
824-
}
825-
826805
Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
827806
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
828807
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
829808
ArrayRef<OpFoldResult> subShapeSizes,
830809
bool omitPartialTileCheck) {
831-
SliceParameters sliceParams =
832-
computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
833-
ubs, subShapeSizes, omitPartialTileCheck);
834-
return materializeTiledShape(builder, loc, valueToTile, sliceParams);
835-
}
836-
837-
SliceParameters
838-
computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
839-
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
840-
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
841-
ArrayRef<OpFoldResult> subShapeSizes,
842-
bool omitPartialTileCheck) {
843810
auto shapedType = valueToTile.getType().dyn_cast<ShapedType>();
844811
assert(shapedType && "only shaped types can be tiled");
845812
ArrayRef<int64_t> shape = shapedType.getShape();
846813
int64_t rank = shapedType.getRank();
847814

848815
// Construct a new subview / extract_slice for the tile.
849-
SliceParameters sliceParams;
850-
sliceParams.offsets.reserve(rank);
851-
sliceParams.sizes.reserve(rank);
852-
sliceParams.strides.reserve(rank);
816+
SmallVector<OpFoldResult, 4> offsets, sizes, strides;
817+
offsets.reserve(rank);
818+
sizes.reserve(rank);
819+
strides.reserve(rank);
853820
for (unsigned r = 0; r < rank; ++r) {
854-
LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
821+
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: for dim#" << r);
855822
if (!isTiled(map.getSubMap({r}), tileSizes)) {
856-
sliceParams.offsets.push_back(builder.getIndexAttr(0));
823+
offsets.push_back(builder.getIndexAttr(0));
857824
OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
858-
sliceParams.sizes.push_back(dim);
859-
sliceParams.strides.push_back(builder.getIndexAttr(1));
825+
sizes.push_back(dim);
826+
strides.push_back(builder.getIndexAttr(1));
860827
LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
861828
continue;
862829
}
@@ -865,27 +832,26 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
865832
// Tiling creates a new slice at the proper index, the slice step is 1
866833
// (i.e. the op does not subsample, stepping occurs in the loop).
867834
auto m = map.getSubMap({r});
868-
LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
835+
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: submap: " << m << "\n");
869836
IRRewriter rewriter(builder);
870837
OpFoldResult offset = makeComposedFoldedAffineApply(rewriter, loc, m, lbs);
871-
sliceParams.offsets.push_back(offset);
838+
offsets.push_back(offset);
872839
OpFoldResult closedIntSize =
873840
makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes);
874841
// Resulting size needs to be made half open interval again.
875842
AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
876843
OpFoldResult size =
877844
makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize);
845+
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: raw size: " << size << "\n");
878846
LLVM_DEBUG(llvm::dbgs()
879-
<< "computeSliceParameters: raw size: " << size << "\n");
880-
LLVM_DEBUG(llvm::dbgs()
881-
<< "computeSliceParameters: new offset: " << offset << "\n");
882-
sliceParams.strides.push_back(builder.getIndexAttr(1));
847+
<< "makeTiledShape: new offset: " << offset << "\n");
848+
strides.push_back(builder.getIndexAttr(1));
883849

884850
if (omitPartialTileCheck) {
885851
// We statically know that the partial/boundary tile condition is
886852
// unnecessary.
887853
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
888-
sliceParams.sizes.push_back(size);
854+
sizes.push_back(size);
889855
continue;
890856
}
891857

@@ -937,9 +903,22 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
937903
makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset});
938904
}
939905
LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
940-
sliceParams.sizes.push_back(size);
906+
sizes.push_back(size);
941907
}
942-
return sliceParams;
908+
909+
auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
910+
.Case([&](MemRefType) {
911+
return builder.create<memref::SubViewOp>(
912+
loc, valueToTile, offsets, sizes, strides);
913+
})
914+
.Case([&](RankedTensorType) {
915+
return makeComposedExtractSliceOp(
916+
builder, loc, valueToTile, offsets, sizes, strides);
917+
})
918+
.Default([](ShapedType) -> Operation * {
919+
llvm_unreachable("Unexpected shaped type");
920+
});
921+
return sliceOp->getResult(0);
943922
}
944923

945924
SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
@@ -1024,29 +1003,28 @@ Value materializeOpFoldResult(OpBuilder &builder, Location loc,
10241003
return materializeOpFoldResult(b, opFoldResult);
10251004
}
10261005

1027-
SmallVector<Optional<SliceParameters>>
1028-
computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
1029-
ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
1030-
ArrayRef<OpFoldResult> tileSizes,
1031-
ArrayRef<OpFoldResult> sizeBounds,
1032-
bool omitPartialTileCheck) {
1006+
SmallVector<Value> makeTiledShapes(OpBuilder &b, Location loc,
1007+
LinalgOp linalgOp, ValueRange valuesToTile,
1008+
ArrayRef<OpFoldResult> ivs,
1009+
ArrayRef<OpFoldResult> tileSizes,
1010+
ArrayRef<OpFoldResult> sizeBounds,
1011+
bool omitPartialTileCheck) {
10331012
assert(ivs.size() == static_cast<size_t>(llvm::count_if(
10341013
llvm::make_range(tileSizes.begin(), tileSizes.end()),
10351014
[](OpFoldResult v) { return !isZero(v); })) &&
10361015
"expected as many ivs as non-zero sizes");
10371016

10381017
// Construct (potentially temporary) mins and maxes on which to apply maps
10391018
// that define tile subshapes.
1040-
SmallVector<OpFoldResult> lbs =
1041-
computeTileOffsets(builder, loc, ivs, tileSizes);
1019+
SmallVector<OpFoldResult> lbs = computeTileOffsets(b, loc, ivs, tileSizes);
10421020
SmallVector<OpFoldResult> subShapeSizes =
1043-
computeTileSizes(builder, loc, tileSizes, sizeBounds);
1021+
computeTileSizes(b, loc, tileSizes, sizeBounds);
10441022

10451023
assert(static_cast<int64_t>(valuesToTile.size()) ==
10461024
linalgOp.getNumInputsAndOutputs() &&
10471025
"expected one value to tile for every operand");
1048-
SmallVector<Optional<SliceParameters>> allSliceParams;
1049-
allSliceParams.reserve(valuesToTile.size());
1026+
SmallVector<Value> tiledShapes;
1027+
tiledShapes.reserve(valuesToTile.size());
10501028
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
10511029
Value shapedOp = valuesToTile[opOperand->getOperandNumber()];
10521030
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
@@ -1057,39 +1035,18 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
10571035
// extract/insert slice pairs make the accessed iteration argument
10581036
// subdomains explicit.
10591037
if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) {
1060-
allSliceParams.push_back(llvm::None);
1038+
tiledShapes.push_back(shapedOp);
10611039
LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: "
10621040
<< opOperand->get().getType() << "\n");
10631041
continue;
10641042
}
10651043
LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
10661044

1067-
allSliceParams.push_back(computeSliceParameters(
1068-
builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
1069-
omitPartialTileCheck));
1045+
tiledShapes.push_back(makeTiledShape(b, loc, shapedOp, tileSizes, map, lbs,
1046+
sizeBounds, subShapeSizes,
1047+
omitPartialTileCheck));
10701048
}
10711049

1072-
return allSliceParams;
1073-
}
1074-
1075-
SmallVector<Value> makeTiledShapes(OpBuilder &builder, Location loc,
1076-
LinalgOp linalgOp, ValueRange valuesToTile,
1077-
ArrayRef<OpFoldResult> ivs,
1078-
ArrayRef<OpFoldResult> tileSizes,
1079-
ArrayRef<OpFoldResult> sizeBounds,
1080-
bool omitPartialTileCheck) {
1081-
SmallVector<Optional<SliceParameters>> allSliceParameter =
1082-
computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
1083-
tileSizes, sizeBounds, omitPartialTileCheck);
1084-
SmallVector<Value> tiledShapes;
1085-
for (auto item : llvm::zip(valuesToTile, allSliceParameter)) {
1086-
Value valueToTile = std::get<0>(item);
1087-
Optional<SliceParameters> sliceParams = std::get<1>(item);
1088-
tiledShapes.push_back(
1089-
sliceParams.hasValue()
1090-
? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
1091-
: valueToTile);
1092-
}
10931050
return tiledShapes;
10941051
}
10951052

mlir/test/Dialect/Linalg/tile-and-distribute.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ func.func @gemm1(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32
1616
// CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x
1717
// CHECK: scf.for %[[ARG3:.*]] =
1818
// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
19-
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
20-
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
21-
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
2219
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
20+
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
2321
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
24-
// CHECK: %[[SV3:.*]] = memref.subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
22+
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
23+
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
24+
// CHECK: %[[SV3:.*]] = memref.subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX]]]
2525
// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
2626

2727
// -----
@@ -48,11 +48,11 @@ func.func @gemm2(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32
4848
// CHECK: scf.if %[[INBOUNDS]]
4949
// CHECK: scf.for %[[ARG3:.*]] =
5050
// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
51+
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
5152
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
53+
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
5254
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
5355
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
54-
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
55-
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
5656
// CHECK: %[[SV3:.*]] = memref.subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
5757
// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
5858

@@ -106,11 +106,11 @@ func.func @gemm4(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32
106106
// CHECK: scf.if %[[INBOUNDS]]
107107
// CHECK: scf.for %[[ARG3:.*]] =
108108
// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
109+
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
109110
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
111+
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
110112
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
111113
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
112-
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG3]]]
113-
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG3]], %[[OFFSETX]]]
114114
// CHECK: %[[SV3:.*]] = memref.subview %[[ARG2]][%[[OFFSETY_2]], %[[OFFSETX_2]]]
115115
// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
116116

@@ -139,9 +139,9 @@ func.func @gemm5(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32
139139
// CHECK: scf.parallel (%[[ARG3:.*]]) = (%[[LBX]]) to (%{{.*}}) step (%[[STEPX]])
140140
// CHECK: scf.for %[[ARG4:.*]] =
141141
// CHECK: %[[OFFSETY:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
142-
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
143142
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[OFFSETY]], %[[ARG4]]]
144143
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG4]], %[[ARG3]]]
144+
// CHECK: %[[OFFSETY_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
145145
// CHECK: %[[SV3:.*]] = memref.subview %[[ARG2]][%[[OFFSETY_2]], %[[ARG3]]]
146146
// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
147147

@@ -166,10 +166,10 @@ func.func @gemm6(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32
166166
// CHECK: %[[STEPY:.*]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
167167
// CHECK: scf.parallel (%[[ARG3:.*]]) = (%[[LBY]]) to (%{{.*}}) step (%[[STEPY]])
168168
// CHECK: scf.for %[[ARG4:.*]] =
169-
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
170-
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
171169
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[ARG3]], %[[ARG4]]]
170+
// CHECK: %[[OFFSETX:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
172171
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG1]][%[[ARG4]], %[[OFFSETX]]]
172+
// CHECK: %[[OFFSETX_2:.*]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
173173
// CHECK: %[[SV3:.*]] = memref.subview %[[ARG2]][%[[ARG3]], %[[OFFSETX_2]]]
174174
// CHECK: linalg.matmul ins(%[[SV1]], %[[SV2]]{{.*}} outs(%[[SV3]]
175175

mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,9 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
241241
// CHECK-NEXT: %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND2_MAP]](%[[IV2]])[%[[ELEM_OC]]]
242242
// CHECK-NEXT: %[[OFFSET_OW:.+]] = affine.apply #[[X2_MAP]](%[[IV2]])
243243
// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[FILL_W]], %[[FILTER_W]]]
244-
// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[FILL_W]], %[[ELEM_OW]]]
245244
// CHECK-NEXT: %[[ST_INPUT:.+]] = tensor.extract_slice %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0]
246245
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]]
246+
// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[FILL_W]], %[[ELEM_OW]]]
247247
// CHECK-NEXT: scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]]
248248
// CHECK-NEXT: %[[ST_ELEM:.+]] = tensor.extract_slice %[[ELEM]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
249249
// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]

mlir/test/Dialect/Linalg/tile-conv.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ func.func @conv(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>, %arg2 : memref
2626
// CHECK: scf.for %[[ARG4:.*]] = %[[C0]] to %[[T3]] step %[[C3]]
2727
// CHECK: %[[T4:.*]] = affine.min #[[MAP0]](%[[ARG3]])[%[[T2]], %[[T0]]]
2828
// CHECK: %[[T5:.*]] = affine.min #[[MAP1]](%[[ARG4]])[%[[T3]], %[[T1]]]
29+
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[ARG3]], %[[ARG4]]] [%[[T4]], %[[T5]]]
2930
// CHECK: %[[T6:.*]] = affine.min #[[MAP2]](%[[ARG3]])[%[[T2]]
3031
// CHECK: %[[T7:.*]] = affine.min #[[MAP3]](%[[ARG4]])[%[[T3]]]
31-
// CHECK: %[[SV1:.*]] = memref.subview %[[ARG0]][%[[ARG3]], %[[ARG4]]] [%[[T4]], %[[T5]]]
3232
// CHECK: %[[SV2:.*]] = memref.subview %[[ARG2]][%[[ARG3]], %[[ARG4]]] [%[[T6]], %[[T7]]]
3333
// CHECK: linalg.conv_2d
3434
// CHECK-SAME: ins(%[[SV1]], %[[ARG1]]

0 commit comments

Comments
 (0)