Skip to content

Commit 21c0050

Browse files
committed
[mlir][tosa] Add error if and level checks for COND_IF & WHILE_LOOP
Error if checks: verify whether the same length and type between input list, output list, and control-flow blocks. Level_checks: verify whether the nested depth exceeds MAX_NESTING.
1 parent 6957699 commit 21c0050

File tree

5 files changed

+743
-79
lines changed

5 files changed

+743
-79
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2559,6 +2559,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
25592559
);
25602560

25612561
let hasCustomAssemblyFormat = 1;
2562+
let hasVerifier = 1;
25622563
}
25632564

25642565
//===----------------------------------------------------------------------===//
@@ -2597,6 +2598,7 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
25972598
);
25982599

25992600
let hasCustomAssemblyFormat = 1;
2601+
let hasVerifier = 1;
26002602
}
26012603

26022604
include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,57 @@ static LogicalResult verifyConvOpErrorIf(T op) {
562562
return success();
563563
}
564564

565+
// Verify whether same type and shape of the given two types.
566+
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
567+
StringRef name1, Type type2,
568+
StringRef name2) {
569+
auto shapeType1 = dyn_cast<ShapedType>(type1);
570+
auto shapeType2 = dyn_cast<ShapedType>(type2);
571+
if (!shapeType1 || !shapeType2)
572+
return failure();
573+
574+
auto elemType1 = shapeType1.getElementType();
575+
auto elemType2 = shapeType2.getElementType();
576+
if (elemType1 != elemType2)
577+
return op->emitOpError()
578+
<< "require same element type for " << name1 << " (" << elemType1
579+
<< ") and " << name2 << " (" << elemType2 << ")";
580+
581+
if (failed(verifyCompatibleShape(type1, type2)))
582+
return op->emitOpError()
583+
<< "require same shapes for " << name1 << " (" << type1 << ") and "
584+
<< name2 << " (" << type2 << ")";
585+
586+
return success();
587+
}
588+
589+
// Verify whether same length, type, and shape of the given two tensor lists.
590+
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
591+
StringRef name1,
592+
ValueRange list2,
593+
StringRef name2) {
594+
if (list1.size() != list2.size())
595+
return op->emitOpError()
596+
<< "require same number of values in " << name1 << " ("
597+
<< list1.size() << ") and " << name2 << " (" << list2.size() << ")";
598+
599+
for (auto [type1, type2] :
600+
llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
601+
if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
602+
return failure();
603+
}
604+
605+
return success();
606+
}
607+
608+
static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
609+
ShapeAdaptor shapeAdaptor(type);
610+
if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
611+
return success();
612+
613+
return shapeAdaptor.getNumElements() == 1 ? success() : failure();
614+
}
615+
565616
// verify that inType and outType have same element types
566617
template <typename T>
567618
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3437,6 +3488,84 @@ void IfOp::print(OpAsmPrinter &p) {
34373488
p.printOptionalAttrDict((*this)->getAttrs());
34383489
}
34393490

3491+
LogicalResult IfOp::verify() {
3492+
if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
3493+
"'then_graph' arguments", getInputList(),
3494+
"'input_list'")
3495+
.failed())
3496+
return failure();
3497+
3498+
if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
3499+
"'else_graph' arguments", getInputList(),
3500+
"'input_list'")
3501+
.failed())
3502+
return failure();
3503+
3504+
auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
3505+
if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
3506+
"'then_graph' results", getOutputList(),
3507+
"'output_list'")
3508+
.failed())
3509+
return failure();
3510+
3511+
auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
3512+
if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
3513+
"'else_graph' results", getOutputList(),
3514+
"'output_list'")
3515+
.failed())
3516+
return failure();
3517+
3518+
auto condType = getCondition().getType();
3519+
if (errorIfShapeNotSizeOne(*this, condType).failed())
3520+
return emitOpError() << "'condition' must be a size 1 tensor, got "
3521+
<< condType;
3522+
3523+
return success();
3524+
}
3525+
3526+
LogicalResult WhileOp::verify() {
3527+
if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
3528+
getOutputList(), "'output_list'")
3529+
.failed())
3530+
return failure();
3531+
3532+
if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
3533+
"'cond_graph' arguments", getInputList(),
3534+
"'input_list'")
3535+
.failed())
3536+
return failure();
3537+
3538+
if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
3539+
"'body_graph' arguments", getInputList(),
3540+
"'input_list'")
3541+
.failed())
3542+
return failure();
3543+
3544+
auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
3545+
if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
3546+
"'body_graph' results", getInputList(),
3547+
"'input_list'")
3548+
.failed())
3549+
return failure();
3550+
3551+
// Condition block output must be a single element tensor with a single bool
3552+
// value.
3553+
auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
3554+
if (condYield.getInputs().size() != 1)
3555+
return emitOpError() << "require 'cond_graph' only have one result";
3556+
3557+
auto condOutType = condYield.getInputs()[0].getType();
3558+
if (errorIfShapeNotSizeOne(*this, condOutType).failed())
3559+
return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
3560+
<< condOutType;
3561+
3562+
if (!getElementTypeOrSelf(condOutType).isInteger(1))
3563+
return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
3564+
<< condOutType;
3565+
3566+
return success();
3567+
}
3568+
34403569
LogicalResult ReverseOp::verify() {
34413570
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
34423571
/* outType = */ getOutput().getType())

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,39 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
449449
return true;
450450
}
451451

452+
// Recursively perform a bottom-up search to determine the maximum nesting
453+
// depth, starting from a specific operation and continuing up to the function
454+
// or module scope. Tosa nesting_depth starts at 0 and increments by one each
455+
// time a new nested `region` is encountered.
456+
static void getMaxNestedDepth(Operation *op, int32_t &depth) {
457+
if (isa<mlir::func::FuncOp>(op) || isa<ModuleOp>(op))
458+
return;
459+
460+
Block *block = op->getBlock();
461+
if (!block)
462+
return;
463+
464+
Region *region = block->getParent();
465+
if (!region)
466+
return;
467+
468+
depth++;
469+
getMaxNestedDepth(region->getParentOp(), depth);
470+
return;
471+
}
472+
473+
bool levelCheckMaxNesting(Operation *op) {
474+
int32_t maxNestedDepth = 0;
475+
getMaxNestedDepth(op, maxNestedDepth);
476+
477+
if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
478+
op->emitOpError() << "failed level check: " << maxNestedDepth
479+
<< " >= MAX_NESTING";
480+
return false;
481+
}
482+
return true;
483+
}
484+
452485
bool levelCheckListSize(Operation *op) {
453486
if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
454487
return levelCheckListSize(op, concat.getInput1().size(), "input1");
@@ -750,6 +783,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
750783
return failure();
751784
}
752785

786+
if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
787+
if (!levelCheckMaxNesting(op)) {
788+
return failure();
789+
}
790+
}
791+
753792
return success();
754793
}
755794

0 commit comments

Comments
 (0)