Skip to content

[mlir][tosa] Add error if and level checks for COND_IF & WHILE_LOOP #136194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025

Conversation

tatwaichong
Copy link
Contributor

Error if checks: verify whether the same length and type between input list, output list, and control-flow blocks.

Level_checks: verify whether the nested depth exceeds MAX_NESTING.

@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2025

@llvm/pr-subscribers-mlir-tosa

Author: TatWai Chong (tatwaichong)

Changes

Error if checks: verify whether the same length and type between input list, output list, and control-flow blocks.

Level_checks: verify whether the nested depth exceeds MAX_NESTING.


Patch is 50.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136194.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+130)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+42)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+224-79)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+349)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c94edad62cac7..70702647bb50a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2559,6 +2559,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   );
 
   let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -2597,6 +2598,7 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
   );
 
   let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
 }
 
 include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 8b4f6ef0d0980..365ea458d25d5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -428,6 +428,51 @@ static LogicalResult verifyConvOpModes(T op) {
   return success();
 }
 
+// Verify whether same length and type of block arguments and tensor list.
+static LogicalResult errorIfTensorListShapeMismatch(Operation *op,
+                                                    ValueRange blocksArgs,
+                                                    StringRef blockName,
+                                                    ValueRange tensorList,
+                                                    StringRef listName) {
+  if (blocksArgs.size() != tensorList.size())
+    return op->emitOpError() << "require same number of values in " << blockName
+                             << " (" << blocksArgs.size() << ") and "
+                             << listName << " (" << tensorList.size() << ")";
+
+  for (auto [bbArgType, opArgType] :
+       llvm::zip_equal(blocksArgs.getTypes(), tensorList.getTypes())) {
+    ShapeAdaptor bbShapeAdaptor(bbArgType);
+    ShapeAdaptor opShapeAdaptor(opArgType);
+
+    if (!bbShapeAdaptor.hasRank() || !opShapeAdaptor.hasRank())
+      continue;
+
+    if (!bbShapeAdaptor.hasStaticShape() || !opShapeAdaptor.hasStaticShape())
+      continue;
+
+    if (bbArgType != opArgType)
+      return op->emitOpError()
+             << "require same shapes for " << blockName << " (" << bbArgType
+             << ") and " << listName << " (" << opArgType << ")";
+  }
+
+  return success();
+}
+
+static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
+  ShapeAdaptor shapeAdaptor(type);
+
+  if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
+    return success();
+
+  SmallVector<int64_t> shape;
+  shapeAdaptor.getDims(shape);
+  if (!llvm::all_of(shape, [](int64_t dim) { return dim == 1; }))
+    return failure();
+
+  return success();
+}
+
 // verify that inType and outType have same element types
 template <typename T>
 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3321,6 +3366,91 @@ void IfOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs());
 }
 
+LogicalResult IfOp::verify() {
+  if (getThenGraph().empty() || getElseGraph().empty())
+    return emitOpError("require `then_graph` and `else_graph` not be empty");
+
+  if (errorIfTensorListShapeMismatch(
+          *this, getThenGraph().front().getArguments(),
+          "`then_graph` arguments", getInputList(), "`input_list`")
+          .failed())
+    return failure();
+
+  if (errorIfTensorListShapeMismatch(
+          *this, getElseGraph().front().getArguments(),
+          "`else_graph` arguments", getInputList(), "`input_list`")
+          .failed())
+    return failure();
+
+  auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
+  if (errorIfTensorListShapeMismatch(*this, thenYield.getInputs(),
+                                     "`then_graph` results", getOutputList(),
+                                     "`output_list`")
+          .failed())
+    return failure();
+
+  auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
+  if (errorIfTensorListShapeMismatch(*this, elseYield.getInputs(),
+                                     "`else_graph` results", getOutputList(),
+                                     "`output_list`")
+          .failed())
+    return failure();
+
+  auto condType = getCondition().getType();
+  if (errorIfShapeNotSizeOne(*this, condType).failed())
+    return emitOpError() << "`condition` must be a size 1 tensor, got "
+                         << condType;
+
+  return success();
+}
+
+LogicalResult WhileOp::verify() {
+  if (getCondGraph().empty() || getBodyGraph().empty())
+    return emitOpError(
+        "`cond_graph` and `body_graph` regions must not be empty");
+
+  if (errorIfTensorListShapeMismatch(*this, getInputList(), "`input_list`",
+                                     getOutputList(), "`output_list`")
+          .failed())
+    return failure();
+
+  if (errorIfTensorListShapeMismatch(
+          *this, getCondGraph().front().getArguments(),
+          "`cond_graph` arguments", getInputList(), "`input_list`")
+          .failed())
+    return failure();
+
+  if (errorIfTensorListShapeMismatch(
+          *this, getBodyGraph().front().getArguments(),
+          "`body_graph` arguments", getInputList(), "`input_list`")
+          .failed())
+    return failure();
+
+  auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
+  if (errorIfTensorListShapeMismatch(*this, bodyYield.getInputs(),
+                                     "`body_graph` results", getInputList(),
+                                     "`input_list`")
+          .failed())
+    return failure();
+
+  // Condition block output must be a single element tensor with a single bool
+  // value.
+  auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+  if (condYield.getInputs().size() != 1)
+    return emitOpError() << "require `cond_graph` only have one result";
+
+  auto condOutType = condYield.getInputs()[0].getType();
+  if (errorIfShapeNotSizeOne(*this, condOutType).failed())
+    return emitOpError() << "`cond_graph` result must be a size 1 tensor, got "
+                         << condOutType;
+
+  if (!getElementTypeOrSelf(condOutType).isInteger(1))
+    return emitOpError() << "`cond_graph` result must be a boolean tensor, got "
+                         << condOutType;
+
+  return success();
+}
+
 LogicalResult ReverseOp::verify() {
   if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
                              /* outType = */ getOutput().getType())
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index ef9d27f8df0ad..de31198faff2a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -449,6 +449,42 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
+  // Preform depth-first search in Tosa IR structure to find the maximum nesting
+  // depth. Tosa nesting_depth starts at 0 and increase by one each time a new
+  // nested `region` is encountered.
+
+  static int32_t getMaxNestedDepth(Operation *op) {
+    int32_t depth = 0;
+    for (Region &region : op->getRegions())
+      depth = std::max(depth, getMaxNestedDepth(region));
+    return depth;
+  }
+
+  static int32_t getMaxNestedDepth(Block &block) {
+    int32_t depth = 0;
+    for (Operation &op : block.getOperations())
+      depth = std::max(depth, getMaxNestedDepth(&op));
+    return depth;
+  }
+
+  static int32_t getMaxNestedDepth(Region &region) {
+    int32_t depth = 0;
+    for (Block &block : region.getBlocks())
+      depth = std::max(depth, getMaxNestedDepth(block));
+    // Increase the nested depth.
+    return depth + 1;
+  }
+
+  bool levelCheckMaxNesting(Operation *op) {
+    int32_t maxNestedDepth = getMaxNestedDepth(op);
+    if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
+      op->emitOpError() << "failed level check: " << maxNestedDepth
+                        << " >= MAX_NESTING";
+      return false;
+    }
+    return true;
+  }
+
   bool levelCheckListSize(Operation *op) {
     if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
       return levelCheckListSize(op, concat.getInput1().size(), "input1");
@@ -750,6 +786,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
     return failure();
   }
 
+  if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
+    if (!levelCheckMaxNesting(op)) {
+      return failure();
+    }
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index b48f614770fcb..5cad01d85825d 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1327,33 +1327,42 @@ func.func @test_if_tensor_list_size(%arg0 : tensor<i1>) {
   %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
   // expected-error@+1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: inputs}}
   %1 = "tosa.cond_if"(%arg0,   // condition
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0) ({
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0) ({
+  ^bb0(%00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+       %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+       %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+       %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+       %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+       %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+       %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>, %63: tensor<1xi32>, %64: tensor<1xi32>
+   ):
+    "tosa.yield"(%64) : (tensor<1xi32>) -> ()
   },  {
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+  ^bb0(%00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+       %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+       %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+       %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+       %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+       %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+       %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>, %63: tensor<1xi32>, %64: tensor<1xi32>
+   ):
+    "tosa.yield"(%01) : (tensor<1xi32>) -> ()
   }) : (
-                    tensor<i1>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>
-                  ) -> tensor<1xi32>
-
+       tensor<i1>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+     ) -> tensor<1xi32>
   return
 }
 
@@ -1361,27 +1370,54 @@ func.func @test_if_tensor_list_size(%arg0 : tensor<i1>) {
 
 // CHECK-LABEL: test_if_tensor_list_size_outputs
 func.func @test_if_tensor_list_size_outputs(%arg0 : tensor<i1>) {
-  %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %cst_0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
 
   // expected-error@+1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: outputs}}
-  %r:65 = "tosa.cond_if"(%arg0) ({
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+  %r:65 = "tosa.cond_if"(%arg0, %cst_0) ({
+  ^bb0(%0: tensor<1xi32>):
+    "tosa.yield"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0
+                ) : (
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+    ) -> ()
   },  {
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
-  }) : (tensor<i1>) -> (
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>
-                  )
-
+  ^bb0(%0: tensor<1xi32>):
+    "tosa.yield"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0
+                ) : (
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+    ) -> ()
+  }) : (tensor<i1>, tensor<1xi32>) -> (
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2025

@llvm/pr-subscribers-mlir

Author: TatWai Chong (tatwaichong)

Changes

Error if checks: verify whether the same length and type between input list, output list, and control-flow blocks.

Level_checks: verify whether the nested depth exceeds MAX_NESTING.


Patch is 50.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136194.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+130)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+42)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+224-79)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+349)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index c94edad62cac7..70702647bb50a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2559,6 +2559,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   );
 
   let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -2597,6 +2598,7 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
   );
 
   let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
 }
 
 include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 8b4f6ef0d0980..365ea458d25d5 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -428,6 +428,51 @@ static LogicalResult verifyConvOpModes(T op) {
   return success();
 }
 
+// Verify whether same length and type of block arguments and tensor list.
+static LogicalResult errorIfTensorListShapeMismatch(Operation *op,
+                                                    ValueRange blocksArgs,
+                                                    StringRef blockName,
+                                                    ValueRange tensorList,
+                                                    StringRef listName) {
+  if (blocksArgs.size() != tensorList.size())
+    return op->emitOpError() << "require same number of values in " << blockName
+                             << " (" << blocksArgs.size() << ") and "
+                             << listName << " (" << tensorList.size() << ")";
+
+  for (auto [bbArgType, opArgType] :
+       llvm::zip_equal(blocksArgs.getTypes(), tensorList.getTypes())) {
+    ShapeAdaptor bbShapeAdaptor(bbArgType);
+    ShapeAdaptor opShapeAdaptor(opArgType);
+
+    if (!bbShapeAdaptor.hasRank() || !opShapeAdaptor.hasRank())
+      continue;
+
+    if (!bbShapeAdaptor.hasStaticShape() || !opShapeAdaptor.hasStaticShape())
+      continue;
+
+    if (bbArgType != opArgType)
+      return op->emitOpError()
+             << "require same shapes for " << blockName << " (" << bbArgType
+             << ") and " << listName << " (" << opArgType << ")";
+  }
+
+  return success();
+}
+
+static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
+  ShapeAdaptor shapeAdaptor(type);
+
+  if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
+    return success();
+
+  SmallVector<int64_t> shape;
+  shapeAdaptor.getDims(shape);
+  if (!llvm::all_of(shape, [](int64_t dim) { return dim == 1; }))
+    return failure();
+
+  return success();
+}
+
 // verify that inType and outType have same element types
 template <typename T>
 static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3321,6 +3366,91 @@ void IfOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict((*this)->getAttrs());
 }
 
+LogicalResult IfOp::verify() {
+  if (getThenGraph().empty() || getElseGraph().empty())
+    return emitOpError("require `then_graph` and `else_graph` not be empty");
+
+  if (errorIfTensorListShapeMismatch(
+          *this, getThenGraph().front().getArguments(),
+          "`then_graph` arguments", getInputList(), "`input_list`")
+          .failed())
+    return failure();
+
+  if (errorIfTensorListShapeMismatch(
+          *this, getElseGraph().front().getArguments(),
+          "`else_graph` arguments", getInputList(), "`input_list`")
+          .failed())
+    return failure();
+
+  auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
+  if (errorIfTensorListShapeMismatch(*this, thenYield.getInputs(),
+                                     "`then_graph` results", getOutputList(),
+                                     "`output_list`")
+          .failed())
+    return failure();
+
+  auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
+  if (errorIfTensorListShapeMismatch(*this, elseYield.getInputs(),
+                                     "`else_graph` results", getOutputList(),
+                                     "`output_list`")
+          .failed())
+    return failure();
+
+  auto condType = getCondition().getType();
+  if (errorIfShapeNotSizeOne(*this, condType).failed())
+    return emitOpError() << "`condition` must be a size 1 tensor, got "
+                         << condType;
+
+  return success();
+}
+
+LogicalResult WhileOp::verify() {
+  if (getCondGraph().empty() || getBodyGraph().empty())
+    return emitOpError(
+        "`cond_graph` and `body_graph` regions must not be empty");
+
+  if (errorIfTensorListShapeMismatch(*this, getInputList(), "`input_list`",
+                                     getOutputList(), "`output_list`")
+          .failed())
+    return failure();
+
+  if (errorIfTensorListShapeMismatch(
+          *this, getCondGraph().front().getArguments(),
+          "`cond_graph` arguments", getInputList(), "`input_list`")
+          .failed())
+    return failure();
+
+  if (errorIfTensorListShapeMismatch(
+          *this, getBodyGraph().front().getArguments(),
+          "`body_graph` arguments", getInputList(), "`input_list`")
+          .failed())
+    return failure();
+
+  auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
+  if (errorIfTensorListShapeMismatch(*this, bodyYield.getInputs(),
+                                     "`body_graph` results", getInputList(),
+                                     "`input_list`")
+          .failed())
+    return failure();
+
+  // Condition block output must be a single element tensor with a single bool
+  // value.
+  auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
+  if (condYield.getInputs().size() != 1)
+    return emitOpError() << "require `cond_graph` only have one result";
+
+  auto condOutType = condYield.getInputs()[0].getType();
+  if (errorIfShapeNotSizeOne(*this, condOutType).failed())
+    return emitOpError() << "`cond_graph` result must be a size 1 tensor, got "
+                         << condOutType;
+
+  if (!getElementTypeOrSelf(condOutType).isInteger(1))
+    return emitOpError() << "`cond_graph` result must be a boolean tensor, got "
+                         << condOutType;
+
+  return success();
+}
+
 LogicalResult ReverseOp::verify() {
   if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
                              /* outType = */ getOutput().getType())
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index ef9d27f8df0ad..de31198faff2a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -449,6 +449,42 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
+  // Preform depth-first search in Tosa IR structure to find the maximum nesting
+  // depth. Tosa nesting_depth starts at 0 and increase by one each time a new
+  // nested `region` is encountered.
+
+  static int32_t getMaxNestedDepth(Operation *op) {
+    int32_t depth = 0;
+    for (Region &region : op->getRegions())
+      depth = std::max(depth, getMaxNestedDepth(region));
+    return depth;
+  }
+
+  static int32_t getMaxNestedDepth(Block &block) {
+    int32_t depth = 0;
+    for (Operation &op : block.getOperations())
+      depth = std::max(depth, getMaxNestedDepth(&op));
+    return depth;
+  }
+
+  static int32_t getMaxNestedDepth(Region &region) {
+    int32_t depth = 0;
+    for (Block &block : region.getBlocks())
+      depth = std::max(depth, getMaxNestedDepth(block));
+    // Increase the nested depth.
+    return depth + 1;
+  }
+
+  bool levelCheckMaxNesting(Operation *op) {
+    int32_t maxNestedDepth = getMaxNestedDepth(op);
+    if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
+      op->emitOpError() << "failed level check: " << maxNestedDepth
+                        << " >= MAX_NESTING";
+      return false;
+    }
+    return true;
+  }
+
   bool levelCheckListSize(Operation *op) {
     if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
       return levelCheckListSize(op, concat.getInput1().size(), "input1");
@@ -750,6 +786,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
     return failure();
   }
 
+  if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
+    if (!levelCheckMaxNesting(op)) {
+      return failure();
+    }
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index b48f614770fcb..5cad01d85825d 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -1327,33 +1327,42 @@ func.func @test_if_tensor_list_size(%arg0 : tensor<i1>) {
   %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
   // expected-error@+1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: inputs}}
   %1 = "tosa.cond_if"(%arg0,   // condition
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0, %0, %0, %0, %0, %0, %0, %0,
-                  %0) ({
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                  %0, %0, %0, %0, %0) ({
+  ^bb0(%00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+       %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+       %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+       %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+       %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+       %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+       %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>, %63: tensor<1xi32>, %64: tensor<1xi32>
+   ):
+    "tosa.yield"(%64) : (tensor<1xi32>) -> ()
   },  {
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+  ^bb0(%00: tensor<1xi32>, %01: tensor<1xi32>, %02: tensor<1xi32>, %03: tensor<1xi32>, %04: tensor<1xi32>, %05: tensor<1xi32>, %06: tensor<1xi32>, %07: tensor<1xi32>, %08: tensor<1xi32>, %09: tensor<1xi32>,
+       %10: tensor<1xi32>, %11: tensor<1xi32>, %12: tensor<1xi32>, %13: tensor<1xi32>, %14: tensor<1xi32>, %15: tensor<1xi32>, %16: tensor<1xi32>, %17: tensor<1xi32>, %18: tensor<1xi32>, %19: tensor<1xi32>,
+       %20: tensor<1xi32>, %21: tensor<1xi32>, %22: tensor<1xi32>, %23: tensor<1xi32>, %24: tensor<1xi32>, %25: tensor<1xi32>, %26: tensor<1xi32>, %27: tensor<1xi32>, %28: tensor<1xi32>, %29: tensor<1xi32>,
+       %30: tensor<1xi32>, %31: tensor<1xi32>, %32: tensor<1xi32>, %33: tensor<1xi32>, %34: tensor<1xi32>, %35: tensor<1xi32>, %36: tensor<1xi32>, %37: tensor<1xi32>, %38: tensor<1xi32>, %39: tensor<1xi32>,
+       %40: tensor<1xi32>, %41: tensor<1xi32>, %42: tensor<1xi32>, %43: tensor<1xi32>, %44: tensor<1xi32>, %45: tensor<1xi32>, %46: tensor<1xi32>, %47: tensor<1xi32>, %48: tensor<1xi32>, %49: tensor<1xi32>,
+       %50: tensor<1xi32>, %51: tensor<1xi32>, %52: tensor<1xi32>, %53: tensor<1xi32>, %54: tensor<1xi32>, %55: tensor<1xi32>, %56: tensor<1xi32>, %57: tensor<1xi32>, %58: tensor<1xi32>, %59: tensor<1xi32>,
+       %60: tensor<1xi32>, %61: tensor<1xi32>, %62: tensor<1xi32>, %63: tensor<1xi32>, %64: tensor<1xi32>
+   ):
+    "tosa.yield"(%01) : (tensor<1xi32>) -> ()
   }) : (
-                    tensor<i1>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>
-                  ) -> tensor<1xi32>
-
+       tensor<i1>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+       tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+     ) -> tensor<1xi32>
   return
 }
 
@@ -1361,27 +1370,54 @@ func.func @test_if_tensor_list_size(%arg0 : tensor<i1>) {
 
 // CHECK-LABEL: test_if_tensor_list_size_outputs
 func.func @test_if_tensor_list_size_outputs(%arg0 : tensor<i1>) {
-  %0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %cst_0 = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
 
   // expected-error@+1 {{'tosa.cond_if' op failed level check for MAX_TENSOR_LIST_SIZE: outputs}}
-  %r:65 = "tosa.cond_if"(%arg0) ({
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
+  %r:65 = "tosa.cond_if"(%arg0, %cst_0) ({
+  ^bb0(%0: tensor<1xi32>):
+    "tosa.yield"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0
+                ) : (
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+    ) -> ()
   },  {
-  ^bb0(%arg3: tensor<1xi32>):
-    "tosa.yield"(%arg3) : (tensor<1xi32>) -> ()
-  }) : (tensor<i1>) -> (
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
-                    tensor<1xi32>
-                  )
-
+  ^bb0(%0: tensor<1xi32>):
+    "tosa.yield"(%0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0, %0, %0, %0, %0, %0,
+                 %0, %0, %0, %0, %0
+                ) : (
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>
+    ) -> ()
+  }) : (tensor<i1>, tensor<1xi32>) -> (
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>,
+      ...
[truncated]

@Tai78641
Copy link
Contributor

LGTM

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Had a few suggestions for improving some of the checks, otherwise it LGTM!

@tatwaichong tatwaichong force-pushed the control_flow_error_if branch 3 times, most recently from d8a3f5e to f10baf6 Compare April 27, 2025 21:10
Copy link

github-actions bot commented Apr 27, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@tatwaichong tatwaichong force-pushed the control_flow_error_if branch from f10baf6 to 21c0050 Compare April 27, 2025 21:14
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, had a question, otherwise LGTM!

Error if checks: verify whether the same length and type between
input list, output list, and control-flow blocks.

Level_checks: verify whether the nested depth exceeds MAX_NESTING.
@tatwaichong tatwaichong force-pushed the control_flow_error_if branch from 21c0050 to 5b8f507 Compare April 28, 2025 17:57
@lhutton1 lhutton1 merged commit 3d47bc9 into llvm:main Apr 29, 2025
11 checks passed
gizmondo pushed a commit to gizmondo/llvm-project that referenced this pull request Apr 29, 2025
…lvm#136194)

Error if checks: verify whether the same length and type between input
list, output list, and control-flow blocks.

Level_checks: verify whether the nested depth exceeds MAX_NESTING.
@tatwaichong tatwaichong deleted the control_flow_error_if branch April 29, 2025 17:06
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…lvm#136194)

Error if checks: verify whether the same length and type between input
list, output list, and control-flow blocks.

Level_checks: verify whether the nested depth exceeds MAX_NESTING.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…lvm#136194)

Error if checks: verify whether the same length and type between input
list, output list, and control-flow blocks.

Level_checks: verify whether the nested depth exceeds MAX_NESTING.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…lvm#136194)

Error if checks: verify whether the same length and type between input
list, output list, and control-flow blocks.

Level_checks: verify whether the nested depth exceeds MAX_NESTING.
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
…lvm#136194)

Error if checks: verify whether the same length and type between input
list, output list, and control-flow blocks.

Level_checks: verify whether the nested depth exceeds MAX_NESTING.
Ankur-0429 pushed a commit to Ankur-0429/llvm-project that referenced this pull request May 9, 2025
…lvm#136194)

Error if checks: verify whether the same length and type between input
list, output list, and control-flow blocks.

Level_checks: verify whether the nested depth exceeds MAX_NESTING.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants