Skip to content

Commit a9205c5

Browse files
sabaumaSpenser Bauman
and
Spenser Bauman
authored
[mlir][tensor] Implement constant folder for tensor.pad (#92691)
Extend the folding ability of the RewriteAsConstant patterns to include tensor.pad operations on constants. The new pattern with constant fold tensor.pad operations which operate on tensor constants and have statically resolvable padding sizes/values. %init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32> %pad_value = arith.constant 0 : i32 %0 = tensor.pad %init low[1, 1] high[1, 1] { ^bb0(%arg1: index, %arg2: index): tensor.yield %pad_value : i32 } : tensor<2x2xi32> to tensor<4x4xi32> becomes %cst = arith.constant dense<[[0, 0, 0, 0], [0, 6, 7, 0], [0, 8, 9, 0], [0, 0, 0, 0]]> : tensor<4x4xi32> Co-authored-by: Spenser Bauman <sabauma@fastmail>
1 parent 2ec47e5 commit a9205c5

File tree

6 files changed

+321
-5
lines changed

6 files changed

+321
-5
lines changed

mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,14 @@ def ApplyReassociativeReshapeFoldingPatternsOp : Op<Transform_Dialect,
114114
def ApplyRewriteTensorOpsAsConstantPatternsOp : Op<Transform_Dialect,
115115
"apply_patterns.tensor.rewrite_as_constant",
116116
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
117+
let arguments = (ins UnitAttr:$aggressive);
117118
let description = [{
118119
Indicates that tensor ops (such as tensor.generate) should be replaced with
119120
constants (arith.constant) when possible.
120121
}];
121122

122-
let assemblyFormat = "attr-dict";
123+
let assemblyFormat =
124+
"(`aggressive` $aggressive^)? attr-dict";
123125
}
124126

125127
def Transform_TensorPadOp : Transform_ConcreteOpType<"tensor.pad">;

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,12 @@ void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
9191
/// respectively.
9292
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);
9393

94+
using ControlFoldFn = std::function<bool(OpOperand *)>;
95+
9496
/// Populates `patterns` with patterns that replace tensor ops (such as
9597
/// tensor.generate) with constants when possible.
96-
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns);
98+
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns,
99+
const ControlFoldFn &controlFn);
97100

98101
//===----------------------------------------------------------------------===//
99102
// Transform helpers

mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,20 @@ void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
127127

128128
void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
129129
RewritePatternSet &patterns) {
130-
tensor::populateRewriteAsConstantPatterns(patterns);
130+
ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) {
131+
Operation *producer = fusedOperand->get().getDefiningOp();
132+
return producer && producer->hasOneUse();
133+
};
134+
135+
ControlFoldFn aggressiveControlFn = [](OpOperand *fusedOperand) {
136+
return true;
137+
};
138+
139+
// Add folding with reshape by expansion patterns.
140+
if (getAggressive())
141+
tensor::populateRewriteAsConstantPatterns(patterns, aggressiveControlFn);
142+
else
143+
tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn);
131144
}
132145

133146
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp

Lines changed: 164 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
//
99
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1010
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
11+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1112
#include "mlir/IR/Matchers.h"
1213
#include "mlir/IR/PatternMatch.h"
1314

15+
#include "llvm/ADT/TypeSwitch.h"
16+
1417
using namespace mlir;
1518
using namespace mlir::tensor;
1619

@@ -45,9 +48,169 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
4548
}
4649
};
4750

51+
/// Transform a linear index from one indexing space to another given:
52+
///
53+
/// - the shape of the source indexing space,
54+
/// - the strides of the target indexing space,
55+
/// - a linear index into the source indexing space.
56+
///
57+
/// This function is logically a sequence of linearize/delinearize over
58+
/// different bases but avoids allocating intermediate SmallVectors.
59+
int64_t transformIndexSpace(ArrayRef<int64_t> inputShape,
60+
ArrayRef<int64_t> outputStrides,
61+
int64_t srcLinearIndex) {
62+
assert(inputShape.size() == outputStrides.size());
63+
64+
int64_t dstLinearIndex = 0;
65+
66+
for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
67+
// Compute the index into the current dimension of the source tensor.
68+
// `quotient` is the remaining linear index after accounting for the
69+
// current dimension.
70+
//
71+
// `remainder` is the index into the source tensor for the current
72+
// dimension.
73+
auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]);
74+
75+
srcLinearIndex = quotient;
76+
77+
// Add the contribution of the current dimension to the output using the
78+
// permutation map.
79+
dstLinearIndex += outputStrides[dim] * remainder;
80+
}
81+
82+
return dstLinearIndex;
83+
}
84+
85+
template <typename ElemType, typename AttrType>
86+
Value constantFoldPadOp(PatternRewriter &rewriter, Location loc,
87+
DenseElementsAttr input, AttrType padValue,
88+
ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) {
89+
auto inputValues = input.tryGetValues<ElemType>();
90+
if (failed(inputValues))
91+
return nullptr;
92+
93+
auto oldShape = input.getType().getShape();
94+
95+
// Compute the output shape of the new value.
96+
auto newShape =
97+
llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
98+
[](std::tuple<int64_t, int64_t, int64_t> pack) {
99+
auto [old, low, high] = pack;
100+
return old + low + high;
101+
});
102+
103+
int64_t outputSize = computeProduct(newShape);
104+
105+
// Fully initialize the vector with the padding value.
106+
// The non-padded area will then be copied.
107+
SmallVector<ElemType> values(outputSize, padValue.getValue());
108+
109+
// Strides for input and output are used to transform between the indexing
110+
// space of the input and output tensors.
111+
SmallVector<int64_t> outputStrides = computeStrides(newShape);
112+
113+
// The contribution of the low padding to the offset in the output tensor.
114+
// This is the starting position of the source tensor within the padding
115+
// tensor.
116+
int64_t startingOffset = linearize(padLow, outputStrides);
117+
118+
// Copy values from the input tensor to the corresponding sub-region
119+
// of the output tensor.
120+
for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) {
121+
auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
122+
values[outputIndex + startingOffset] = inputValue;
123+
}
124+
125+
// Create an attribute for the folded value.
126+
auto newType = input.getType().clone(newShape);
127+
auto newAttr = DenseElementsAttr::get(newType, values);
128+
129+
Operation *constantOp =
130+
rewriter.getContext()
131+
->getLoadedDialect<TensorDialect>()
132+
->materializeConstant(rewriter, newAttr, newType, loc);
133+
134+
return constantOp ? constantOp->getResult(0) : nullptr;
135+
}
136+
137+
struct PadOpToConstant final : public OpRewritePattern<PadOp> {
138+
139+
PadOpToConstant(MLIRContext *context, const ControlFoldFn &controlFn,
140+
PatternBenefit benefit = 1)
141+
: OpRewritePattern<PadOp>(context, benefit), controlFn{controlFn} {}
142+
143+
LogicalResult matchAndRewrite(PadOp padTensorOp,
144+
PatternRewriter &rewriter) const override {
145+
if (padTensorOp.getNofold())
146+
return rewriter.notifyMatchFailure(
147+
padTensorOp, "refusing to fold nofold pad operation");
148+
149+
TypedValue<RankedTensorType> input = padTensorOp.getSource();
150+
RankedTensorType resultType = padTensorOp.getResult().getType();
151+
152+
DenseElementsAttr inputAttr = nullptr;
153+
if (!matchPattern(input, m_Constant(&inputAttr)))
154+
return failure();
155+
156+
Value paddingValue = padTensorOp.getConstantPaddingValue();
157+
158+
// Extract the constant value used for padding or bail out.
159+
Attribute paddingAttr = nullptr;
160+
if (!paddingValue || !matchPattern(paddingValue, m_Constant(&paddingAttr)))
161+
return rewriter.notifyMatchFailure(padTensorOp,
162+
"unable to get constant value");
163+
164+
// Try to extract the constant values of the low and high padding.
165+
auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad());
166+
auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad());
167+
168+
// If the padding cannot be extracted, bail out.
169+
if (!lowPad || !highPad)
170+
return rewriter.notifyMatchFailure(padTensorOp,
171+
"unable to extract constant padding");
172+
173+
// We have a potential candidate, consult the control function to
174+
// determine if the op should fold.
175+
if (!controlFn(&padTensorOp.getSourceMutable()))
176+
return rewriter.notifyMatchFailure(padTensorOp,
177+
"not folding due to cost function");
178+
179+
Location loc = padTensorOp.getLoc();
180+
181+
// Try constant folding the supported cases of integer and float values.
182+
Value newOp =
183+
llvm::TypeSwitch<Attribute, Value>(paddingAttr)
184+
.Case([&](FloatAttr floatAttr) {
185+
return constantFoldPadOp<llvm::APFloat>(
186+
rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
187+
})
188+
.Case([&](IntegerAttr integerAttr) {
189+
return constantFoldPadOp<llvm::APInt>(
190+
rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
191+
})
192+
.Default(Value());
193+
194+
if (!newOp)
195+
return rewriter.notifyMatchFailure(padTensorOp,
196+
"tensor type not supported");
197+
198+
if (newOp.getType() != resultType)
199+
newOp = rewriter.create<tensor::CastOp>(loc, resultType, newOp);
200+
201+
rewriter.replaceOp(padTensorOp, newOp);
202+
return success();
203+
}
204+
205+
private:
206+
ControlFoldFn controlFn;
207+
};
208+
48209
} // namespace
49210

50211
void mlir::tensor::populateRewriteAsConstantPatterns(
51-
RewritePatternSet &patterns) {
212+
RewritePatternSet &patterns, const ControlFoldFn &controlFn) {
52213
patterns.add<GenerateToConstant>(patterns.getContext());
214+
215+
patterns.add<PadOpToConstant>(patterns.getContext(), controlFn);
53216
}

mlir/lib/Dialect/Utils/IndexingUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ int64_t mlir::computeProduct(ArrayRef<int64_t> basis) {
9292
assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
9393
"basis must be nonnegative");
9494
if (basis.empty())
95-
return 0;
95+
return 1;
9696
return std::accumulate(basis.begin(), basis.end(), 1,
9797
std::multiplies<int64_t>());
9898
}

mlir/test/Dialect/Tensor/rewrite-as-constant.mlir

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,138 @@ func.func @tensor_generate_constant() -> tensor<2x3x5xf32> {
2121
} : tensor<2x3x5xf32>
2222
return %0 : tensor<2x3x5xf32>
2323
}
24+
25+
// CHECK-LABEL: func @pad_of_ints(
26+
// CHECK: %[[cst:.*]] = arith.constant dense<[
27+
// CHECK-SAME{LITERAL}: [0, 0, 0, 0],
28+
// CHECK-SAME{LITERAL}: [0, 6, 7, 0],
29+
// CHECK-SAME{LITERAL}: [0, 8, 9, 0],
30+
// CHECK-SAME{LITERAL}: [0, 0, 0, 0]
31+
// CHECK-SAME{LITERAL}: ]> : tensor<4x4xi32>
32+
// CHECK: %[[cast:.*]] = tensor.cast %[[cst]] : tensor<4x4xi32> to tensor<?x?xi32>
33+
// CHECK: return %[[cast]]
34+
func.func @pad_of_ints() -> tensor<?x?xi32> {
35+
%init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
36+
%pad_value = arith.constant 0 : i32
37+
38+
%c1 = arith.constant 1 : index
39+
40+
%0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
41+
^bb0(%arg1: index, %arg2: index):
42+
tensor.yield %pad_value : i32
43+
} : tensor<2x2xi32> to tensor<?x?xi32>
44+
45+
return %0 : tensor<?x?xi32>
46+
}
47+
48+
// CHECK-LABEL: func @pad_of_floats(
49+
// CHECK: %[[cst:.*]] = arith.constant dense<[
50+
// CHECK-SAME{LITERAL}: [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
51+
// CHECK-SAME{LITERAL}: [0.000000e+00, 6.000000e+00, 7.000000e+00, 0.000000e+00],
52+
// CHECK-SAME{LITERAL}: [0.000000e+00, 8.000000e+00, 9.000000e+00, 0.000000e+00],
53+
// CHECK-SAME{LITERAL}: [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]
54+
// CHECK-SAME{LITERAL}: ]> : tensor<4x4xf32>
55+
// CHECK: return %[[cst]]
56+
57+
func.func @pad_of_floats() -> tensor<4x4xf32> {
58+
%init = arith.constant dense<[[6.0, 7.0], [8.0, 9.0]]> : tensor<2x2xf32>
59+
%pad_value = arith.constant 0.0 : f32
60+
61+
%0 = tensor.pad %init low[1, 1] high[1, 1] {
62+
^bb0(%arg1: index, %arg2: index):
63+
tensor.yield %pad_value : f32
64+
} : tensor<2x2xf32> to tensor<4x4xf32>
65+
66+
return %0 : tensor<4x4xf32>
67+
}
68+
69+
// CHECK-LABEL: func @pad_of_ints_no_low_dims(
70+
// CHECK: %[[cst:.*]] = arith.constant dense<[
71+
// CHECK-SAME{LITERAL}: [6, 7, 0],
72+
// CHECK-SAME{LITERAL}: [8, 9, 0],
73+
// CHECK-SAME{LITERAL}: [0, 0, 0]
74+
// CHECK-SAME{LITERAL}: ]> : tensor<3x3xi32>
75+
// CHECK: return %[[cst]]
76+
func.func @pad_of_ints_no_low_dims() -> tensor<3x3xi32> {
77+
%init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
78+
%pad_value = arith.constant 0 : i32
79+
80+
%0 = tensor.pad %init low[0, 0] high[1, 1] {
81+
^bb0(%arg1: index, %arg2: index):
82+
tensor.yield %pad_value : i32
83+
} : tensor<2x2xi32> to tensor<3x3xi32>
84+
85+
return %0 : tensor<3x3xi32>
86+
}
87+
88+
// CHECK-LABEL: func @pad_of_ints_no_high_dims(
89+
// CHECK: %[[cst:.*]] = arith.constant dense<[
90+
// CHECK-SAME{LITERAL}: [0, 0, 0],
91+
// CHECK-SAME{LITERAL}: [0, 6, 7],
92+
// CHECK-SAME{LITERAL}: [0, 8, 9]
93+
// CHECK-SAME{LITERAL}: ]> : tensor<3x3xi32>
94+
// CHECK: return %[[cst]]
95+
func.func @pad_of_ints_no_high_dims() -> tensor<3x3xi32> {
96+
%init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
97+
%pad_value = arith.constant 0 : i32
98+
99+
%0 = tensor.pad %init low[1, 1] high[0, 0] {
100+
^bb0(%arg1: index, %arg2: index):
101+
tensor.yield %pad_value : i32
102+
} : tensor<2x2xi32> to tensor<3x3xi32>
103+
104+
return %0 : tensor<3x3xi32>
105+
}
106+
107+
// CHECK-LABEL: func @pad_multi_use_do_not_fold(
108+
// CHECK: %[[pad:.+]] = tensor.pad
109+
// CHECK: return %[[pad]]
110+
func.func @pad_multi_use_do_not_fold() -> (tensor<?x?xi32>, tensor<2x2xi32>) {
111+
%init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
112+
%pad_value = arith.constant 0 : i32
113+
114+
%c1 = arith.constant 1 : index
115+
116+
%0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
117+
^bb0(%arg1: index, %arg2: index):
118+
tensor.yield %pad_value : i32
119+
} : tensor<2x2xi32> to tensor<?x?xi32>
120+
121+
return %0, %init : tensor<?x?xi32>, tensor<2x2xi32>
122+
}
123+
124+
// -----
125+
126+
module attributes {transform.with_named_sequence} {
127+
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
128+
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
129+
transform.apply_patterns to %func_op {
130+
transform.apply_patterns.tensor.rewrite_as_constant aggressive
131+
} : !transform.op<"func.func">
132+
transform.yield
133+
}
134+
}
135+
136+
// CHECK-LABEL: func @pad_aggressive_fold(
137+
// CHECK: %[[init:.*]] = arith.constant dense<7> : tensor<2x2xi32>
138+
// CHECK: %[[cst:.*]] = arith.constant dense<[
139+
// CHECK-SAME{LITERAL}: [0, 0, 0, 0],
140+
// CHECK-SAME{LITERAL}: [0, 7, 7, 0],
141+
// CHECK-SAME{LITERAL}: [0, 7, 7, 0],
142+
// CHECK-SAME{LITERAL}: [0, 0, 0, 0]
143+
// CHECK-SAME{LITERAL}: ]> : tensor<4x4xi32>
144+
// CHECK: %[[cast:.*]] = tensor.cast %[[cst]] : tensor<4x4xi32> to tensor<?x?xi32>
145+
// CHECK: return %[[cast]]
146+
func.func @pad_aggressive_fold() -> (tensor<?x?xi32>, tensor<2x2xi32>) {
147+
%init = arith.constant dense<7> : tensor<2x2xi32>
148+
%pad_value = arith.constant 0 : i32
149+
150+
%c1 = arith.constant 1 : index
151+
152+
%0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
153+
^bb0(%arg1: index, %arg2: index):
154+
tensor.yield %pad_value : i32
155+
} : tensor<2x2xi32> to tensor<?x?xi32>
156+
157+
return %0, %init : tensor<?x?xi32>, tensor<2x2xi32>
158+
}

0 commit comments

Comments
 (0)