Skip to content

Commit b86a9c5

Browse files
authored
[mlir][irdl] Lookup symbols near dialects instead of locally (#92819)
Because symbols cannot refer to operations outside of their symbol tables, it was impossible to refer to operations outside of the dialect currently being defined. This PR modifies the lookup logic to happen relative to the symbol table containing the dialect-defining operations. This is a bit of hack but should unblock the situation here.
1 parent ae86278 commit b86a9c5

File tree

10 files changed

+108
-29
lines changed

10 files changed

+108
-29
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//===- IRDLSymbols.h - IRDL-related symbol logic ----------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Manages lookup logic for IRDL dialect-absolute symbols.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
14+
#define MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
15+
16+
#include "mlir/IR/Operation.h"
17+
#include "mlir/IR/SymbolTable.h"
18+
19+
namespace mlir {
20+
namespace irdl {
21+
22+
/// Looks up a symbol from the symbol table containing the source operation's
23+
/// dialect definition operation. The source operation must be nested within an
24+
/// IRDL dialect definition operation. This exploits SymbolTableCollection for
25+
/// better symbol table lookup.
26+
Operation *lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
27+
Operation *source, SymbolRefAttr symbol);
28+
29+
/// Looks up a symbol from the symbol table containing the source operation's
30+
/// dialect definition operation. The source operation must be nested within an
31+
/// IRDL dialect definition operation.
32+
Operation *lookupSymbolNearDialect(Operation *source, SymbolRefAttr symbol);
33+
34+
} // namespace irdl
35+
} // namespace mlir
36+
37+
#endif // MLIR_DIALECT_IRDL_IRDLSYMBOLS_H

mlir/lib/Dialect/IRDL/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRIRDL
22
IR/IRDL.cpp
33
IR/IRDLOps.cpp
44
IRDLLoading.cpp
5+
IRDLSymbols.cpp
56
IRDLVerifiers.cpp
67

78
DEPENDS

mlir/lib/Dialect/IRDL/IR/IRDL.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/IRDL/IR/IRDL.h"
10+
#include "mlir/Dialect/IRDL/IRDLSymbols.h"
1011
#include "mlir/IR/Builders.h"
1112
#include "mlir/IR/BuiltinAttributes.h"
1213
#include "mlir/IR/Diagnostics.h"
@@ -132,10 +133,14 @@ LogicalResult BaseOp::verify() {
132133
return success();
133134
}
134135

136+
/// Finds whether the provided symbol is an IRDL type or attribute definition.
137+
/// The source operation must be within a DialectOp.
135138
static LogicalResult
136139
checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable,
137140
Operation *source, SymbolRefAttr symbol) {
138-
Operation *targetOp = symbolTable.lookupNearestSymbolFrom(source, symbol);
141+
Operation *targetOp =
142+
irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
143+
139144
if (!targetOp)
140145
return source->emitOpError() << "symbol '" << symbol << "' not found";
141146

mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/IRDL/IR/IRDL.h"
10+
#include "mlir/Dialect/IRDL/IRDLSymbols.h"
1011
#include "mlir/IR/ValueRange.h"
1112
#include <optional>
1213

@@ -47,8 +48,9 @@ std::unique_ptr<Constraint> BaseOp::getVerifier(
4748
// Case where the input is a symbol reference.
4849
// This corresponds to the case where the base is an IRDL type or attribute.
4950
if (auto baseRef = getBaseRef()) {
51+
// The verifier for BaseOp guarantees it is within a dialect.
5052
Operation *defOp =
51-
SymbolTable::lookupNearestSymbolFrom(getOperation(), baseRef.value());
53+
irdl::lookupSymbolNearDialect(getOperation(), baseRef.value());
5254

5355
// Type case.
5456
if (auto typeOp = dyn_cast<TypeOp>(defOp)) {
@@ -99,10 +101,10 @@ std::unique_ptr<Constraint> ParametricOp::getVerifier(
99101
SmallVector<unsigned> constraints =
100102
getConstraintIndicesForArgs(getArgs(), valueToConstr);
101103

102-
// Symbol reference case for the base
104+
// Symbol reference case for the base.
105+
// The verifier for ParametricOp guarantees it is within a dialect.
103106
SymbolRefAttr symRef = getBaseType();
104-
Operation *defOp =
105-
SymbolTable::lookupNearestSymbolFrom(getOperation(), symRef);
107+
Operation *defOp = irdl::lookupSymbolNearDialect(getOperation(), symRef);
106108
if (!defOp) {
107109
emitError() << symRef << " does not refer to any existing symbol";
108110
return nullptr;

mlir/lib/Dialect/IRDL/IRDLLoading.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/IRDL/IRDLLoading.h"
1414
#include "mlir/Dialect/IRDL/IR/IRDL.h"
1515
#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
16+
#include "mlir/Dialect/IRDL/IRDLSymbols.h"
1617
#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
1718
#include "mlir/IR/Attributes.h"
1819
#include "mlir/IR/BuiltinOps.h"
@@ -523,7 +524,7 @@ static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> &paramIds,
523524
// For `irdl.parametric`, we get directly the base from the operation.
524525
if (auto params = dyn_cast<ParametricOp>(op)) {
525526
SymbolRefAttr symRef = params.getBaseType();
526-
Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef);
527+
Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef);
527528
assert(defOp && "symbol reference should refer to an existing operation");
528529
paramIrdlOps.insert(defOp);
529530
return false;

mlir/lib/Dialect/IRDL/IRDLSymbols.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===- IRDLSymbols.cpp - IRDL-related symbol logic --------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/IRDL/IRDLSymbols.h"
10+
#include "mlir/Dialect/IRDL/IR/IRDL.h"
11+
12+
using namespace mlir;
13+
using namespace mlir::irdl;
14+
15+
static Operation *lookupDialectOp(Operation *source) {
16+
Operation *dialectOp = source;
17+
while (dialectOp && !isa<DialectOp>(dialectOp))
18+
dialectOp = dialectOp->getParentOp();
19+
20+
if (!dialectOp)
21+
llvm_unreachable("symbol lookup near dialect must originate from "
22+
"within a dialect definition");
23+
24+
return dialectOp;
25+
}
26+
27+
Operation *
28+
mlir::irdl::lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
29+
Operation *source, SymbolRefAttr symbol) {
30+
return symbolTable.lookupNearestSymbolFrom(
31+
lookupDialectOp(source)->getParentOp(), symbol);
32+
}
33+
34+
Operation *mlir::irdl::lookupSymbolNearDialect(Operation *source,
35+
SymbolRefAttr symbol) {
36+
return SymbolTable::lookupNearestSymbolFrom(
37+
lookupDialectOp(source)->getParentOp(), symbol);
38+
}

mlir/test/Dialect/IRDL/cmath.irdl.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ module {
1919

2020
// CHECK: irdl.operation @norm {
2121
// CHECK: %[[v0:[^ ]*]] = irdl.any
22-
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @complex<%[[v0]]>
22+
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v0]]>
2323
// CHECK: irdl.operands(%[[v1]])
2424
// CHECK: irdl.results(%[[v0]])
2525
// CHECK: }
2626
irdl.operation @norm {
2727
%0 = irdl.any
28-
%1 = irdl.parametric @complex<%0>
28+
%1 = irdl.parametric @cmath::@complex<%0>
2929
irdl.operands(%1)
3030
irdl.results(%0)
3131
}
@@ -34,15 +34,15 @@ module {
3434
// CHECK: %[[v0:[^ ]*]] = irdl.is f32
3535
// CHECK: %[[v1:[^ ]*]] = irdl.is f64
3636
// CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
37-
// CHECK: %[[v3:[^ ]*]] = irdl.parametric @complex<%[[v2]]>
37+
// CHECK: %[[v3:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v2]]>
3838
// CHECK: irdl.operands(%[[v3]], %[[v3]])
3939
// CHECK: irdl.results(%[[v3]])
4040
// CHECK: }
4141
irdl.operation @mul {
4242
%0 = irdl.is f32
4343
%1 = irdl.is f64
4444
%2 = irdl.any_of(%0, %1)
45-
%3 = irdl.parametric @complex<%2>
45+
%3 = irdl.parametric @cmath::@complex<%2>
4646
irdl.operands(%3, %3)
4747
irdl.results(%3)
4848
}

mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
irdl.dialect @testd {
77
// CHECK: irdl.type @self_referencing {
88
// CHECK: %[[v0:[^ ]*]] = irdl.any
9-
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @self_referencing<%[[v0]]>
9+
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@self_referencing<%[[v0]]>
1010
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
1111
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
1212
// CHECK: irdl.parameters(%[[v3]])
1313
// CHECK: }
1414
irdl.type @self_referencing {
1515
%0 = irdl.any
16-
%1 = irdl.parametric @self_referencing<%0>
16+
%1 = irdl.parametric @testd::@self_referencing<%0>
1717
%2 = irdl.is i32
1818
%3 = irdl.any_of(%1, %2)
1919
irdl.parameters(%3)
@@ -22,27 +22,27 @@ irdl.dialect @testd {
2222

2323
// CHECK: irdl.type @type1 {
2424
// CHECK: %[[v0:[^ ]*]] = irdl.any
25-
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @type2<%[[v0]]>
25+
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@type2<%[[v0]]>
2626
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
2727
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
2828
// CHECK: irdl.parameters(%[[v3]])
2929
irdl.type @type1 {
3030
%0 = irdl.any
31-
%1 = irdl.parametric @type2<%0>
31+
%1 = irdl.parametric @testd::@type2<%0>
3232
%2 = irdl.is i32
3333
%3 = irdl.any_of(%1, %2)
3434
irdl.parameters(%3)
3535
}
3636

3737
// CHECK: irdl.type @type2 {
3838
// CHECK: %[[v0:[^ ]*]] = irdl.any
39-
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @type1<%[[v0]]>
39+
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@type1<%[[v0]]>
4040
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
4141
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
4242
// CHECK: irdl.parameters(%[[v3]])
4343
irdl.type @type2 {
4444
%0 = irdl.any
45-
%1 = irdl.parametric @type1<%0>
45+
%1 = irdl.parametric @testd::@type1<%0>
4646
%2 = irdl.is i32
4747
%3 = irdl.any_of(%1, %2)
4848
irdl.parameters(%3)

mlir/test/Dialect/IRDL/invalid.irdl.mlir

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
// Testing invalid IRDL IRs
44

5-
func.func private @foo()
6-
75
irdl.dialect @testd {
86
irdl.type @type {
97
// expected-error@+1 {{symbol '@foo' not found}}
@@ -44,15 +42,12 @@ irdl.dialect @testd {
4442

4543
// -----
4644

45+
func.func private @not_a_type_or_attr()
46+
4747
irdl.dialect @invalid_parametric {
4848
irdl.operation @foo {
4949
// expected-error@+1 {{symbol '@not_a_type_or_attr' does not refer to a type or attribute definition}}
5050
%param = irdl.parametric @not_a_type_or_attr<>
5151
irdl.results(%param)
5252
}
53-
54-
irdl.operation @not_a_type_or_attr {
55-
%param = irdl.is i1
56-
irdl.results(%param)
57-
}
5853
}

mlir/test/Dialect/IRDL/testd.irdl.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,20 @@ irdl.dialect @testd {
7676
}
7777

7878
// CHECK: irdl.operation @dyn_type_base {
79-
// CHECK: %[[v1:[^ ]*]] = irdl.base @parametric
79+
// CHECK: %[[v1:[^ ]*]] = irdl.base @testd::@parametric
8080
// CHECK: irdl.results(%[[v1]])
8181
// CHECK: }
8282
irdl.operation @dyn_type_base {
83-
%0 = irdl.base @parametric
83+
%0 = irdl.base @testd::@parametric
8484
irdl.results(%0)
8585
}
8686

8787
// CHECK: irdl.operation @dyn_attr_base {
88-
// CHECK: %[[v1:[^ ]*]] = irdl.base @parametric_attr
88+
// CHECK: %[[v1:[^ ]*]] = irdl.base @testd::@parametric_attr
8989
// CHECK: irdl.attributes {"attr1" = %[[v1]]}
9090
// CHECK: }
9191
irdl.operation @dyn_attr_base {
92-
%0 = irdl.base @parametric_attr
92+
%0 = irdl.base @testd::@parametric_attr
9393
irdl.attributes {"attr1" = %0}
9494
}
9595

@@ -115,14 +115,14 @@ irdl.dialect @testd {
115115
// CHECK: %[[v0:[^ ]*]] = irdl.is i32
116116
// CHECK: %[[v1:[^ ]*]] = irdl.is i64
117117
// CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
118-
// CHECK: %[[v3:[^ ]*]] = irdl.parametric @parametric<%[[v2]]>
118+
// CHECK: %[[v3:[^ ]*]] = irdl.parametric @testd::@parametric<%[[v2]]>
119119
// CHECK: irdl.results(%[[v3]])
120120
// CHECK: }
121121
irdl.operation @dynparams {
122122
%0 = irdl.is i32
123123
%1 = irdl.is i64
124124
%2 = irdl.any_of(%0, %1)
125-
%3 = irdl.parametric @parametric<%2>
125+
%3 = irdl.parametric @testd::@parametric<%2>
126126
irdl.results(%3)
127127
}
128128

0 commit comments

Comments
 (0)