Skip to content

[mlir][irdl] Lookup symbols near dialects instead of locally #92819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 31, 2024

Conversation

Moxinilian
Copy link
Member

Because symbols cannot refer to operations outside of their symbol tables, it was impossible to refer to operations outside of the dialect currently being defined. This PR modifies the lookup logic to happen relative to the symbol table containing the dialect-defining operations. This is a bit of hack but should unblock the situation here.

@llvmbot
Copy link
Member

llvmbot commented May 20, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-irdl

Author: Théo Degioanni (Moxinilian)

Changes

Because symbols cannot refer to operations outside of their symbol tables, it was impossible to refer to operations outside of the dialect currently being defined. This PR modifies the lookup logic to happen relative to the symbol table containing the dialect-defining operations. This is a bit of hack but should unblock the situation here.


Full diff: https://github.com/llvm/llvm-project/pull/92819.diff

10 Files Affected:

  • (added) mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h (+37)
  • (modified) mlir/lib/Dialect/IRDL/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/IRDL/IR/IRDL.cpp (+6-1)
  • (modified) mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp (+6-4)
  • (modified) mlir/lib/Dialect/IRDL/IRDLLoading.cpp (+2-1)
  • (added) mlir/lib/Dialect/IRDL/IRDLSymbols.cpp (+38)
  • (modified) mlir/test/Dialect/IRDL/cmath.irdl.mlir (+4-4)
  • (modified) mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir (+6-6)
  • (modified) mlir/test/Dialect/IRDL/invalid.irdl.mlir (+2-7)
  • (modified) mlir/test/Dialect/IRDL/testd.irdl.mlir (+6-6)
diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h b/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
new file mode 100644
index 0000000000000..4b7292c054ec2
--- /dev/null
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
@@ -0,0 +1,37 @@
+//===- IRDLSymbols.h - IRDL-related symbol logic ----------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Manages lookup logic for IRDL dialect-absolute symbols.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
+#define MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
+
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+
+namespace mlir {
+namespace irdl {
+
+/// Looks up a symbol from the symbol table containing the source operation's
+/// dialect definition operation. The source operation must be nested within an
+/// IRDL dialect definition operation. This exploits SymbolTableCollection for
+/// better symbol table lookup.
+Operation *lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
+                                   Operation *source, SymbolRefAttr symbol);
+
+/// Looks up a symbol from the symbol table containing the source operation's
+/// dialect definition operation. The source operation must be nested within an
+/// IRDL dialect definition operation.
+Operation *lookupSymbolNearDialect(Operation *source, SymbolRefAttr symbol);
+
+} // namespace irdl
+} // namespace mlir
+
+#endif // MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
diff --git a/mlir/lib/Dialect/IRDL/CMakeLists.txt b/mlir/lib/Dialect/IRDL/CMakeLists.txt
index d25760e5d29bc..db4b98ef5308e 100644
--- a/mlir/lib/Dialect/IRDL/CMakeLists.txt
+++ b/mlir/lib/Dialect/IRDL/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRIRDL
   IR/IRDL.cpp
   IR/IRDLOps.cpp
   IRDLLoading.cpp
+  IRDLSymbols.cpp
   IRDLVerifiers.cpp
 
   DEPENDS
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
index e4728f55b49d7..1f5584fa30c27 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDL.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Diagnostics.h"
@@ -132,10 +133,14 @@ LogicalResult BaseOp::verify() {
   return success();
 }
 
+/// Finds whether the provided symbol is an IRDL type or attribute definition.
+/// The source operation must be within a DialectOp.
 static LogicalResult
 checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable,
                              Operation *source, SymbolRefAttr symbol) {
-  Operation *targetOp = symbolTable.lookupNearestSymbolFrom(source, symbol);
+  Operation *targetOp =
+      irdl::lookupSymbolNearDialect(symbolTable, source, symbol);
+
   if (!targetOp)
     return source->emitOpError() << "symbol '" << symbol << "' not found";
 
diff --git a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
index 0895306b8bce1..7ec3aa2741023 100644
--- a/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
+++ b/mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
 #include "mlir/IR/ValueRange.h"
 #include <optional>
 
@@ -47,8 +48,9 @@ std::unique_ptr<Constraint> BaseOp::getVerifier(
   // Case where the input is a symbol reference.
   // This corresponds to the case where the base is an IRDL type or attribute.
   if (auto baseRef = getBaseRef()) {
+    // The verifier for BaseOp guarantees it is within a dialect.
     Operation *defOp =
-        SymbolTable::lookupNearestSymbolFrom(getOperation(), baseRef.value());
+        irdl::lookupSymbolNearDialect(getOperation(), baseRef.value());
 
     // Type case.
     if (auto typeOp = dyn_cast<TypeOp>(defOp)) {
@@ -99,10 +101,10 @@ std::unique_ptr<Constraint> ParametricOp::getVerifier(
   SmallVector<unsigned> constraints =
       getConstraintIndicesForArgs(getArgs(), valueToConstr);
 
-  // Symbol reference case for the base
+  // Symbol reference case for the base.
+  // The verifier for ParametricOp guarantees it is within a dialect.
   SymbolRefAttr symRef = getBaseType();
-  Operation *defOp =
-      SymbolTable::lookupNearestSymbolFrom(getOperation(), symRef);
+  Operation *defOp = irdl::lookupSymbolNearDialect(getOperation(), symRef);
   if (!defOp) {
     emitError() << symRef << " does not refer to any existing symbol";
     return nullptr;
diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
index 5df2b45d8037b..5f623e8845d10 100644
--- a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/IRDL/IRDLLoading.h"
 #include "mlir/Dialect/IRDL/IR/IRDL.h"
 #include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
 #include "mlir/Dialect/IRDL/IRDLVerifiers.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -523,7 +524,7 @@ static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> &paramIds,
   // For `irdl.parametric`, we get directly the base from the operation.
   if (auto params = dyn_cast<ParametricOp>(op)) {
     SymbolRefAttr symRef = params.getBaseType();
-    Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef);
+    Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef);
     assert(defOp && "symbol reference should refer to an existing operation");
     paramIrdlOps.insert(defOp);
     return false;
diff --git a/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp b/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp
new file mode 100644
index 0000000000000..ff2136df364d9
--- /dev/null
+++ b/mlir/lib/Dialect/IRDL/IRDLSymbols.cpp
@@ -0,0 +1,38 @@
+//===- IRDLSymbols.cpp - IRDL-related symbol logic --------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/IRDL/IRDLSymbols.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+
+using namespace mlir;
+using namespace mlir::irdl;
+
+static Operation *lookupDialectOp(Operation *source) {
+  Operation *dialectOp = source;
+  while (dialectOp && !isa<DialectOp>(dialectOp))
+    dialectOp = dialectOp->getParentOp();
+
+  if (!dialectOp)
+    llvm_unreachable("symbol lookup near dialect must originate from "
+                     "within a dialect definition");
+
+  return dialectOp;
+}
+
+Operation *
+mlir::irdl::lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
+                                    Operation *source, SymbolRefAttr symbol) {
+  return symbolTable.lookupNearestSymbolFrom(
+      lookupDialectOp(source)->getParentOp(), symbol);
+}
+
+Operation *mlir::irdl::lookupSymbolNearDialect(Operation *source,
+                                               SymbolRefAttr symbol) {
+  return SymbolTable::lookupNearestSymbolFrom(
+      lookupDialectOp(source)->getParentOp(), symbol);
+}
diff --git a/mlir/test/Dialect/IRDL/cmath.irdl.mlir b/mlir/test/Dialect/IRDL/cmath.irdl.mlir
index 997af08d24733..0b7e220ceb90c 100644
--- a/mlir/test/Dialect/IRDL/cmath.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/cmath.irdl.mlir
@@ -19,13 +19,13 @@ module {
 
     // CHECK: irdl.operation @norm {
     // CHECK:   %[[v0:[^ ]*]] = irdl.any
-    // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @complex<%[[v0]]>
+    // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v0]]>
     // CHECK:   irdl.operands(%[[v1]])
     // CHECK:   irdl.results(%[[v0]])
     // CHECK: }
     irdl.operation @norm {
       %0 = irdl.any
-      %1 = irdl.parametric @complex<%0>
+      %1 = irdl.parametric @cmath::@complex<%0>
       irdl.operands(%1)
       irdl.results(%0)
     }
@@ -34,7 +34,7 @@ module {
     // CHECK:   %[[v0:[^ ]*]] = irdl.is f32
     // CHECK:   %[[v1:[^ ]*]] = irdl.is f64
     // CHECK:   %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
-    // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @complex<%[[v2]]>
+    // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v2]]>
     // CHECK:   irdl.operands(%[[v3]], %[[v3]])
     // CHECK:   irdl.results(%[[v3]])
     // CHECK: }
@@ -42,7 +42,7 @@ module {
       %0 = irdl.is f32
       %1 = irdl.is f64
       %2 = irdl.any_of(%0, %1)
-      %3 = irdl.parametric @complex<%2>
+      %3 = irdl.parametric @cmath::@complex<%2>
       irdl.operands(%3, %3)
       irdl.results(%3)
     }
diff --git a/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
index db8dfc5cb36ca..cbcc248bf00b1 100644
--- a/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
@@ -6,14 +6,14 @@
 irdl.dialect @testd {
   // CHECK:   irdl.type @self_referencing {
   // CHECK:   %[[v0:[^ ]*]] = irdl.any
-  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @self_referencing<%[[v0]]>
+  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @testd::@self_referencing<%[[v0]]>
   // CHECK:   %[[v2:[^ ]*]] = irdl.is i32
   // CHECK:   %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
   // CHECK:   irdl.parameters(%[[v3]])
   // CHECK: }
   irdl.type @self_referencing {
     %0 = irdl.any
-    %1 = irdl.parametric @self_referencing<%0>
+    %1 = irdl.parametric @testd::@self_referencing<%0>
     %2 = irdl.is i32
     %3 = irdl.any_of(%1, %2)
     irdl.parameters(%3)
@@ -22,13 +22,13 @@ irdl.dialect @testd {
 
   // CHECK:   irdl.type @type1 {
   // CHECK:   %[[v0:[^ ]*]] = irdl.any
-  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @type2<%[[v0]]>
+  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @testd::@type2<%[[v0]]>
   // CHECK:   %[[v2:[^ ]*]] = irdl.is i32
   // CHECK:   %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
   // CHECK:   irdl.parameters(%[[v3]])
   irdl.type @type1 {
     %0 = irdl.any
-    %1 = irdl.parametric @type2<%0>
+    %1 = irdl.parametric @testd::@type2<%0>
     %2 = irdl.is i32
     %3 = irdl.any_of(%1, %2)
     irdl.parameters(%3)
@@ -36,13 +36,13 @@ irdl.dialect @testd {
 
   // CHECK:   irdl.type @type2 {
   // CHECK:   %[[v0:[^ ]*]] = irdl.any
-  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @type1<%[[v0]]>
+  // CHECK:   %[[v1:[^ ]*]] = irdl.parametric @testd::@type1<%[[v0]]>
   // CHECK:   %[[v2:[^ ]*]] = irdl.is i32
   // CHECK:   %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
   // CHECK:   irdl.parameters(%[[v3]])
   irdl.type @type2 {
       %0 = irdl.any
-      %1 = irdl.parametric @type1<%0>
+      %1 = irdl.parametric @testd::@type1<%0>
       %2 = irdl.is i32
       %3 = irdl.any_of(%1, %2)
       irdl.parameters(%3)
diff --git a/mlir/test/Dialect/IRDL/invalid.irdl.mlir b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
index f207d31cf158b..93ad619358750 100644
--- a/mlir/test/Dialect/IRDL/invalid.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/invalid.irdl.mlir
@@ -2,8 +2,6 @@
 
 // Testing invalid IRDL IRs
 
-func.func private @foo()
-
 irdl.dialect @testd {
   irdl.type @type {
     // expected-error@+1 {{symbol '@foo' not found}}
@@ -44,15 +42,12 @@ irdl.dialect @testd {
 
 // -----
 
+func.func private @not_a_type_or_attr()
+
 irdl.dialect @invalid_parametric {
   irdl.operation @foo {
     // expected-error@+1 {{symbol '@not_a_type_or_attr' does not refer to a type or attribute definition}}
     %param = irdl.parametric @not_a_type_or_attr<>
     irdl.results(%param)
   }
-
-  irdl.operation @not_a_type_or_attr {
-    %param = irdl.is i1
-    irdl.results(%param)
-  }
 }
diff --git a/mlir/test/Dialect/IRDL/testd.irdl.mlir b/mlir/test/Dialect/IRDL/testd.irdl.mlir
index f828d95bdb81d..aeb1a83747ecc 100644
--- a/mlir/test/Dialect/IRDL/testd.irdl.mlir
+++ b/mlir/test/Dialect/IRDL/testd.irdl.mlir
@@ -76,20 +76,20 @@ irdl.dialect @testd {
   }
 
   // CHECK: irdl.operation @dyn_type_base {
-  // CHECK:   %[[v1:[^ ]*]] = irdl.base @parametric
+  // CHECK:   %[[v1:[^ ]*]] = irdl.base @testd::@parametric
   // CHECK:   irdl.results(%[[v1]])
   // CHECK: }
   irdl.operation @dyn_type_base {
-    %0 = irdl.base @parametric
+    %0 = irdl.base @testd::@parametric
     irdl.results(%0)
   }
 
   // CHECK: irdl.operation @dyn_attr_base {
-  // CHECK:   %[[v1:[^ ]*]] = irdl.base @parametric_attr
+  // CHECK:   %[[v1:[^ ]*]] = irdl.base @testd::@parametric_attr
   // CHECK:   irdl.attributes {"attr1" = %[[v1]]}
   // CHECK: }
   irdl.operation @dyn_attr_base {
-    %0 = irdl.base @parametric_attr
+    %0 = irdl.base @testd::@parametric_attr
     irdl.attributes {"attr1" = %0}
   }
 
@@ -115,14 +115,14 @@ irdl.dialect @testd {
   // CHECK:   %[[v0:[^ ]*]] = irdl.is i32
   // CHECK:   %[[v1:[^ ]*]] = irdl.is i64
   // CHECK:   %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
-  // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @parametric<%[[v2]]>
+  // CHECK:   %[[v3:[^ ]*]] = irdl.parametric @testd::@parametric<%[[v2]]>
   // CHECK:   irdl.results(%[[v3]])
   // CHECK: }
   irdl.operation @dynparams {
     %0 = irdl.is i32
     %1 = irdl.is i64
     %2 = irdl.any_of(%0, %1)
-    %3 = irdl.parametric @parametric<%2>
+    %3 = irdl.parametric @testd::@parametric<%2>
     irdl.results(%3)
   }
 

@Moxinilian Moxinilian force-pushed the irdl-symbols-hack branch from cba009f to f6ccbe2 Compare May 20, 2024 21:26
Copy link
Contributor

@math-fehr math-fehr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay! Very nice!

@Moxinilian Moxinilian merged commit b86a9c5 into llvm:main May 31, 2024
10 checks passed
@Moxinilian Moxinilian deleted the irdl-symbols-hack branch May 31, 2024 08:15
keith added a commit that referenced this pull request May 31, 2024
keith added a commit that referenced this pull request May 31, 2024
I missed this since it was still broken because of another patch
#93996
@jackalcooper
Copy link
Contributor

hi, glad to see the new progress in the development of IRDL!
After this PR, what is the recommended way to get a symbol like @cmath::@complex? I was using the CAPI mlirFlatSymbolRefAttrGet to parse the a "complex" string ref.

@Moxinilian
Copy link
Member Author

Moxinilian commented Jun 4, 2024

The SymbolRefAttr is what you are looking for. FlatSymbolRefAttr, as its name implies, only models a flat symbol reference, with no nesting. In contrast, you can define nesting in SymbolRefAttr.

I hope this is exposed in the C API. If not, it should not be too hard to add it? I need to check.

@jackalcooper
Copy link
Contributor

The SymbolRefAttr is what you are looking for. FlatSymbolRefAttr, as its name implies, only models a flat symbol reference, with no nesting. In contrast, you can define nesting in SymbolRefAttr.

I hope this is exposed in the C API. If not, it should not be too hard to add it? I need to check.

I've found it! Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants