Skip to content

Commit 4790578

Browse files
[mlir] Make overloads of SymbolTable::replaceAllSymbolUses consistent. (#68320)
This function has several overloads that allow to specify the symbol that should be renamed and the scope for that renaming in different ways. The overloads were inconsistent in the following way (quoted strings are `StringAttr`s, other variables are `Operation *`): * `replaceAllSymbolUses(symbolOp, "new_symbol", scopeOp)` would traverse into the nested regions of `scopeOp` and hence rename the symbol inside of `scopeOp`. * `replaceAllSymbolUses("symbol", "new_symbol", scopeOp)` would *not* traverse into the nested regions of `scopeOp` and hence *not* rename the symbol. The underlying behavior was spread over different places and is somewhat hard to understand. The two overloads above mainly differed by what `collectSymbolScopes` computed, which is itself overloaded. If `scopeOp` is a top-level module, then the overload on `(Operation *, Operation *)`, which is used in the first of the above cases, computes a scope where the body region of the module is the `limit`; however, the overload on `(StringAttr, Operation *)` computed the module op itself as the `limit`. Later, `walkSymbolTable` would walk the body of the module if it was given as a region but it would *not* enter the regions of the module op because that op has a symbol table (which was assumed to be a *different* scope). The fix in this commit is change the behavior of `collectSymbolScopes` such that the `(StringAttr, Operation *)` overload returns a scope for each region in the `limit` argument.
1 parent 0d0f219 commit 4790578

File tree

4 files changed

+160
-4
lines changed

4 files changed

+160
-4
lines changed

mlir/lib/IR/SymbolTable.cpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -729,12 +729,20 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
729729
scopes.back().limit = limit;
730730
return scopes;
731731
}
732-
template <typename IRUnit>
733732
static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
734-
IRUnit *limit) {
733+
Region *limit) {
735734
return {{SymbolRefAttr::get(symbol), limit}};
736735
}
737736

737+
static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
738+
Operation *limit) {
739+
SmallVector<SymbolScope, 1> scopes;
740+
auto symbolRef = SymbolRefAttr::get(symbol);
741+
for (auto &region : limit->getRegions())
742+
scopes.push_back({symbolRef, &region});
743+
return scopes;
744+
}
745+
738746
/// Returns true if the given reference 'SubRef' is a sub reference of the
739747
/// reference 'ref', i.e. 'ref' is a further qualified reference.
740748
static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {

mlir/test/python/ir/symbol_table.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ def testSymbolTableRAUW():
106106
"""
107107
)
108108
foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
109+
110+
# Do renaming just within `foo`.
109111
SymbolTable.set_symbol_name(bar, "bam")
110-
# Note that module.operation counts as a "nested symbol table" which won't
111-
# be traversed into, so it is necessary to traverse its children.
112112
SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
113113
# CHECK: call @bam()
114114
# CHECK: func private @bam
@@ -118,6 +118,17 @@ def testSymbolTableRAUW():
118118
print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
119119
print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
120120

121+
# Do renaming within the module.
122+
SymbolTable.set_symbol_name(bar, "baz")
123+
SymbolTable.replace_all_symbol_uses("bam", "baz", m.operation)
124+
# CHECK: call @baz()
125+
# CHECK: func private @baz
126+
print(m)
127+
# CHECK: Foo symbol: StringAttr("foo")
128+
# CHECK: Bar symbol: StringAttr("baz")
129+
print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
130+
print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
131+
121132

122133
# CHECK-LABEL: testSymbolTableVisibility
123134
@run

mlir/unittests/IR/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_unittest(MLIRIRTests
88
OperationSupportTest.cpp
99
PatternMatchTest.cpp
1010
ShapedTypeTest.cpp
11+
SymbolTableTest.cpp
1112
TypeTest.cpp
1213
OpPropertiesTest.cpp
1314

mlir/unittests/IR/SymbolTableTest.cpp

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
//===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#include "mlir/IR/SymbolTable.h"
9+
#include "mlir/IR/BuiltinOps.h"
10+
#include "mlir/IR/Verifier.h"
11+
#include "mlir/Interfaces/CallInterfaces.h"
12+
#include "mlir/Interfaces/FunctionInterfaces.h"
13+
#include "mlir/Parser/Parser.h"
14+
15+
#include "gtest/gtest.h"
16+
17+
using namespace mlir;
18+
19+
namespace test {
20+
void registerTestDialect(DialectRegistry &);
21+
} // namespace test
22+
23+
class ReplaceAllSymbolUsesTest : public ::testing::Test {
24+
protected:
25+
using ReplaceFnType = llvm::function_ref<LogicalResult(
26+
SymbolTable, ModuleOp, Operation *, Operation *)>;
27+
28+
void SetUp() override {
29+
::test::registerTestDialect(registry);
30+
context = std::make_unique<MLIRContext>(registry);
31+
}
32+
33+
void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
34+
// Set up IR and find func ops.
35+
OwningOpRef<ModuleOp> module =
36+
parseSourceString<ModuleOp>(kInput, context.get());
37+
SymbolTable symbolTable(module.get());
38+
auto opIterator = module->getBody(0)->getOperations().begin();
39+
auto fooOp = cast<FunctionOpInterface>(opIterator++);
40+
auto barOp = cast<FunctionOpInterface>(opIterator++);
41+
ASSERT_EQ(fooOp.getNameAttr(), "foo");
42+
ASSERT_EQ(barOp.getNameAttr(), "bar");
43+
44+
// Call test function that does symbol replacement.
45+
LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp);
46+
ASSERT_TRUE(succeeded(res));
47+
ASSERT_TRUE(succeeded(verify(module.get())));
48+
49+
// Check that it got renamed.
50+
bool calleeFound = false;
51+
fooOp->walk([&](CallOpInterface callOp) {
52+
StringAttr callee = callOp.getCallableForCallee()
53+
.dyn_cast<SymbolRefAttr>()
54+
.getLeafReference();
55+
EXPECT_EQ(callee, "baz");
56+
calleeFound = true;
57+
});
58+
EXPECT_TRUE(calleeFound);
59+
}
60+
61+
std::unique_ptr<MLIRContext> context;
62+
63+
private:
64+
constexpr static llvm::StringLiteral kInput = R"MLIR(
65+
module {
66+
test.conversion_func_op private @foo() {
67+
"test.conversion_call_op"() { callee=@bar } : () -> ()
68+
"test.return"() : () -> ()
69+
}
70+
test.conversion_func_op private @bar()
71+
}
72+
)MLIR";
73+
74+
DialectRegistry registry;
75+
};
76+
77+
namespace {
78+
79+
TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
80+
// Symbol as `Operation *`, rename within module.
81+
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
82+
auto barOp) -> LogicalResult {
83+
return symbolTable.replaceAllSymbolUses(
84+
barOp, StringAttr::get(context.get(), "baz"), module);
85+
});
86+
}
87+
88+
TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
89+
// Symbol as `StringAttr`, rename within module.
90+
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
91+
auto barOp) -> LogicalResult {
92+
return symbolTable.replaceAllSymbolUses(
93+
StringAttr::get(context.get(), "bar"),
94+
StringAttr::get(context.get(), "baz"), module);
95+
});
96+
}
97+
98+
TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
99+
// Symbol as `Operation *`, rename within module body.
100+
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
101+
auto barOp) -> LogicalResult {
102+
return symbolTable.replaceAllSymbolUses(
103+
barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
104+
});
105+
}
106+
107+
TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
108+
// Symbol as `StringAttr`, rename within module body.
109+
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
110+
auto barOp) -> LogicalResult {
111+
return symbolTable.replaceAllSymbolUses(
112+
StringAttr::get(context.get(), "bar"),
113+
StringAttr::get(context.get(), "baz"), &module->getRegion(0));
114+
});
115+
}
116+
117+
TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
118+
// Symbol as `Operation *`, rename within function.
119+
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
120+
auto barOp) -> LogicalResult {
121+
return symbolTable.replaceAllSymbolUses(
122+
barOp, StringAttr::get(context.get(), "baz"), fooOp);
123+
});
124+
}
125+
126+
TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
127+
// Symbol as `StringAttr`, rename within function.
128+
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
129+
auto barOp) -> LogicalResult {
130+
return symbolTable.replaceAllSymbolUses(
131+
StringAttr::get(context.get(), "bar"),
132+
StringAttr::get(context.get(), "baz"), fooOp);
133+
});
134+
}
135+
136+
} // namespace

0 commit comments

Comments
 (0)