@@ -802,61 +802,28 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
802
802
assert (ivs.size () == iteratorTypes.size () && " did not generate enough loops" );
803
803
}
804
804
805
- static Value materializeTiledShape (OpBuilder &builder, Location loc,
806
- Value valueToTile,
807
- const SliceParameters &sliceParams) {
808
- auto shapedType = valueToTile.getType ().dyn_cast <ShapedType>();
809
- auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
810
- .Case ([&](MemRefType) {
811
- return builder.create <memref::SubViewOp>(
812
- loc, valueToTile, sliceParams.offsets ,
813
- sliceParams.sizes , sliceParams.strides );
814
- })
815
- .Case ([&](RankedTensorType) {
816
- return makeComposedExtractSliceOp (
817
- builder, loc, valueToTile, sliceParams.offsets ,
818
- sliceParams.sizes , sliceParams.strides );
819
- })
820
- .Default ([](ShapedType) -> Operation * {
821
- llvm_unreachable (" Unexpected shaped type" );
822
- });
823
- return sliceOp->getResult (0 );
824
- }
825
-
826
805
Value makeTiledShape (OpBuilder &builder, Location loc, Value valueToTile,
827
806
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
828
807
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
829
808
ArrayRef<OpFoldResult> subShapeSizes,
830
809
bool omitPartialTileCheck) {
831
- SliceParameters sliceParams =
832
- computeSliceParameters (builder, loc, valueToTile, tileSizes, map, lbs,
833
- ubs, subShapeSizes, omitPartialTileCheck);
834
- return materializeTiledShape (builder, loc, valueToTile, sliceParams);
835
- }
836
-
837
- SliceParameters
838
- computeSliceParameters (OpBuilder &builder, Location loc, Value valueToTile,
839
- ArrayRef<OpFoldResult> tileSizes, AffineMap map,
840
- ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
841
- ArrayRef<OpFoldResult> subShapeSizes,
842
- bool omitPartialTileCheck) {
843
810
auto shapedType = valueToTile.getType ().dyn_cast <ShapedType>();
844
811
assert (shapedType && " only shaped types can be tiled" );
845
812
ArrayRef<int64_t > shape = shapedType.getShape ();
846
813
int64_t rank = shapedType.getRank ();
847
814
848
815
// Construct a new subview / extract_slice for the tile.
849
- SliceParameters sliceParams ;
850
- sliceParams. offsets .reserve (rank);
851
- sliceParams. sizes .reserve (rank);
852
- sliceParams. strides .reserve (rank);
816
+ SmallVector<OpFoldResult, 4 > offsets, sizes, strides ;
817
+ offsets.reserve (rank);
818
+ sizes.reserve (rank);
819
+ strides.reserve (rank);
853
820
for (unsigned r = 0 ; r < rank; ++r) {
854
- LLVM_DEBUG (llvm::dbgs () << " computeSliceParameters : for dim#" << r);
821
+ LLVM_DEBUG (llvm::dbgs () << " makeTiledShape : for dim#" << r);
855
822
if (!isTiled (map.getSubMap ({r}), tileSizes)) {
856
- sliceParams. offsets .push_back (builder.getIndexAttr (0 ));
823
+ offsets.push_back (builder.getIndexAttr (0 ));
857
824
OpFoldResult dim = createFoldedDimOp (builder, loc, valueToTile, r);
858
- sliceParams. sizes .push_back (dim);
859
- sliceParams. strides .push_back (builder.getIndexAttr (1 ));
825
+ sizes.push_back (dim);
826
+ strides.push_back (builder.getIndexAttr (1 ));
860
827
LLVM_DEBUG (llvm::dbgs () << " : not tiled: use size: " << dim << " \n " );
861
828
continue ;
862
829
}
@@ -865,27 +832,26 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
865
832
// Tiling creates a new slice at the proper index, the slice step is 1
866
833
// (i.e. the op does not subsample, stepping occurs in the loop).
867
834
auto m = map.getSubMap ({r});
868
- LLVM_DEBUG (llvm::dbgs () << " computeSliceParameters : submap: " << m << " \n " );
835
+ LLVM_DEBUG (llvm::dbgs () << " makeTiledShape : submap: " << m << " \n " );
869
836
IRRewriter rewriter (builder);
870
837
OpFoldResult offset = makeComposedFoldedAffineApply (rewriter, loc, m, lbs);
871
- sliceParams. offsets .push_back (offset);
838
+ offsets.push_back (offset);
872
839
OpFoldResult closedIntSize =
873
840
makeComposedFoldedAffineApply (rewriter, loc, m, subShapeSizes);
874
841
// Resulting size needs to be made half open interval again.
875
842
AffineExpr s0 = getAffineSymbolExpr (0 , builder.getContext ());
876
843
OpFoldResult size =
877
844
makeComposedFoldedAffineApply (rewriter, loc, s0 + 1 , closedIntSize);
845
+ LLVM_DEBUG (llvm::dbgs () << " makeTiledShape: raw size: " << size << " \n " );
878
846
LLVM_DEBUG (llvm::dbgs ()
879
- << " computeSliceParameters: raw size: " << size << " \n " );
880
- LLVM_DEBUG (llvm::dbgs ()
881
- << " computeSliceParameters: new offset: " << offset << " \n " );
882
- sliceParams.strides .push_back (builder.getIndexAttr (1 ));
847
+ << " makeTiledShape: new offset: " << offset << " \n " );
848
+ strides.push_back (builder.getIndexAttr (1 ));
883
849
884
850
if (omitPartialTileCheck) {
885
851
// We statically know that the partial/boundary tile condition is
886
852
// unnecessary.
887
853
LLVM_DEBUG (llvm::dbgs () << " makeTiledShape: new size: " << size << " \n " );
888
- sliceParams. sizes .push_back (size);
854
+ sizes.push_back (size);
889
855
continue ;
890
856
}
891
857
@@ -937,9 +903,22 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
937
903
makeComposedFoldedAffineMin (rewriter, loc, minMap, {size, d, offset});
938
904
}
939
905
LLVM_DEBUG (llvm::dbgs () << " makeTiledShape: new size: " << size << " \n " );
940
- sliceParams. sizes .push_back (size);
906
+ sizes.push_back (size);
941
907
}
942
- return sliceParams;
908
+
909
+ auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
910
+ .Case ([&](MemRefType) {
911
+ return builder.create <memref::SubViewOp>(
912
+ loc, valueToTile, offsets, sizes, strides);
913
+ })
914
+ .Case ([&](RankedTensorType) {
915
+ return makeComposedExtractSliceOp (
916
+ builder, loc, valueToTile, offsets, sizes, strides);
917
+ })
918
+ .Default ([](ShapedType) -> Operation * {
919
+ llvm_unreachable (" Unexpected shaped type" );
920
+ });
921
+ return sliceOp->getResult (0 );
943
922
}
944
923
945
924
SmallVector<OpFoldResult> computeTileOffsets (OpBuilder &b, Location loc,
@@ -1024,29 +1003,28 @@ Value materializeOpFoldResult(OpBuilder &builder, Location loc,
1024
1003
return materializeOpFoldResult (b, opFoldResult);
1025
1004
}
1026
1005
1027
- SmallVector<Optional<SliceParameters>>
1028
- computeAllSliceParameters (OpBuilder &builder, Location loc, LinalgOp linalgOp,
1029
- ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
1030
- ArrayRef<OpFoldResult> tileSizes,
1031
- ArrayRef<OpFoldResult> sizeBounds,
1032
- bool omitPartialTileCheck) {
1006
+ SmallVector<Value> makeTiledShapes (OpBuilder &b, Location loc,
1007
+ LinalgOp linalgOp, ValueRange valuesToTile ,
1008
+ ArrayRef<OpFoldResult> ivs,
1009
+ ArrayRef<OpFoldResult> tileSizes,
1010
+ ArrayRef<OpFoldResult> sizeBounds,
1011
+ bool omitPartialTileCheck) {
1033
1012
assert (ivs.size () == static_cast <size_t >(llvm::count_if (
1034
1013
llvm::make_range (tileSizes.begin (), tileSizes.end ()),
1035
1014
[](OpFoldResult v) { return !isZero (v); })) &&
1036
1015
" expected as many ivs as non-zero sizes" );
1037
1016
1038
1017
// Construct (potentially temporary) mins and maxes on which to apply maps
1039
1018
// that define tile subshapes.
1040
- SmallVector<OpFoldResult> lbs =
1041
- computeTileOffsets (builder, loc, ivs, tileSizes);
1019
+ SmallVector<OpFoldResult> lbs = computeTileOffsets (b, loc, ivs, tileSizes);
1042
1020
SmallVector<OpFoldResult> subShapeSizes =
1043
- computeTileSizes (builder , loc, tileSizes, sizeBounds);
1021
+ computeTileSizes (b , loc, tileSizes, sizeBounds);
1044
1022
1045
1023
assert (static_cast <int64_t >(valuesToTile.size ()) ==
1046
1024
linalgOp.getNumInputsAndOutputs () &&
1047
1025
" expected one value to tile for every operand" );
1048
- SmallVector<Optional<SliceParameters>> allSliceParams ;
1049
- allSliceParams .reserve (valuesToTile.size ());
1026
+ SmallVector<Value> tiledShapes ;
1027
+ tiledShapes .reserve (valuesToTile.size ());
1050
1028
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands ()) {
1051
1029
Value shapedOp = valuesToTile[opOperand->getOperandNumber ()];
1052
1030
LLVM_DEBUG (llvm::dbgs () << " makeTiledShapes: for operand " << shapedOp);
@@ -1057,39 +1035,18 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
1057
1035
// extract/insert slice pairs make the accessed iteration argument
1058
1036
// subdomains explicit.
1059
1037
if (!isTiled (map, tileSizes) && !linalgOp.isOutputTensor (opOperand)) {
1060
- allSliceParams .push_back (llvm::None );
1038
+ tiledShapes .push_back (shapedOp );
1061
1039
LLVM_DEBUG (llvm::dbgs () << " : not tiled: use shape: "
1062
1040
<< opOperand->get ().getType () << " \n " );
1063
1041
continue ;
1064
1042
}
1065
1043
LLVM_DEBUG (llvm::dbgs () << " : tiled: figure out subshape...\n " );
1066
1044
1067
- allSliceParams .push_back (computeSliceParameters (
1068
- builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
1069
- omitPartialTileCheck));
1045
+ tiledShapes .push_back (makeTiledShape (b, loc, shapedOp, tileSizes, map, lbs,
1046
+ sizeBounds, subShapeSizes,
1047
+ omitPartialTileCheck));
1070
1048
}
1071
1049
1072
- return allSliceParams;
1073
- }
1074
-
1075
- SmallVector<Value> makeTiledShapes (OpBuilder &builder, Location loc,
1076
- LinalgOp linalgOp, ValueRange valuesToTile,
1077
- ArrayRef<OpFoldResult> ivs,
1078
- ArrayRef<OpFoldResult> tileSizes,
1079
- ArrayRef<OpFoldResult> sizeBounds,
1080
- bool omitPartialTileCheck) {
1081
- SmallVector<Optional<SliceParameters>> allSliceParameter =
1082
- computeAllSliceParameters (builder, loc, linalgOp, valuesToTile, ivs,
1083
- tileSizes, sizeBounds, omitPartialTileCheck);
1084
- SmallVector<Value> tiledShapes;
1085
- for (auto item : llvm::zip (valuesToTile, allSliceParameter)) {
1086
- Value valueToTile = std::get<0 >(item);
1087
- Optional<SliceParameters> sliceParams = std::get<1 >(item);
1088
- tiledShapes.push_back (
1089
- sliceParams.hasValue ()
1090
- ? materializeTiledShape (builder, loc, valueToTile, *sliceParams)
1091
- : valueToTile);
1092
- }
1093
1050
return tiledShapes;
1094
1051
}
1095
1052
0 commit comments