|
21 | 21 |
|
22 | 22 | #include "mlir/IR/Operation.h"
|
23 | 23 | #include "llvm/Support/PointerLikeTypeTraits.h"
|
| 24 | + |
24 | 25 | #include <type_traits>
|
25 | 26 |
|
26 | 27 | namespace mlir {
|
@@ -277,7 +278,16 @@ class FoldingHook {
|
277 | 278 | /// AbstractOperation.
|
278 | 279 | static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
|
279 | 280 | SmallVectorImpl<OpFoldResult> &results) {
|
280 |
| - return cast<ConcreteType>(op).fold(operands, results); |
| 281 | + auto operationFoldResult = cast<ConcreteType>(op).fold(operands, results); |
| 282 | + // Failure to fold or in place fold both mean we can continue folding. |
| 283 | + if (failed(operationFoldResult) || results.empty()) { |
| 284 | + auto traitFoldResult = ConcreteType::foldTraits(op, operands, results); |
| 285 | + // Only return the trait fold result if it is a success since |
| 286 | + // operationFoldResult might have been a success originally. |
| 287 | + if (succeeded(traitFoldResult)) |
| 288 | + return traitFoldResult; |
| 289 | + } |
| 290 | + return operationFoldResult; |
281 | 291 | }
|
282 | 292 |
|
283 | 293 | /// This hook implements a generalized folder for this operation. Operations
|
@@ -326,6 +336,14 @@ class FoldingHook<ConcreteType, isSingleResult,
|
326 | 336 | static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
|
327 | 337 | SmallVectorImpl<OpFoldResult> &results) {
|
328 | 338 | auto result = cast<ConcreteType>(op).fold(operands);
|
| 339 | + // Failure to fold or in place fold both mean we can continue folding. |
| 340 | + if (!result || result.template dyn_cast<Value>() == op->getResult(0)) { |
| 341 | + // Only consider the trait fold result if it is a success since |
| 342 | + // the operation fold might have been a success originally. |
| 343 | + if (auto traitFoldResult = ConcreteType::foldTraits(op, operands)) |
| 344 | + result = traitFoldResult; |
| 345 | + } |
| 346 | + |
329 | 347 | if (!result)
|
330 | 348 | return failure();
|
331 | 349 |
|
@@ -370,9 +388,11 @@ namespace OpTrait {
|
370 | 388 | // corresponding trait classes. This avoids them being template
|
371 | 389 | // instantiated/duplicated.
|
372 | 390 | namespace impl {
|
| 391 | +OpFoldResult foldInvolution(Operation *op); |
373 | 392 | LogicalResult verifyZeroOperands(Operation *op);
|
374 | 393 | LogicalResult verifyOneOperand(Operation *op);
|
375 | 394 | LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
|
| 395 | +LogicalResult verifyIsInvolution(Operation *op); |
376 | 396 | LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
|
377 | 397 | LogicalResult verifyOperandsAreFloatLike(Operation *op);
|
378 | 398 | LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
|
@@ -426,6 +446,23 @@ class TraitBase {
|
426 | 446 | static AbstractOperation::OperationProperties getTraitProperties() {
|
427 | 447 | return 0;
|
428 | 448 | }
|
| 449 | + |
| 450 | + static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { |
| 451 | + SmallVector<OpFoldResult, 1> results; |
| 452 | + if (failed(foldTrait(op, operands, results))) |
| 453 | + return {}; |
| 454 | + if (results.empty()) |
| 455 | + return op->getResult(0); |
| 456 | + assert(results.size() == 1 && |
| 457 | + "Single result op cannot return multiple fold results"); |
| 458 | + |
| 459 | + return results[0]; |
| 460 | + } |
| 461 | + |
| 462 | + static LogicalResult foldTrait(Operation *op, ArrayRef<Attribute> operands, |
| 463 | + SmallVectorImpl<OpFoldResult> &results) { |
| 464 | + return failure(); |
| 465 | + } |
429 | 466 | };
|
430 | 467 |
|
431 | 468 | //===----------------------------------------------------------------------===//
|
@@ -974,6 +1011,26 @@ class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
|
974 | 1011 | }
|
975 | 1012 | };
|
976 | 1013 |
|
| 1014 | +/// This class adds property that the operation is an involution. |
| 1015 | +/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x) |
| 1016 | +template <typename ConcreteType> |
| 1017 | +class IsInvolution : public TraitBase<ConcreteType, IsInvolution> { |
| 1018 | +public: |
| 1019 | + static LogicalResult verifyTrait(Operation *op) { |
| 1020 | + static_assert(ConcreteType::template hasTrait<OneResult>(), |
| 1021 | + "expected operation to produce one result"); |
| 1022 | + static_assert(ConcreteType::template hasTrait<OneOperand>(), |
| 1023 | + "expected operation to take one operand"); |
| 1024 | + static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(), |
| 1025 | + "expected operation to preserve type"); |
| 1026 | + return impl::verifyIsInvolution(op); |
| 1027 | + } |
| 1028 | + |
| 1029 | + static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { |
| 1030 | + return impl::foldInvolution(op); |
| 1031 | + } |
| 1032 | +}; |
| 1033 | + |
977 | 1034 | /// This class verifies that all operands of the specified op have a float type,
|
978 | 1035 | /// a vector thereof, or a tensor thereof.
|
979 | 1036 | template <typename ConcreteType>
|
@@ -1306,6 +1363,19 @@ class Op : public OpState,
|
1306 | 1363 | failed(cast<ConcreteType>(op).verify()));
|
1307 | 1364 | }
|
1308 | 1365 |
|
| 1366 | + /// This is the hook that tries to fold the given operation according to its |
| 1367 | + /// traits. It delegates to the Traits for their policy implementations, and |
| 1368 | + /// allows the user to specify their own fold() method. |
| 1369 | + static OpFoldResult foldTraits(Operation *op, ArrayRef<Attribute> operands) { |
| 1370 | + return BaseFolder<Traits<ConcreteType>...>::foldTraits(op, operands); |
| 1371 | + } |
| 1372 | + |
| 1373 | + static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands, |
| 1374 | + SmallVectorImpl<OpFoldResult> &results) { |
| 1375 | + return BaseFolder<Traits<ConcreteType>...>::foldTraits(op, operands, |
| 1376 | + results); |
| 1377 | + } |
| 1378 | + |
1309 | 1379 | // Returns the properties of an operation by combining the properties of the
|
1310 | 1380 | // traits of the op.
|
1311 | 1381 | static AbstractOperation::OperationProperties getOperationProperties() {
|
@@ -1358,6 +1428,53 @@ class Op : public OpState,
|
1358 | 1428 | }
|
1359 | 1429 | };
|
1360 | 1430 |
|
| 1431 | + template <typename... Types> |
| 1432 | + struct BaseFolder; |
| 1433 | + |
| 1434 | + template <typename First, typename... Rest> |
| 1435 | + struct BaseFolder<First, Rest...> { |
| 1436 | + static OpFoldResult foldTraits(Operation *op, |
| 1437 | + ArrayRef<Attribute> operands) { |
| 1438 | + auto result = First::foldTrait(op, operands); |
| 1439 | + // Failure to fold or in place fold both mean we can continue folding. |
| 1440 | + if (!result || result.template dyn_cast<Value>() == op->getResult(0)) { |
| 1441 | + // Only consider the trait fold result if it is a success since |
| 1442 | + // the operation fold might have been a success originally. |
| 1443 | + auto resultRemaining = BaseFolder<Rest...>::foldTraits(op, operands); |
| 1444 | + if (resultRemaining) |
| 1445 | + result = resultRemaining; |
| 1446 | + } |
| 1447 | + |
| 1448 | + return result; |
| 1449 | + } |
| 1450 | + |
| 1451 | + static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands, |
| 1452 | + SmallVectorImpl<OpFoldResult> &results) { |
| 1453 | + auto result = First::foldTrait(op, operands, results); |
| 1454 | + // Failure to fold or in place fold both mean we can continue folding. |
| 1455 | + if (failed(result) || results.empty()) { |
| 1456 | + auto resultRemaining = |
| 1457 | + BaseFolder<Rest...>::foldTraits(op, operands, results); |
| 1458 | + if (succeeded(resultRemaining)) |
| 1459 | + result = resultRemaining; |
| 1460 | + } |
| 1461 | + |
| 1462 | + return result; |
| 1463 | + } |
| 1464 | + }; |
| 1465 | + |
| 1466 | + template <typename...> |
| 1467 | + struct BaseFolder { |
| 1468 | + static OpFoldResult foldTraits(Operation *op, |
| 1469 | + ArrayRef<Attribute> operands) { |
| 1470 | + return {}; |
| 1471 | + } |
| 1472 | + static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands, |
| 1473 | + SmallVectorImpl<OpFoldResult> &results) { |
| 1474 | + return failure(); |
| 1475 | + } |
| 1476 | + }; |
| 1477 | + |
1361 | 1478 | template <typename...> struct BaseProperties {
|
1362 | 1479 | static AbstractOperation::OperationProperties getTraitProperties() {
|
1363 | 1480 | return 0;
|
|
0 commit comments