Skip to content

Commit f6ccbe2

Browse files
committed
lookup symbols near dialects instead of locally
1 parent 1eb7f05 commit f6ccbe2

File tree

10 files changed

+108
-29
lines changed

10 files changed

+108
-29
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//===- IRDLSymbols.h - IRDL-related symbol logic ----------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Manages lookup logic for IRDL dialect-absolute symbols.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
14+
#define MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
15+
16+
#include "mlir/IR/Operation.h"
17+
#include "mlir/IR/SymbolTable.h"
18+
19+
namespace mlir {
20+
namespace irdl {
21+
22+
/// Looks up a symbol from the symbol table containing the source operation's
23+
/// dialect definition operation. The source operation must be nested within an
24+
/// IRDL dialect definition operation. This exploits SymbolTableCollection for
25+
/// better symbol table lookup.
26+
Operation *lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
27+
Operation *source, SymbolRefAttr symbol);
28+
29+
/// Looks up a symbol from the symbol table containing the source operation's
30+
/// dialect definition operation. The source operation must be nested within an
31+
/// IRDL dialect definition operation.
32+
Operation *lookupSymbolNearDialect(Operation *source, SymbolRefAttr symbol);
33+
34+
} // namespace irdl
35+
} // namespace mlir
36+
37+
#endif // MLIR_DIALECT_IRDL_IRDLSYMBOLS_H

mlir/lib/Dialect/IRDL/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRIRDL
22
IR/IRDL.cpp
33
IR/IRDLOps.cpp
44
IRDLLoading.cpp
5+
IRDLSymbols.cpp
56
IRDLVerifiers.cpp
67

78
DEPENDS

mlir/lib/Dialect/IRDL/IR/IRDL.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/IRDL/IR/IRDL.h"
10+
#include "mlir/Dialect/IRDL/IRDLSymbols.h"
1011
#include "mlir/IR/Builders.h"
1112
#include "mlir/IR/BuiltinAttributes.h"
1213
#include "mlir/IR/Diagnostics.h"
@@ -132,10 +133,14 @@ LogicalResult BaseOp::verify() {
132133
return success();
133134
}
134135

136+
/// Finds whether the provided symbol is an IRDL type or attribute definition.
137+
/// The source operation must be within a DialectOp.
135138
static LogicalResult
136139
checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable,
137140
Operation *source, SymbolRefAttr symbol) {
138-
Operation *targetOp = symbolTable.lookupNearestSymbolFrom(source, symbol);
141+
Operation *targetOp =
142+
irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
143+
139144
if (!targetOp)
140145
return source->emitOpError() << "symbol '" << symbol << "' not found";
141146

mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/IRDL/IR/IRDL.h"
10+
#include "mlir/Dialect/IRDL/IRDLSymbols.h"
1011
#include "mlir/IR/ValueRange.h"
1112
#include <optional>
1213

@@ -47,8 +48,9 @@ std::unique_ptr<Constraint> BaseOp::getVerifier(
4748
// Case where the input is a symbol reference.
4849
// This corresponds to the case where the base is an IRDL type or attribute.
4950
if (auto baseRef = getBaseRef()) {
51+
// The verifier for BaseOp guarantees it is within a dialect.
5052
Operation *defOp =
51-
SymbolTable::lookupNearestSymbolFrom(getOperation(), baseRef.value());
53+
irdl::lookupSymbolNearDialect(getOperation(), baseRef.value());
5254

5355
// Type case.
5456
if (auto typeOp = dyn_cast<TypeOp>(defOp)) {
@@ -99,10 +101,10 @@ std::unique_ptr<Constraint> ParametricOp::getVerifier(
99101
SmallVector<unsigned> constraints =
100102
getConstraintIndicesForArgs(getArgs(), valueToConstr);
101103

102-
// Symbol reference case for the base
104+
// Symbol reference case for the base.
105+
// The verifier for ParametricOp guarantees it is within a dialect.
103106
SymbolRefAttr symRef = getBaseType();
104-
Operation *defOp =
105-
SymbolTable::lookupNearestSymbolFrom(getOperation(), symRef);
107+
Operation *defOp = irdl::lookupSymbolNearDialect(getOperation(), symRef);
106108
if (!defOp) {
107109
emitError() << symRef << " does not refer to any existing symbol";
108110
return nullptr;

mlir/lib/Dialect/IRDL/IRDLLoading.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Dialect/IRDL/IRDLLoading.h"
1414
#include "mlir/Dialect/IRDL/IR/IRDL.h"
1515
#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
16+
#include "mlir/Dialect/IRDL/IRDLSymbols.h"
1617
#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
1718
#include "mlir/IR/Attributes.h"
1819
#include "mlir/IR/BuiltinOps.h"
@@ -523,7 +524,7 @@ static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> &paramIds,
523524
// For `irdl.parametric`, we get directly the base from the operation.
524525
if (auto params = dyn_cast<ParametricOp>(op)) {
525526
SymbolRefAttr symRef = params.getBaseType();
526-
Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef);
527+
Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef);
527528
assert(defOp && "symbol reference should refer to an existing operation");
528529
paramIrdlOps.insert(defOp);
529530
return false;

mlir/lib/Dialect/IRDL/IRDLSymbols.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
//===- IRDLSymbols.cpp - IRDL-related symbol logic --------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/IRDL/IRDLSymbols.h"
10+
#include "mlir/Dialect/IRDL/IR/IRDL.h"
11+
12+
using namespace mlir;
13+
using namespace mlir::irdl;
14+
15+
static Operation *lookupDialectOp(Operation *source) {
16+
Operation *dialectOp = source;
17+
while (dialectOp && !isa<DialectOp>(dialectOp))
18+
dialectOp = dialectOp->getParentOp();
19+
20+
if (!dialectOp)
21+
llvm_unreachable("symbol lookup near dialect must originate from "
22+
"within a dialect definition");
23+
24+
return dialectOp;
25+
}
26+
27+
Operation *
28+
mlir::irdl::lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
29+
Operation *source, SymbolRefAttr symbol) {
30+
return symbolTable.lookupNearestSymbolFrom(
31+
lookupDialectOp(source)->getParentOp(), symbol);
32+
}
33+
34+
Operation *mlir::irdl::lookupSymbolNearDialect(Operation *source,
35+
SymbolRefAttr symbol) {
36+
return SymbolTable::lookupNearestSymbolFrom(
37+
lookupDialectOp(source)->getParentOp(), symbol);
38+
}

mlir/test/Dialect/IRDL/cmath.irdl.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ module {
1919

2020
// CHECK: irdl.operation @norm {
2121
// CHECK: %[[v0:[^ ]*]] = irdl.any
22-
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @complex<%[[v0]]>
22+
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v0]]>
2323
// CHECK: irdl.operands(%[[v1]])
2424
// CHECK: irdl.results(%[[v0]])
2525
// CHECK: }
2626
irdl.operation @norm {
2727
%0 = irdl.any
28-
%1 = irdl.parametric @complex<%0>
28+
%1 = irdl.parametric @cmath::@complex<%0>
2929
irdl.operands(%1)
3030
irdl.results(%0)
3131
}
@@ -34,15 +34,15 @@ module {
3434
// CHECK: %[[v0:[^ ]*]] = irdl.is f32
3535
// CHECK: %[[v1:[^ ]*]] = irdl.is f64
3636
// CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
37-
// CHECK: %[[v3:[^ ]*]] = irdl.parametric @complex<%[[v2]]>
37+
// CHECK: %[[v3:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v2]]>
3838
// CHECK: irdl.operands(%[[v3]], %[[v3]])
3939
// CHECK: irdl.results(%[[v3]])
4040
// CHECK: }
4141
irdl.operation @mul {
4242
%0 = irdl.is f32
4343
%1 = irdl.is f64
4444
%2 = irdl.any_of(%0, %1)
45-
%3 = irdl.parametric @complex<%2>
45+
%3 = irdl.parametric @cmath::@complex<%2>
4646
irdl.operands(%3, %3)
4747
irdl.results(%3)
4848
}

mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
irdl.dialect @testd {
77
// CHECK: irdl.type @self_referencing {
88
// CHECK: %[[v0:[^ ]*]] = irdl.any
9-
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @self_referencing<%[[v0]]>
9+
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@self_referencing<%[[v0]]>
1010
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
1111
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
1212
// CHECK: irdl.parameters(%[[v3]])
1313
// CHECK: }
1414
irdl.type @self_referencing {
1515
%0 = irdl.any
16-
%1 = irdl.parametric @self_referencing<%0>
16+
%1 = irdl.parametric @testd::@self_referencing<%0>
1717
%2 = irdl.is i32
1818
%3 = irdl.any_of(%1, %2)
1919
irdl.parameters(%3)
@@ -22,27 +22,27 @@ irdl.dialect @testd {
2222

2323
// CHECK: irdl.type @type1 {
2424
// CHECK: %[[v0:[^ ]*]] = irdl.any
25-
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @type2<%[[v0]]>
25+
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@type2<%[[v0]]>
2626
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
2727
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
2828
// CHECK: irdl.parameters(%[[v3]])
2929
irdl.type @type1 {
3030
%0 = irdl.any
31-
%1 = irdl.parametric @type2<%0>
31+
%1 = irdl.parametric @testd::@type2<%0>
3232
%2 = irdl.is i32
3333
%3 = irdl.any_of(%1, %2)
3434
irdl.parameters(%3)
3535
}
3636

3737
// CHECK: irdl.type @type2 {
3838
// CHECK: %[[v0:[^ ]*]] = irdl.any
39-
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @type1<%[[v0]]>
39+
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@type1<%[[v0]]>
4040
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
4141
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
4242
// CHECK: irdl.parameters(%[[v3]])
4343
irdl.type @type2 {
4444
%0 = irdl.any
45-
%1 = irdl.parametric @type1<%0>
45+
%1 = irdl.parametric @testd::@type1<%0>
4646
%2 = irdl.is i32
4747
%3 = irdl.any_of(%1, %2)
4848
irdl.parameters(%3)

mlir/test/Dialect/IRDL/invalid.irdl.mlir

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
// Testing invalid IRDL IRs
44

5-
func.func private @foo()
6-
75
irdl.dialect @testd {
86
irdl.type @type {
97
// expected-error@+1 {{symbol '@foo' not found}}
@@ -44,15 +42,12 @@ irdl.dialect @testd {
4442

4543
// -----
4644

45+
func.func private @not_a_type_or_attr()
46+
4747
irdl.dialect @invalid_parametric {
4848
irdl.operation @foo {
4949
// expected-error@+1 {{symbol '@not_a_type_or_attr' does not refer to a type or attribute definition}}
5050
%param = irdl.parametric @not_a_type_or_attr<>
5151
irdl.results(%param)
5252
}
53-
54-
irdl.operation @not_a_type_or_attr {
55-
%param = irdl.is i1
56-
irdl.results(%param)
57-
}
5853
}

mlir/test/Dialect/IRDL/testd.irdl.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,20 @@ irdl.dialect @testd {
7676
}
7777

7878
// CHECK: irdl.operation @dyn_type_base {
79-
// CHECK: %[[v1:[^ ]*]] = irdl.base @parametric
79+
// CHECK: %[[v1:[^ ]*]] = irdl.base @testd::@parametric
8080
// CHECK: irdl.results(%[[v1]])
8181
// CHECK: }
8282
irdl.operation @dyn_type_base {
83-
%0 = irdl.base @parametric
83+
%0 = irdl.base @testd::@parametric
8484
irdl.results(%0)
8585
}
8686

8787
// CHECK: irdl.operation @dyn_attr_base {
88-
// CHECK: %[[v1:[^ ]*]] = irdl.base @parametric_attr
88+
// CHECK: %[[v1:[^ ]*]] = irdl.base @testd::@parametric_attr
8989
// CHECK: irdl.attributes {"attr1" = %[[v1]]}
9090
// CHECK: }
9191
irdl.operation @dyn_attr_base {
92-
%0 = irdl.base @parametric_attr
92+
%0 = irdl.base @testd::@parametric_attr
9393
irdl.attributes {"attr1" = %0}
9494
}
9595

@@ -115,14 +115,14 @@ irdl.dialect @testd {
115115
// CHECK: %[[v0:[^ ]*]] = irdl.is i32
116116
// CHECK: %[[v1:[^ ]*]] = irdl.is i64
117117
// CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
118-
// CHECK: %[[v3:[^ ]*]] = irdl.parametric @parametric<%[[v2]]>
118+
// CHECK: %[[v3:[^ ]*]] = irdl.parametric @testd::@parametric<%[[v2]]>
119119
// CHECK: irdl.results(%[[v3]])
120120
// CHECK: }
121121
irdl.operation @dynparams {
122122
%0 = irdl.is i32
123123
%1 = irdl.is i64
124124
%2 = irdl.any_of(%0, %1)
125-
%3 = irdl.parametric @parametric<%2>
125+
%3 = irdl.parametric @testd::@parametric<%2>
126126
irdl.results(%3)
127127
}
128128

0 commit comments

Comments
 (0)