Skip to content

Commit 1ceaffd

Browse files
ahmedsabiejoker-eph
authored andcommitted
[MLIR] Add a foldTrait() mechanism to allow traits to define folding and test it with an Involution trait
This change allows folds to be done on a newly introduced involution trait rather than having to manually rewrite this optimization for every instance of an involution Reviewed By: rriddle, andyly, stephenneuendorffer Differential Revision: https://reviews.llvm.org/D88809
1 parent 26cfb6e commit 1ceaffd

File tree

9 files changed

+311
-1
lines changed

9 files changed

+311
-1
lines changed

mlir/docs/Traits.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,47 @@ Note: It is generally good practice to define the implementation of the
5656
`verifyTrait` hook out-of-line as a free function when possible to avoid
5757
instantiating the implementation for every concrete operation type.
5858

59+
Operation traits may also provide a `foldTrait` hook that is called when
60+
folding the concrete operation. The trait folders will only be invoked if
61+
the concrete operation fold is either not implemented, fails, or performs
62+
an in-place fold.
63+
64+
The following signature of fold will be called if it is implemented
65+
and the op has a single result.
66+
67+
```c++
68+
template <typename ConcreteType>
69+
class MyTrait : public OpTrait::TraitBase<ConcreteType, MyTrait> {
70+
public:
71+
/// Override the 'foldTrait' hook to support trait based folding on the
72+
/// concrete operation.
73+
static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { {
74+
// ...
75+
}
76+
};
77+
```
78+
79+
Otherwise, if the operation has a single result and the above signature is
80+
not implemented, or the operation has multiple results, then the following signature
81+
will be used (if implemented):
82+
83+
```c++
84+
template <typename ConcreteType>
85+
class MyTrait : public OpTrait::TraitBase<ConcreteType, MyTrait> {
86+
public:
87+
/// Override the 'foldTrait' hook to support trait based folding on the
88+
/// concrete operation.
89+
static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
90+
SmallVectorImpl<OpFoldResult> &results) { {
91+
// ...
92+
}
93+
};
94+
```
95+
96+
Note: It is generally good practice to define the implementation of the
97+
`foldTrait` hook out-of-line as a free function when possible to avoid
98+
instantiating the implementation for every concrete operation type.
99+
59100
### Parametric Traits
60101

61102
The above demonstrates the definition of a simple self-contained trait. It is

mlir/include/mlir/IR/OpBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,6 +1723,8 @@ def ResultsBroadcastableShape :
17231723
NativeOpTrait<"ResultsBroadcastableShape">;
17241724
// X op Y == Y op X
17251725
def Commutative : NativeOpTrait<"IsCommutative">;
1726+
// op op X == X
1727+
def Involution : NativeOpTrait<"IsInvolution">;
17261728
// Op behaves like a constant.
17271729
def ConstantLike : NativeOpTrait<"ConstantLike">;
17281730
// Op behaves like a function.

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "mlir/IR/Operation.h"
2323
#include "llvm/Support/PointerLikeTypeTraits.h"
24+
2425
#include <type_traits>
2526

2627
namespace mlir {
@@ -277,7 +278,16 @@ class FoldingHook {
277278
/// AbstractOperation.
278279
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
279280
SmallVectorImpl<OpFoldResult> &results) {
280-
return cast<ConcreteType>(op).fold(operands, results);
281+
auto operationFoldResult = cast<ConcreteType>(op).fold(operands, results);
282+
// Failure to fold or in place fold both mean we can continue folding.
283+
if (failed(operationFoldResult) || results.empty()) {
284+
auto traitFoldResult = ConcreteType::foldTraits(op, operands, results);
285+
// Only return the trait fold result if it is a success since
286+
// operationFoldResult might have been a success originally.
287+
if (succeeded(traitFoldResult))
288+
return traitFoldResult;
289+
}
290+
return operationFoldResult;
281291
}
282292

283293
/// This hook implements a generalized folder for this operation. Operations
@@ -326,6 +336,14 @@ class FoldingHook<ConcreteType, isSingleResult,
326336
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
327337
SmallVectorImpl<OpFoldResult> &results) {
328338
auto result = cast<ConcreteType>(op).fold(operands);
339+
// Failure to fold or in place fold both mean we can continue folding.
340+
if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
341+
// Only consider the trait fold result if it is a success since
342+
// the operation fold might have been a success originally.
343+
if (auto traitFoldResult = ConcreteType::foldTraits(op, operands))
344+
result = traitFoldResult;
345+
}
346+
329347
if (!result)
330348
return failure();
331349

@@ -370,9 +388,11 @@ namespace OpTrait {
370388
// corresponding trait classes. This avoids them being template
371389
// instantiated/duplicated.
372390
namespace impl {
391+
OpFoldResult foldInvolution(Operation *op);
373392
LogicalResult verifyZeroOperands(Operation *op);
374393
LogicalResult verifyOneOperand(Operation *op);
375394
LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
395+
LogicalResult verifyIsInvolution(Operation *op);
376396
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
377397
LogicalResult verifyOperandsAreFloatLike(Operation *op);
378398
LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
@@ -426,6 +446,23 @@ class TraitBase {
426446
static AbstractOperation::OperationProperties getTraitProperties() {
427447
return 0;
428448
}
449+
450+
static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
451+
SmallVector<OpFoldResult, 1> results;
452+
if (failed(foldTrait(op, operands, results)))
453+
return {};
454+
if (results.empty())
455+
return op->getResult(0);
456+
assert(results.size() == 1 &&
457+
"Single result op cannot return multiple fold results");
458+
459+
return results[0];
460+
}
461+
462+
static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands,
463+
SmallVectorImpl<OpFoldResult> &results) {
464+
return failure();
465+
}
429466
};
430467

431468
//===----------------------------------------------------------------------===//
@@ -974,6 +1011,26 @@ class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
9741011
}
9751012
};
9761013

1014+
/// This class adds property that the operation is an involution.
1015+
/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x)
1016+
template <typename ConcreteType>
1017+
class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
1018+
public:
1019+
static LogicalResult verifyTrait(Operation *op) {
1020+
static_assert(ConcreteType::template hasTrait<OneResult>(),
1021+
"expected operation to produce one result");
1022+
static_assert(ConcreteType::template hasTrait<OneOperand>(),
1023+
"expected operation to take one operand");
1024+
static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1025+
"expected operation to preserve type");
1026+
return impl::verifyIsInvolution(op);
1027+
}
1028+
1029+
static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1030+
return impl::foldInvolution(op);
1031+
}
1032+
};
1033+
9771034
/// This class verifies that all operands of the specified op have a float type,
9781035
/// a vector thereof, or a tensor thereof.
9791036
template <typename ConcreteType>
@@ -1306,6 +1363,19 @@ class Op : public OpState,
13061363
failed(cast<ConcreteType>(op).verify()));
13071364
}
13081365

1366+
/// This is the hook that tries to fold the given operation according to its
1367+
/// traits. It delegates to the Traits for their policy implementations, and
1368+
/// allows the user to specify their own fold() method.
1369+
static OpFoldResult foldTraits(Operation *op, ArrayRef<Attribute> operands) {
1370+
return BaseFolder<Traits<ConcreteType>...>::foldTraits(op, operands);
1371+
}
1372+
1373+
static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
1374+
SmallVectorImpl<OpFoldResult> &results) {
1375+
return BaseFolder<Traits<ConcreteType>...>::foldTraits(op, operands,
1376+
results);
1377+
}
1378+
13091379
// Returns the properties of an operation by combining the properties of the
13101380
// traits of the op.
13111381
static AbstractOperation::OperationProperties getOperationProperties() {
@@ -1358,6 +1428,53 @@ class Op : public OpState,
13581428
}
13591429
};
13601430

1431+
template <typename... Types>
1432+
struct BaseFolder;
1433+
1434+
template <typename First, typename... Rest>
1435+
struct BaseFolder<First, Rest...> {
1436+
static OpFoldResult foldTraits(Operation *op,
1437+
ArrayRef<Attribute> operands) {
1438+
auto result = First::foldTrait(op, operands);
1439+
// Failure to fold or in place fold both mean we can continue folding.
1440+
if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
1441+
// Only consider the trait fold result if it is a success since
1442+
// the operation fold might have been a success originally.
1443+
auto resultRemaining = BaseFolder<Rest...>::foldTraits(op, operands);
1444+
if (resultRemaining)
1445+
result = resultRemaining;
1446+
}
1447+
1448+
return result;
1449+
}
1450+
1451+
static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
1452+
SmallVectorImpl<OpFoldResult> &results) {
1453+
auto result = First::foldTrait(op, operands, results);
1454+
// Failure to fold or in place fold both mean we can continue folding.
1455+
if (failed(result) || results.empty()) {
1456+
auto resultRemaining =
1457+
BaseFolder<Rest...>::foldTraits(op, operands, results);
1458+
if (succeeded(resultRemaining))
1459+
result = resultRemaining;
1460+
}
1461+
1462+
return result;
1463+
}
1464+
};
1465+
1466+
template <typename...>
1467+
struct BaseFolder {
1468+
static OpFoldResult foldTraits(Operation *op,
1469+
ArrayRef<Attribute> operands) {
1470+
return {};
1471+
}
1472+
static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
1473+
SmallVectorImpl<OpFoldResult> &results) {
1474+
return failure();
1475+
}
1476+
};
1477+
13611478
template <typename...> struct BaseProperties {
13621479
static AbstractOperation::OperationProperties getTraitProperties() {
13631480
return 0;

mlir/lib/IR/Operation.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/StandardTypes.h"
1515
#include "mlir/IR/TypeUtilities.h"
1616
#include "mlir/Interfaces/FoldInterfaces.h"
17+
#include "mlir/Interfaces/SideEffectInterfaces.h"
1718
#include <numeric>
1819

1920
using namespace mlir;
@@ -679,6 +680,16 @@ InFlightDiagnostic OpState::emitRemark(const Twine &message) {
679680
// Op Trait implementations
680681
//===----------------------------------------------------------------------===//
681682

683+
OpFoldResult OpTrait::impl::foldInvolution(Operation *op) {
684+
auto *argumentOp = op->getOperand(0).getDefiningOp();
685+
if (argumentOp && op->getName() == argumentOp->getName()) {
686+
// Replace the outer involutions output with inner's input.
687+
return argumentOp->getOperand(0);
688+
}
689+
690+
return {};
691+
}
692+
682693
LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) {
683694
if (op->getNumOperands() != 0)
684695
return op->emitOpError() << "requires zero operands";
@@ -720,6 +731,12 @@ static Type getTensorOrVectorElementType(Type type) {
720731
return type;
721732
}
722733

734+
LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) {
735+
if (!MemoryEffectOpInterface::hasNoEffect(op))
736+
return op->emitOpError() << "requires operation to have no side effects";
737+
return success();
738+
}
739+
723740
LogicalResult
724741
OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) {
725742
for (auto opType : op->getOperandTypes()) {

mlir/test/lib/Dialect/Test/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
set(LLVM_OPTIONAL_SOURCES
22
TestDialect.cpp
33
TestPatterns.cpp
4+
TestTraits.cpp
45
)
56

67
set(LLVM_TARGET_DEFINITIONS TestInterfaces.td)
@@ -23,6 +24,7 @@ add_public_tablegen_target(MLIRTestOpsIncGen)
2324
add_mlir_library(MLIRTestDialect
2425
TestDialect.cpp
2526
TestPatterns.cpp
27+
TestTraits.cpp
2628

2729
EXCLUDE_FROM_LIBMLIR
2830

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,29 @@ def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> {
798798
let results = (outs I32);
799799
}
800800

801+
def TestInvolutionTraitNoOperationFolderOp
802+
: TEST_Op<"op_involution_trait_no_operation_fold",
803+
[SameOperandsAndResultType, NoSideEffect, Involution]> {
804+
let arguments = (ins I32:$op1);
805+
let results = (outs I32);
806+
}
807+
808+
def TestInvolutionTraitFailingOperationFolderOp
809+
: TEST_Op<"op_involution_trait_failing_operation_fold",
810+
[SameOperandsAndResultType, NoSideEffect, Involution]> {
811+
let arguments = (ins I32:$op1);
812+
let results = (outs I32);
813+
let hasFolder = 1;
814+
}
815+
816+
def TestInvolutionTraitSuccesfulOperationFolderOp
817+
: TEST_Op<"op_involution_trait_succesful_operation_fold",
818+
[SameOperandsAndResultType, NoSideEffect, Involution]> {
819+
let arguments = (ins I32:$op1);
820+
let results = (outs I32);
821+
let hasFolder = 1;
822+
}
823+
801824
def TestOpInPlaceFoldAnchor : TEST_Op<"op_in_place_fold_anchor"> {
802825
let arguments = (ins I32);
803826
let results = (outs I32);
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//===- TestTraits.cpp - Test trait folding --------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "TestDialect.h"
10+
#include "mlir/IR/PatternMatch.h"
11+
#include "mlir/Pass/Pass.h"
12+
#include "mlir/Transforms/FoldUtils.h"
13+
14+
using namespace mlir;
15+
16+
//===----------------------------------------------------------------------===//
17+
// Trait Folder.
18+
//===----------------------------------------------------------------------===//
19+
20+
OpFoldResult TestInvolutionTraitFailingOperationFolderOp::fold(
21+
ArrayRef<Attribute> operands) {
22+
// This failure should cause the trait fold to run instead.
23+
return {};
24+
}
25+
26+
OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold(
27+
ArrayRef<Attribute> operands) {
28+
auto argument_op = getOperand();
29+
// The success case should cause the trait fold to be supressed.
30+
return argument_op.getDefiningOp() ? argument_op : OpFoldResult{};
31+
}
32+
33+
namespace {
34+
struct TestTraitFolder : public PassWrapper<TestTraitFolder, FunctionPass> {
35+
void runOnFunction() override {
36+
applyPatternsAndFoldGreedily(getFunction(), {});
37+
}
38+
};
39+
} // end anonymous namespace
40+
41+
namespace mlir {
42+
void registerTestTraitsPass() {
43+
PassRegistration<TestTraitFolder>("test-trait-folder", "Run trait folding");
44+
}
45+
} // namespace mlir

0 commit comments

Comments
 (0)