Skip to content

Commit a608830

Browse files
authored
[mlir] Speed up FuncToLLVM using a SymbolTable (#68082)
We have a project where this saves 23% of the compilation time. This means using hashmaps instead of searching in linked lists.
1 parent d3e4702 commit a608830

File tree

2 files changed

+54
-17
lines changed

2 files changed

+54
-17
lines changed

mlir/include/mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h

+13-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ namespace mlir {
1919
class DialectRegistry;
2020
class LLVMTypeConverter;
2121
class RewritePatternSet;
22+
class SymbolTable;
2223

2324
/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If
2425
/// `emitCWrappers` is set, the pattern will also produce functions
@@ -31,8 +32,18 @@ void populateFuncToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter,
3132
/// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions
3233
/// by reference meaning the references have to remain alive during the entire
3334
/// pattern lifetime.
34-
void populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter,
35-
RewritePatternSet &patterns);
35+
///
36+
/// The `symbolTable` parameter can be used to speed up function lookups in the
37+
/// module. It's good to provide it, but only if we know that the patterns will
38+
/// be applied to a single module and the symbols referenced by the symbol table
39+
/// will not be removed and new symbols will not be added during the usage of
40+
/// the patterns. If provided, the lookups will have O(calls) cumulative
41+
/// runtime, otherwise O(calls * functions). The symbol table is currently not
42+
/// needed if `converter.getOptions().useBarePtrCallConv` is `true`, but it's
43+
/// not an error to provide it anyway.
44+
void populateFuncToLLVMConversionPatterns(
45+
LLVMTypeConverter &converter, RewritePatternSet &patterns,
46+
const SymbolTable *symbolTable = nullptr);
3647

3748
void registerConvertFuncToLLVMInterface(DialectRegistry &registry);
3849

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

+41-15
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/IR/BuiltinOps.h"
3434
#include "mlir/IR/IRMapping.h"
3535
#include "mlir/IR/PatternMatch.h"
36+
#include "mlir/IR/SymbolTable.h"
3637
#include "mlir/IR/TypeUtilities.h"
3738
#include "mlir/Support/LogicalResult.h"
3839
#include "mlir/Support/MathExtras.h"
@@ -48,6 +49,7 @@
4849
#include "llvm/Support/FormatVariadic.h"
4950
#include <algorithm>
5051
#include <functional>
52+
#include <optional>
5153

5254
namespace mlir {
5355
#define GEN_PASS_DEF_CONVERTFUNCTOLLVMPASS
@@ -601,19 +603,38 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
601603
}
602604
};
603605

604-
struct CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
605-
using Super::Super;
606+
class CallOpLowering : public CallOpInterfaceLowering<func::CallOp> {
607+
public:
608+
CallOpLowering(const LLVMTypeConverter &typeConverter,
609+
// Can be nullptr.
610+
const SymbolTable *symbolTable, PatternBenefit benefit = 1)
611+
: CallOpInterfaceLowering<func::CallOp>(typeConverter, benefit),
612+
symbolTable(symbolTable) {}
606613

607614
LogicalResult
608615
matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
609616
ConversionPatternRewriter &rewriter) const override {
610617
bool useBarePtrCallConv = false;
611-
if (Operation *callee = SymbolTable::lookupNearestSymbolFrom(
612-
callOp, callOp.getCalleeAttr())) {
613-
useBarePtrCallConv = shouldUseBarePtrCallConv(callee, getTypeConverter());
618+
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
619+
useBarePtrCallConv = true;
620+
} else if (symbolTable != nullptr) {
621+
// Fast lookup.
622+
Operation *callee =
623+
symbolTable->lookup(callOp.getCalleeAttr().getValue());
624+
useBarePtrCallConv =
625+
callee != nullptr && callee->hasAttr(barePtrAttrName);
626+
} else {
627+
// Warning: This is a linear lookup.
628+
Operation *callee =
629+
SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getCalleeAttr());
630+
useBarePtrCallConv =
631+
callee != nullptr && callee->hasAttr(barePtrAttrName);
614632
}
615633
return matchAndRewriteImpl(callOp, adaptor, rewriter, useBarePtrCallConv);
616634
}
635+
636+
private:
637+
const SymbolTable *symbolTable = nullptr;
617638
};
618639

619640
struct CallIndirectOpLowering
@@ -728,16 +749,14 @@ void mlir::populateFuncToLLVMFuncOpConversionPattern(
728749
patterns.add<FuncOpConversion>(converter);
729750
}
730751

731-
void mlir::populateFuncToLLVMConversionPatterns(LLVMTypeConverter &converter,
732-
RewritePatternSet &patterns) {
752+
void mlir::populateFuncToLLVMConversionPatterns(
753+
LLVMTypeConverter &converter, RewritePatternSet &patterns,
754+
const SymbolTable *symbolTable) {
733755
populateFuncToLLVMFuncOpConversionPattern(converter, patterns);
734-
// clang-format off
735-
patterns.add<
736-
CallIndirectOpLowering,
737-
CallOpLowering,
738-
ConstantOpLowering,
739-
ReturnOpLowering>(converter);
740-
// clang-format on
756+
patterns.add<CallIndirectOpLowering>(converter);
757+
patterns.add<CallOpLowering>(converter, symbolTable);
758+
patterns.add<ConstantOpLowering>(converter);
759+
patterns.add<ReturnOpLowering>(converter);
741760
}
742761

743762
namespace {
@@ -776,8 +795,15 @@ struct ConvertFuncToLLVMPass
776795
LLVMTypeConverter typeConverter(&getContext(), options,
777796
&dataLayoutAnalysis);
778797

798+
std::optional<SymbolTable> optSymbolTable = std::nullopt;
799+
const SymbolTable *symbolTable = nullptr;
800+
if (!options.useBarePtrCallConv) {
801+
optSymbolTable.emplace(m);
802+
symbolTable = &optSymbolTable.value();
803+
}
804+
779805
RewritePatternSet patterns(&getContext());
780-
populateFuncToLLVMConversionPatterns(typeConverter, patterns);
806+
populateFuncToLLVMConversionPatterns(typeConverter, patterns, symbolTable);
781807

782808
// TODO: Remove these in favor of their dedicated conversion passes.
783809
arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);

0 commit comments

Comments
 (0)