Skip to content

Commit 36d936a

Browse files
[mlir][IR] Improve error message when return type could not be inferred (#112336)
Print an error such as the following one before terminating program execution. ``` mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir:26:8: remark: location of op %0 = sparse_tensor.convert %arg0 : tensor<?xi32> to tensor<?xi32, #SparseVector> ^ LLVM ERROR: Failed to infer result type(s): "sparse_tensor.positions"(...) {} : (index) -> ( ??? ) (stack trace follows) ```
1 parent ae68d53 commit 36d936a

File tree

4 files changed

+25
-1
lines changed

4 files changed

+25
-1
lines changed

mlir/include/mlir/Interfaces/InferTypeOpInterface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,10 @@ inferReturnTensorTypes(ArrayRef<ShapedTypeComponents> retComponents,
244244
/// Verifies that the inferred result types match the actual result types for
245245
/// the op. Precondition: op implements InferTypeOpInterface.
246246
LogicalResult verifyInferredResultTypes(Operation *op);
247+
248+
/// Report a fatal error indicating that the result types could not be
249+
/// inferred.
250+
void reportFatalInferReturnTypesError(OperationState &state);
247251
} // namespace detail
248252

249253
namespace OpTrait {

mlir/lib/Interfaces/InferTypeOpInterface.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,17 @@ LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
247247

248248
return result;
249249
}
250+
251+
void mlir::detail::reportFatalInferReturnTypesError(OperationState &state) {
252+
std::string buffer;
253+
llvm::raw_string_ostream os(buffer);
254+
os << "Failed to infer result type(s):\n";
255+
os << "\"" << state.name << "\"(...) ";
256+
os << state.attributes.getDictionary(state.location.getContext());
257+
os << " : (";
258+
llvm::interleaveComma(state.operands, os,
259+
[&](Value val) { os << val.getType(); });
260+
os << ") -> ( ??? )";
261+
emitRemark(state.location, "location of op");
262+
llvm::report_fatal_error(llvm::StringRef(buffer));
263+
}

mlir/test/mlir-tblgen/op-decl-and-defs.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,11 @@ def NS_FOp : NS_Op<"op_with_all_types_constraint",
208208
// CHECK-LABEL: class FOp :
209209
// CHECK: static ::llvm::LogicalResult inferReturnTypes
210210

211+
// DEFS: void FOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a) {
212+
// DEFS: if (::mlir::succeeded(FOp::inferReturnTypes(odsBuilder.getContext(),
213+
// DEFS: else
214+
// DEFS: ::mlir::detail::reportFatalInferReturnTypesError(odsState);
215+
211216
def NS_GOp : NS_Op<"op_with_fixed_return_type", []> {
212217
let arguments = (ins AnyType:$a);
213218
let results = (outs I32:$b);

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2503,7 +2503,8 @@ void OpEmitter::genSeparateArgParamBuilder() {
25032503
{1}.regions, inferredReturnTypes)))
25042504
{1}.addTypes(inferredReturnTypes);
25052505
else
2506-
::llvm::report_fatal_error("Failed to infer result type(s).");)",
2506+
::mlir::detail::reportFatalInferReturnTypesError({1});
2507+
)",
25072508
opClass.getClassName(), builderOpState);
25082509
return;
25092510
}

0 commit comments

Comments
 (0)