Skip to content

[mlir][irdl] Lookup symbols near dialects instead of locally #92819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions mlir/include/mlir/Dialect/IRDL/IRDLSymbols.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- IRDLSymbols.h - IRDL-related symbol logic ----------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Manages lookup logic for IRDL dialect-absolute symbols.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
#define MLIR_DIALECT_IRDL_IRDLSYMBOLS_H

#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"

namespace mlir {
namespace irdl {

/// Looks up a symbol from the symbol table containing the source operation's
/// dialect definition operation. The source operation must be nested within an
/// IRDL dialect definition operation. This exploits SymbolTableCollection for
/// better symbol table lookup.
Operation *lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
Operation *source, SymbolRefAttr symbol);

/// Looks up a symbol from the symbol table containing the source operation's
/// dialect definition operation. The source operation must be nested within an
/// IRDL dialect definition operation.
Operation *lookupSymbolNearDialect(Operation *source, SymbolRefAttr symbol);

} // namespace irdl
} // namespace mlir

#endif // MLIR_DIALECT_IRDL_IRDLSYMBOLS_H
1 change: 1 addition & 0 deletions mlir/lib/Dialect/IRDL/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRIRDL
IR/IRDL.cpp
IR/IRDLOps.cpp
IRDLLoading.cpp
IRDLSymbols.cpp
IRDLVerifiers.cpp

DEPENDS
Expand Down
7 changes: 6 additions & 1 deletion mlir/lib/Dialect/IRDL/IR/IRDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/Dialect/IRDL/IRDLSymbols.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
Expand Down Expand Up @@ -132,10 +133,14 @@ LogicalResult BaseOp::verify() {
return success();
}

/// Finds whether the provided symbol is an IRDL type or attribute definition.
/// The source operation must be within a DialectOp.
static LogicalResult
checkSymbolIsTypeOrAttribute(SymbolTableCollection &symbolTable,
Operation *source, SymbolRefAttr symbol) {
Operation *targetOp = symbolTable.lookupNearestSymbolFrom(source, symbol);
Operation *targetOp =
irdl::lookupSymbolNearDialect(symbolTable, source, symbol);

if (!targetOp)
return source->emitOpError() << "symbol '" << symbol << "' not found";

Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Dialect/IRDL/IR/IRDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/Dialect/IRDL/IRDLSymbols.h"
#include "mlir/IR/ValueRange.h"
#include <optional>

Expand Down Expand Up @@ -47,8 +48,9 @@ std::unique_ptr<Constraint> BaseOp::getVerifier(
// Case where the input is a symbol reference.
// This corresponds to the case where the base is an IRDL type or attribute.
if (auto baseRef = getBaseRef()) {
// The verifier for BaseOp guarantees it is within a dialect.
Operation *defOp =
SymbolTable::lookupNearestSymbolFrom(getOperation(), baseRef.value());
irdl::lookupSymbolNearDialect(getOperation(), baseRef.value());

// Type case.
if (auto typeOp = dyn_cast<TypeOp>(defOp)) {
Expand Down Expand Up @@ -99,10 +101,10 @@ std::unique_ptr<Constraint> ParametricOp::getVerifier(
SmallVector<unsigned> constraints =
getConstraintIndicesForArgs(getArgs(), valueToConstr);

// Symbol reference case for the base
// Symbol reference case for the base.
// The verifier for ParametricOp guarantees it is within a dialect.
SymbolRefAttr symRef = getBaseType();
Operation *defOp =
SymbolTable::lookupNearestSymbolFrom(getOperation(), symRef);
Operation *defOp = irdl::lookupSymbolNearDialect(getOperation(), symRef);
if (!defOp) {
emitError() << symRef << " does not refer to any existing symbol";
return nullptr;
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/IRDL/IRDLLoading.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Dialect/IRDL/IRDLLoading.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/Dialect/IRDL/IR/IRDLInterfaces.h"
#include "mlir/Dialect/IRDL/IRDLSymbols.h"
#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
Expand Down Expand Up @@ -523,7 +524,7 @@ static bool getBases(Operation *op, SmallPtrSet<TypeID, 4> &paramIds,
// For `irdl.parametric`, we get directly the base from the operation.
if (auto params = dyn_cast<ParametricOp>(op)) {
SymbolRefAttr symRef = params.getBaseType();
Operation *defOp = SymbolTable::lookupNearestSymbolFrom(op, symRef);
Operation *defOp = irdl::lookupSymbolNearDialect(op, symRef);
assert(defOp && "symbol reference should refer to an existing operation");
paramIrdlOps.insert(defOp);
return false;
Expand Down
38 changes: 38 additions & 0 deletions mlir/lib/Dialect/IRDL/IRDLSymbols.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===- IRDLSymbols.cpp - IRDL-related symbol logic --------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/IRDL/IRDLSymbols.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"

using namespace mlir;
using namespace mlir::irdl;

static Operation *lookupDialectOp(Operation *source) {
Operation *dialectOp = source;
while (dialectOp && !isa<DialectOp>(dialectOp))
dialectOp = dialectOp->getParentOp();

if (!dialectOp)
llvm_unreachable("symbol lookup near dialect must originate from "
"within a dialect definition");

return dialectOp;
}

Operation *
mlir::irdl::lookupSymbolNearDialect(SymbolTableCollection &symbolTable,
Operation *source, SymbolRefAttr symbol) {
return symbolTable.lookupNearestSymbolFrom(
lookupDialectOp(source)->getParentOp(), symbol);
}

Operation *mlir::irdl::lookupSymbolNearDialect(Operation *source,
SymbolRefAttr symbol) {
return SymbolTable::lookupNearestSymbolFrom(
lookupDialectOp(source)->getParentOp(), symbol);
}
8 changes: 4 additions & 4 deletions mlir/test/Dialect/IRDL/cmath.irdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ module {

// CHECK: irdl.operation @norm {
// CHECK: %[[v0:[^ ]*]] = irdl.any
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @complex<%[[v0]]>
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v0]]>
// CHECK: irdl.operands(%[[v1]])
// CHECK: irdl.results(%[[v0]])
// CHECK: }
irdl.operation @norm {
%0 = irdl.any
%1 = irdl.parametric @complex<%0>
%1 = irdl.parametric @cmath::@complex<%0>
irdl.operands(%1)
irdl.results(%0)
}
Expand All @@ -34,15 +34,15 @@ module {
// CHECK: %[[v0:[^ ]*]] = irdl.is f32
// CHECK: %[[v1:[^ ]*]] = irdl.is f64
// CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
// CHECK: %[[v3:[^ ]*]] = irdl.parametric @complex<%[[v2]]>
// CHECK: %[[v3:[^ ]*]] = irdl.parametric @cmath::@complex<%[[v2]]>
// CHECK: irdl.operands(%[[v3]], %[[v3]])
// CHECK: irdl.results(%[[v3]])
// CHECK: }
irdl.operation @mul {
%0 = irdl.is f32
%1 = irdl.is f64
%2 = irdl.any_of(%0, %1)
%3 = irdl.parametric @complex<%2>
%3 = irdl.parametric @cmath::@complex<%2>
irdl.operands(%3, %3)
irdl.results(%3)
}
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/Dialect/IRDL/cyclic-types.irdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
irdl.dialect @testd {
// CHECK: irdl.type @self_referencing {
// CHECK: %[[v0:[^ ]*]] = irdl.any
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @self_referencing<%[[v0]]>
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@self_referencing<%[[v0]]>
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
// CHECK: irdl.parameters(%[[v3]])
// CHECK: }
irdl.type @self_referencing {
%0 = irdl.any
%1 = irdl.parametric @self_referencing<%0>
%1 = irdl.parametric @testd::@self_referencing<%0>
%2 = irdl.is i32
%3 = irdl.any_of(%1, %2)
irdl.parameters(%3)
Expand All @@ -22,27 +22,27 @@ irdl.dialect @testd {

// CHECK: irdl.type @type1 {
// CHECK: %[[v0:[^ ]*]] = irdl.any
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @type2<%[[v0]]>
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@type2<%[[v0]]>
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
// CHECK: irdl.parameters(%[[v3]])
irdl.type @type1 {
%0 = irdl.any
%1 = irdl.parametric @type2<%0>
%1 = irdl.parametric @testd::@type2<%0>
%2 = irdl.is i32
%3 = irdl.any_of(%1, %2)
irdl.parameters(%3)
}

// CHECK: irdl.type @type2 {
// CHECK: %[[v0:[^ ]*]] = irdl.any
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @type1<%[[v0]]>
// CHECK: %[[v1:[^ ]*]] = irdl.parametric @testd::@type1<%[[v0]]>
// CHECK: %[[v2:[^ ]*]] = irdl.is i32
// CHECK: %[[v3:[^ ]*]] = irdl.any_of(%[[v1]], %[[v2]])
// CHECK: irdl.parameters(%[[v3]])
irdl.type @type2 {
%0 = irdl.any
%1 = irdl.parametric @type1<%0>
%1 = irdl.parametric @testd::@type1<%0>
%2 = irdl.is i32
%3 = irdl.any_of(%1, %2)
irdl.parameters(%3)
Expand Down
9 changes: 2 additions & 7 deletions mlir/test/Dialect/IRDL/invalid.irdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

// Testing invalid IRDL IRs

func.func private @foo()

irdl.dialect @testd {
irdl.type @type {
// expected-error@+1 {{symbol '@foo' not found}}
Expand Down Expand Up @@ -44,15 +42,12 @@ irdl.dialect @testd {

// -----

func.func private @not_a_type_or_attr()

irdl.dialect @invalid_parametric {
irdl.operation @foo {
// expected-error@+1 {{symbol '@not_a_type_or_attr' does not refer to a type or attribute definition}}
%param = irdl.parametric @not_a_type_or_attr<>
irdl.results(%param)
}

irdl.operation @not_a_type_or_attr {
%param = irdl.is i1
irdl.results(%param)
}
}
12 changes: 6 additions & 6 deletions mlir/test/Dialect/IRDL/testd.irdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,20 @@ irdl.dialect @testd {
}

// CHECK: irdl.operation @dyn_type_base {
// CHECK: %[[v1:[^ ]*]] = irdl.base @parametric
// CHECK: %[[v1:[^ ]*]] = irdl.base @testd::@parametric
// CHECK: irdl.results(%[[v1]])
// CHECK: }
irdl.operation @dyn_type_base {
%0 = irdl.base @parametric
%0 = irdl.base @testd::@parametric
irdl.results(%0)
}

// CHECK: irdl.operation @dyn_attr_base {
// CHECK: %[[v1:[^ ]*]] = irdl.base @parametric_attr
// CHECK: %[[v1:[^ ]*]] = irdl.base @testd::@parametric_attr
// CHECK: irdl.attributes {"attr1" = %[[v1]]}
// CHECK: }
irdl.operation @dyn_attr_base {
%0 = irdl.base @parametric_attr
%0 = irdl.base @testd::@parametric_attr
irdl.attributes {"attr1" = %0}
}

Expand All @@ -115,14 +115,14 @@ irdl.dialect @testd {
// CHECK: %[[v0:[^ ]*]] = irdl.is i32
// CHECK: %[[v1:[^ ]*]] = irdl.is i64
// CHECK: %[[v2:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]])
// CHECK: %[[v3:[^ ]*]] = irdl.parametric @parametric<%[[v2]]>
// CHECK: %[[v3:[^ ]*]] = irdl.parametric @testd::@parametric<%[[v2]]>
// CHECK: irdl.results(%[[v3]])
// CHECK: }
irdl.operation @dynparams {
%0 = irdl.is i32
%1 = irdl.is i64
%2 = irdl.any_of(%0, %1)
%3 = irdl.parametric @parametric<%2>
%3 = irdl.parametric @testd::@parametric<%2>
irdl.results(%3)
}

Expand Down
Loading