|
33 | 33 | #include "mlir/IR/BuiltinOps.h"
|
34 | 34 | #include "mlir/IR/IRMapping.h"
|
35 | 35 | #include "mlir/IR/PatternMatch.h"
|
| 36 | +#include "mlir/IR/SymbolTable.h" |
36 | 37 | #include "mlir/IR/TypeUtilities.h"
|
37 | 38 | #include "mlir/Support/LogicalResult.h"
|
38 | 39 | #include "mlir/Support/MathExtras.h"
|
|
48 | 49 | #include "llvm/Support/FormatVariadic.h"
|
49 | 50 | #include <algorithm>
|
50 | 51 | #include <functional>
|
| 52 | +#include <optional> |
51 | 53 |
|
52 | 54 | namespace mlir {
|
53 | 55 | #define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
|
@@ -601,19 +603,38 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
|
601 | 603 | }
|
602 | 604 | };
|
603 | 605 |
|
604 |
| -struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> { |
605 |
| - using Super::Super; |
| 606 | +class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> { |
| 607 | +public: |
| 608 | + CallOpLowering(const LLVMTypeConverter &typeConverter, |
| 609 | + // Can be nullptr. |
| 610 | + const SymbolTable *symbolTable, PatternBenefit benefit = 1) |
| 611 | + : CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit), |
| 612 | + symbolTable(symbolTable) {} |
606 | 613 |
|
607 | 614 | LogicalResult
|
608 | 615 | matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
|
609 | 616 | ConversionPatternRewriter &rewriter) const override {
|
610 | 617 | bool useBarePtrCallConv = false;
|
611 |
| - if (Operation *callee = SymbolTable::lookupNearestSymbolFrom( |
612 |
| - callOp, callOp.getCalleeAttr())) { |
613 |
| - useBarePtrCallConv = shouldUseBarePtrCallConv(callee, getTypeConverter()); |
| 618 | + if (getTypeConverter()->getOptions().useBarePtrCallConv) { |
| 619 | + useBarePtrCallConv = true; |
| 620 | + } else if (symbolTable != nullptr) { |
| 621 | + // Fast lookup. |
| 622 | + Operation *callee = |
| 623 | + symbolTable->lookup(callOp.getCalleeAttr().getValue()); |
| 624 | + useBarePtrCallConv = |
| 625 | + callee != nullptr && callee->hasAttr(barePtrAttrName); |
| 626 | + } else { |
| 627 | + // Warning: This is a linear lookup. |
| 628 | + Operation *callee = |
| 629 | + SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr()); |
| 630 | + useBarePtrCallConv = |
| 631 | + callee != nullptr && callee->hasAttr(barePtrAttrName); |
614 | 632 | }
|
615 | 633 | return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
|
616 | 634 | }
|
| 635 | + |
| 636 | +private: |
| 637 | + const SymbolTable *symbolTable = nullptr; |
617 | 638 | };
|
618 | 639 |
|
619 | 640 | struct CallIndirectOpLowering
|
@@ -728,16 +749,14 @@ void mlir::populateFuncToLLVMFuncOpConversionPattern(
|
728 | 749 | patterns.add<FuncOpConversion>(converter);
|
729 | 750 | }
|
730 | 751 |
|
731 |
| -void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter, |
732 |
| - RewritePatternSet &patterns) { |
| 752 | +void mlir::populateFuncToLLVMConversionPatterns( |
| 753 | + LLVMTypeConverter &converter, RewritePatternSet &patterns, |
| 754 | + const SymbolTable *symbolTable) { |
733 | 755 | populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
|
734 |
| - // clang-format off |
735 |
| - patterns.add< |
736 |
| - CallIndirectOpLowering, |
737 |
| - CallOpLowering, |
738 |
| - ConstantOpLowering, |
739 |
| - ReturnOpLowering>(converter); |
740 |
| - // clang-format on |
| 756 | + patterns.add<CallIndirectOpLowering>(converter); |
| 757 | + patterns.add<CallOpLowering>(converter, symbolTable); |
| 758 | + patterns.add<ConstantOpLowering>(converter); |
| 759 | + patterns.add<ReturnOpLowering>(converter); |
741 | 760 | }
|
742 | 761 |
|
743 | 762 | namespace {
|
@@ -776,8 +795,15 @@ struct ConvertFuncToLLVMPass
|
776 | 795 | LLVMTypeConverter typeConverter(&getContext(), options,
|
777 | 796 | &dataLayoutAnalysis);
|
778 | 797 |
|
| 798 | + std::optional<SymbolTable> optSymbolTable = std::nullopt; |
| 799 | + const SymbolTable *symbolTable = nullptr; |
| 800 | + if (!options.useBarePtrCallConv) { |
| 801 | + optSymbolTable.emplace(m); |
| 802 | + symbolTable = &optSymbolTable.value(); |
| 803 | + } |
| 804 | + |
779 | 805 | RewritePatternSet patterns(&getContext());
|
780 |
| - populateFuncToLLVMConversionPatterns(typeConverter, patterns); |
| 806 | + populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable); |
781 | 807 |
|
782 | 808 | // TODO: Remove these in favor of their dedicated conversion passes.
|
783 | 809 | arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
|
|
0 commit comments