Skip to content

[mlir][tensor] Implement constant folder for tensor.pad #92691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,14 @@ def ApplyReassociativeReshapeFoldingPatternsOp : Op<Transform_Dialect,
def ApplyRewriteTensorOpsAsConstantPatternsOp : Op<Transform_Dialect,
"apply_patterns.tensor.rewrite_as_constant",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let arguments = (ins UnitAttr:$aggressive);
let description = [{
Indicates that tensor ops (such as tensor.generate) should be replaced with
constants (arith.constant) when possible.
}];

let assemblyFormat = "attr-dict";
let assemblyFormat =
"(`aggressive` $aggressive^)? attr-dict";
}

def Transform_TensorPadOp : Transform_ConcreteOpType<"tensor.pad">;
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,12 @@ void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
/// respectively.
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns);

using ControlFoldFn = std::function<bool(OpOperand *)>;

/// Populates `patterns` with patterns that replace tensor ops (such as
/// tensor.generate) with constants when possible.
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns);
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns,
const ControlFoldFn &controlFn);

//===----------------------------------------------------------------------===//
// Transform helpers
Expand Down
15 changes: 14 additions & 1 deletion mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,20 @@ void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(

void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
tensor::populateRewriteAsConstantPatterns(patterns);
ControlFoldFn defaultControlFn = [](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
return producer && producer->hasOneUse();
};

ControlFoldFn aggressiveControlFn = [](OpOperand *fusedOperand) {
return true;
};

// Add folding with reshape by expansion patterns.
if (getAggressive())
tensor::populateRewriteAsConstantPatterns(patterns, aggressiveControlFn);
else
tensor::populateRewriteAsConstantPatterns(patterns, defaultControlFn);
}

//===----------------------------------------------------------------------===//
Expand Down
165 changes: 164 additions & 1 deletion mlir/lib/Dialect/Tensor/Transforms/RewriteAsConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"

#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::tensor;

Expand Down Expand Up @@ -45,9 +48,169 @@ struct GenerateToConstant : public OpRewritePattern<GenerateOp> {
}
};

/// Transform a linear index from one indexing space to another given:
///
/// - the shape of the source indexing space,
/// - the strides of the target indexing space,
/// - a linear index into the source indexing space.
///
/// This function is logically a sequence of linearize/delinearize over
/// different bases but avoids allocating intermediate SmallVectors.
int64_t transformIndexSpace(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> outputStrides,
int64_t srcLinearIndex) {
assert(inputShape.size() == outputStrides.size());

int64_t dstLinearIndex = 0;

for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
// Compute the index into the current dimension of the source tensor.
// `quotient` is the remaining linear index after accounting for the
// current dimension.
//
// `remainder` is the index into the source tensor for the current
// dimension.
auto [quotient, remainder] = std::div(srcLinearIndex, inputShape[dim]);

srcLinearIndex = quotient;

// Add the contribution of the current dimension to the output using the
// permutation map.
dstLinearIndex += outputStrides[dim] * remainder;
}

return dstLinearIndex;
}

template <typename ElemType, typename AttrType>
Value constantFoldPadOp(PatternRewriter &rewriter, Location loc,
DenseElementsAttr input, AttrType padValue,
ArrayRef<int64_t> padLow, ArrayRef<int64_t> padHigh) {
auto inputValues = input.tryGetValues<ElemType>();
if (failed(inputValues))
return nullptr;

auto oldShape = input.getType().getShape();

// Compute the output shape of the new value.
auto newShape =
llvm::map_to_vector(llvm::zip(oldShape, padLow, padHigh),
[](std::tuple<int64_t, int64_t, int64_t> pack) {
auto [old, low, high] = pack;
return old + low + high;
});

int64_t outputSize = computeProduct(newShape);

// Fully initialize the vector with the padding value.
// The non-padded area will then be copied.
SmallVector<ElemType> values(outputSize, padValue.getValue());

// Strides for input and output are used to transform between the indexing
// space of the input and output tensors.
SmallVector<int64_t> outputStrides = computeStrides(newShape);

// The contribution of the low padding to the offset in the output tensor.
// This is the starting position of the source tensor within the padding
// tensor.
int64_t startingOffset = linearize(padLow, outputStrides);

// Copy values from the input tensor to the corresponding sub-region
// of the output tensor.
for (auto [inputIndex, inputValue] : llvm::enumerate(*inputValues)) {
auto outputIndex = transformIndexSpace(oldShape, outputStrides, inputIndex);
values[outputIndex + startingOffset] = inputValue;
}

// Create an attribute for the folded value.
auto newType = input.getType().clone(newShape);
auto newAttr = DenseElementsAttr::get(newType, values);

Operation *constantOp =
rewriter.getContext()
->getLoadedDialect<TensorDialect>()
->materializeConstant(rewriter, newAttr, newType, loc);

return constantOp ? constantOp->getResult(0) : nullptr;
}

struct PadOpToConstant final : public OpRewritePattern<PadOp> {

PadOpToConstant(MLIRContext *context, const ControlFoldFn &controlFn,
PatternBenefit benefit = 1)
: OpRewritePattern<PadOp>(context, benefit), controlFn{controlFn} {}

LogicalResult matchAndRewrite(PadOp padTensorOp,
PatternRewriter &rewriter) const override {
if (padTensorOp.getNofold())
return rewriter.notifyMatchFailure(
padTensorOp, "refusing to fold nofold pad operation");

TypedValue<RankedTensorType> input = padTensorOp.getSource();
RankedTensorType resultType = padTensorOp.getResult().getType();

DenseElementsAttr inputAttr = nullptr;
if (!matchPattern(input, m_Constant(&inputAttr)))
return failure();

Value paddingValue = padTensorOp.getConstantPaddingValue();

// Extract the constant value used for padding or bail out.
Attribute paddingAttr = nullptr;
if (!paddingValue || !matchPattern(paddingValue, m_Constant(&paddingAttr)))
return rewriter.notifyMatchFailure(padTensorOp,
"unable to get constant value");

// Try to extract the constant values of the low and high padding.
auto lowPad = getConstantIntValues(padTensorOp.getMixedLowPad());
auto highPad = getConstantIntValues(padTensorOp.getMixedHighPad());

// If the padding cannot be extracted, bail out.
if (!lowPad || !highPad)
return rewriter.notifyMatchFailure(padTensorOp,
"unable to extract constant padding");

// We have a potential candidate, consult the control function to
// determine if the op should fold.
if (!controlFn(&padTensorOp.getSourceMutable()))
return rewriter.notifyMatchFailure(padTensorOp,
"not folding due to cost function");

Location loc = padTensorOp.getLoc();

// Try constant folding the supported cases of integer and float values.
Value newOp =
llvm::TypeSwitch<Attribute, Value>(paddingAttr)
.Case([&](FloatAttr floatAttr) {
return constantFoldPadOp<llvm::APFloat>(
rewriter, loc, inputAttr, floatAttr, *lowPad, *highPad);
})
.Case([&](IntegerAttr integerAttr) {
return constantFoldPadOp<llvm::APInt>(
rewriter, loc, inputAttr, integerAttr, *lowPad, *highPad);
})
.Default(Value());

if (!newOp)
return rewriter.notifyMatchFailure(padTensorOp,
"tensor type not supported");

if (newOp.getType() != resultType)
newOp = rewriter.create<tensor::CastOp>(loc, resultType, newOp);

rewriter.replaceOp(padTensorOp, newOp);
return success();
}

private:
ControlFoldFn controlFn;
};

} // namespace

void mlir::tensor::populateRewriteAsConstantPatterns(
RewritePatternSet &patterns) {
RewritePatternSet &patterns, const ControlFoldFn &controlFn) {
patterns.add<GenerateToConstant>(patterns.getContext());

patterns.add<PadOpToConstant>(patterns.getContext(), controlFn);
}
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Utils/IndexingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ int64_t mlir::computeProduct(ArrayRef<int64_t> basis) {
assert(llvm::all_of(basis, [](int64_t s) { return s > 0; }) &&
"basis must be nonnegative");
if (basis.empty())
return 0;
return 1;
return std::accumulate(basis.begin(), basis.end(), 1,
std::multiplies<int64_t>());
}
Expand Down
135 changes: 135 additions & 0 deletions mlir/test/Dialect/Tensor/rewrite-as-constant.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,138 @@ func.func @tensor_generate_constant() -> tensor<2x3x5xf32> {
} : tensor<2x3x5xf32>
return %0 : tensor<2x3x5xf32>
}

// CHECK-LABEL: func @pad_of_ints(
// CHECK: %[[cst:.*]] = arith.constant dense<[
// CHECK-SAME{LITERAL}: [0, 0, 0, 0],
// CHECK-SAME{LITERAL}: [0, 6, 7, 0],
// CHECK-SAME{LITERAL}: [0, 8, 9, 0],
// CHECK-SAME{LITERAL}: [0, 0, 0, 0]
// CHECK-SAME{LITERAL}: ]> : tensor<4x4xi32>
// CHECK: %[[cast:.*]] = tensor.cast %[[cst]] : tensor<4x4xi32> to tensor<?x?xi32>
// CHECK: return %[[cast]]
func.func @pad_of_ints() -> tensor<?x?xi32> {
%init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
%pad_value = arith.constant 0 : i32

%c1 = arith.constant 1 : index

%0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %pad_value : i32
} : tensor<2x2xi32> to tensor<?x?xi32>

return %0 : tensor<?x?xi32>
}

// CHECK-LABEL: func @pad_of_floats(
// CHECK: %[[cst:.*]] = arith.constant dense<[
// CHECK-SAME{LITERAL}: [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00],
// CHECK-SAME{LITERAL}: [0.000000e+00, 6.000000e+00, 7.000000e+00, 0.000000e+00],
// CHECK-SAME{LITERAL}: [0.000000e+00, 8.000000e+00, 9.000000e+00, 0.000000e+00],
// CHECK-SAME{LITERAL}: [0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00]
// CHECK-SAME{LITERAL}: ]> : tensor<4x4xf32>
// CHECK: return %[[cst]]

func.func @pad_of_floats() -> tensor<4x4xf32> {
%init = arith.constant dense<[[6.0, 7.0], [8.0, 9.0]]> : tensor<2x2xf32>
%pad_value = arith.constant 0.0 : f32

%0 = tensor.pad %init low[1, 1] high[1, 1] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %pad_value : f32
} : tensor<2x2xf32> to tensor<4x4xf32>

return %0 : tensor<4x4xf32>
}

// CHECK-LABEL: func @pad_of_ints_no_low_dims(
// CHECK: %[[cst:.*]] = arith.constant dense<[
// CHECK-SAME{LITERAL}: [6, 7, 0],
// CHECK-SAME{LITERAL}: [8, 9, 0],
// CHECK-SAME{LITERAL}: [0, 0, 0]
// CHECK-SAME{LITERAL}: ]> : tensor<3x3xi32>
// CHECK: return %[[cst]]
func.func @pad_of_ints_no_low_dims() -> tensor<3x3xi32> {
%init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
%pad_value = arith.constant 0 : i32

%0 = tensor.pad %init low[0, 0] high[1, 1] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %pad_value : i32
} : tensor<2x2xi32> to tensor<3x3xi32>

return %0 : tensor<3x3xi32>
}

// CHECK-LABEL: func @pad_of_ints_no_high_dims(
// CHECK: %[[cst:.*]] = arith.constant dense<[
// CHECK-SAME{LITERAL}: [0, 0, 0],
// CHECK-SAME{LITERAL}: [0, 6, 7],
// CHECK-SAME{LITERAL}: [0, 8, 9]
// CHECK-SAME{LITERAL}: ]> : tensor<3x3xi32>
// CHECK: return %[[cst]]
func.func @pad_of_ints_no_high_dims() -> tensor<3x3xi32> {
%init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
%pad_value = arith.constant 0 : i32

%0 = tensor.pad %init low[1, 1] high[0, 0] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %pad_value : i32
} : tensor<2x2xi32> to tensor<3x3xi32>

return %0 : tensor<3x3xi32>
}

// CHECK-LABEL: func @pad_multi_use_do_not_fold(
// CHECK: %[[pad:.+]] = tensor.pad
// CHECK: return %[[pad]]
func.func @pad_multi_use_do_not_fold() -> (tensor<?x?xi32>, tensor<2x2xi32>) {
%init = arith.constant dense<[[6, 7], [8, 9]]> : tensor<2x2xi32>
%pad_value = arith.constant 0 : i32

%c1 = arith.constant 1 : index

%0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %pad_value : i32
} : tensor<2x2xi32> to tensor<?x?xi32>

return %0, %init : tensor<?x?xi32>, tensor<2x2xi32>
}

// -----

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
transform.apply_patterns.tensor.rewrite_as_constant aggressive
} : !transform.op<"func.func">
transform.yield
}
}

// CHECK-LABEL: func @pad_aggressive_fold(
// CHECK: %[[init:.*]] = arith.constant dense<7> : tensor<2x2xi32>
// CHECK: %[[cst:.*]] = arith.constant dense<[
// CHECK-SAME{LITERAL}: [0, 0, 0, 0],
// CHECK-SAME{LITERAL}: [0, 7, 7, 0],
// CHECK-SAME{LITERAL}: [0, 7, 7, 0],
// CHECK-SAME{LITERAL}: [0, 0, 0, 0]
// CHECK-SAME{LITERAL}: ]> : tensor<4x4xi32>
// CHECK: %[[cast:.*]] = tensor.cast %[[cst]] : tensor<4x4xi32> to tensor<?x?xi32>
// CHECK: return %[[cast]]
func.func @pad_aggressive_fold() -> (tensor<?x?xi32>, tensor<2x2xi32>) {
%init = arith.constant dense<7> : tensor<2x2xi32>
%pad_value = arith.constant 0 : i32

%c1 = arith.constant 1 : index

%0 = tensor.pad %init low[%c1, %c1] high[%c1, %c1] {
^bb0(%arg1: index, %arg2: index):
tensor.yield %pad_value : i32
} : tensor<2x2xi32> to tensor<?x?xi32>

return %0, %init : tensor<?x?xi32>, tensor<2x2xi32>
}
Loading