Skip to content

Commit 3d47bc9

Browse files
authored
[mlir][tosa] Add error if and level checks for COND_IF & WHILE_LOOP (#136194)
Error if checks: verify whether the same length and type between input list, output list, and control-flow blocks. Level_checks: verify whether the nested depth exceeds MAX_NESTING.
1 parent 252b095 commit 3d47bc9

File tree

5 files changed

+739
-80
lines changed

5 files changed

+739
-80
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2559,6 +2559,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
25592559
);
25602560

25612561
let hasCustomAssemblyFormat = 1;
2562+
let hasVerifier = 1;
25622563
}
25632564

25642565
//===----------------------------------------------------------------------===//
@@ -2597,6 +2598,7 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
25972598
);
25982599

25992600
let hasCustomAssemblyFormat = 1;
2601+
let hasVerifier = 1;
26002602
}
26012603

26022604
include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,57 @@ static LogicalResult verifyConvOpErrorIf(T op) {
562562
return success();
563563
}
564564

565+
// Verify whether same type and shape of the given two types.
566+
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
567+
StringRef name1, Type type2,
568+
StringRef name2) {
569+
auto shapeType1 = dyn_cast<ShapedType>(type1);
570+
auto shapeType2 = dyn_cast<ShapedType>(type2);
571+
if (!shapeType1 || !shapeType2)
572+
return failure();
573+
574+
auto elemType1 = shapeType1.getElementType();
575+
auto elemType2 = shapeType2.getElementType();
576+
if (elemType1 != elemType2)
577+
return op->emitOpError()
578+
<< "require same element type for " << name1 << " (" << elemType1
579+
<< ") and " << name2 << " (" << elemType2 << ")";
580+
581+
if (failed(verifyCompatibleShape(type1, type2)))
582+
return op->emitOpError()
583+
<< "require same shapes for " << name1 << " (" << type1 << ") and "
584+
<< name2 << " (" << type2 << ")";
585+
586+
return success();
587+
}
588+
589+
// Verify whether same length, type, and shape of the given two tensor lists.
590+
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
591+
StringRef name1,
592+
ValueRange list2,
593+
StringRef name2) {
594+
if (list1.size() != list2.size())
595+
return op->emitOpError()
596+
<< "require same number of values in " << name1 << " ("
597+
<< list1.size() << ") and " << name2 << " (" << list2.size() << ")";
598+
599+
for (auto [type1, type2] :
600+
llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
601+
if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
602+
return failure();
603+
}
604+
605+
return success();
606+
}
607+
608+
static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
609+
ShapeAdaptor shapeAdaptor(type);
610+
if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
611+
return success();
612+
613+
return shapeAdaptor.getNumElements() == 1 ? success() : failure();
614+
}
615+
565616
// verify that inType and outType have same element types
566617
template <typename T>
567618
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3473,6 +3524,84 @@ void IfOp::print(OpAsmPrinter &p) {
34733524
p.printOptionalAttrDict((*this)->getAttrs());
34743525
}
34753526

3527+
LogicalResult IfOp::verify() {
3528+
if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
3529+
"'then_graph' arguments", getInputList(),
3530+
"'input_list'")
3531+
.failed())
3532+
return failure();
3533+
3534+
if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
3535+
"'else_graph' arguments", getInputList(),
3536+
"'input_list'")
3537+
.failed())
3538+
return failure();
3539+
3540+
auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
3541+
if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
3542+
"'then_graph' results", getOutputList(),
3543+
"'output_list'")
3544+
.failed())
3545+
return failure();
3546+
3547+
auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
3548+
if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
3549+
"'else_graph' results", getOutputList(),
3550+
"'output_list'")
3551+
.failed())
3552+
return failure();
3553+
3554+
auto condType = getCondition().getType();
3555+
if (errorIfShapeNotSizeOne(*this, condType).failed())
3556+
return emitOpError() << "'condition' must be a size 1 tensor, got "
3557+
<< condType;
3558+
3559+
return success();
3560+
}
3561+
3562+
LogicalResult WhileOp::verify() {
3563+
if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
3564+
getOutputList(), "'output_list'")
3565+
.failed())
3566+
return failure();
3567+
3568+
if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
3569+
"'cond_graph' arguments", getInputList(),
3570+
"'input_list'")
3571+
.failed())
3572+
return failure();
3573+
3574+
if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
3575+
"'body_graph' arguments", getInputList(),
3576+
"'input_list'")
3577+
.failed())
3578+
return failure();
3579+
3580+
auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
3581+
if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
3582+
"'body_graph' results", getInputList(),
3583+
"'input_list'")
3584+
.failed())
3585+
return failure();
3586+
3587+
// Condition block output must be a single element tensor with a single bool
3588+
// value.
3589+
auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
3590+
if (condYield.getInputs().size() != 1)
3591+
return emitOpError() << "require 'cond_graph' only have one result";
3592+
3593+
auto condOutType = condYield.getInputs()[0].getType();
3594+
if (errorIfShapeNotSizeOne(*this, condOutType).failed())
3595+
return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
3596+
<< condOutType;
3597+
3598+
if (!getElementTypeOrSelf(condOutType).isInteger(1))
3599+
return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
3600+
<< condOutType;
3601+
3602+
return success();
3603+
}
3604+
34763605
LogicalResult ReverseOp::verify() {
34773606
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
34783607
/* outType = */ getOutput().getType())

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,35 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
449449
return true;
450450
}
451451

452+
// Recursively perform a bottom-up search to determine the maximum nesting
453+
// depth, starting from a specific operation and continuing up to the function
454+
// or module scope. Tosa nesting_depth starts at 0 and increments by one each
455+
// time a new nested `region` is encountered.
456+
static void getMaxNestedDepth(Operation *op, int32_t &depth) {
457+
if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
458+
return;
459+
460+
op = op->getParentOp();
461+
if (!op)
462+
return;
463+
464+
depth++;
465+
getMaxNestedDepth(op, depth);
466+
return;
467+
}
468+
469+
bool levelCheckMaxNesting(Operation *op) {
470+
int32_t maxNestedDepth = 0;
471+
getMaxNestedDepth(op, maxNestedDepth);
472+
473+
if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
474+
op->emitOpError() << "failed level check: " << maxNestedDepth
475+
<< " >= MAX_NESTING";
476+
return false;
477+
}
478+
return true;
479+
}
480+
452481
bool levelCheckListSize(Operation *op) {
453482
if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
454483
return levelCheckListSize(op, concat.getInput1().size(), "input1");
@@ -750,6 +779,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
750779
return failure();
751780
}
752781

782+
if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
783+
if (!levelCheckMaxNesting(op)) {
784+
return failure();
785+
}
786+
}
787+
753788
return success();
754789
}
755790

0 commit comments

Comments
 (0)