Skip to content

Commit d8a3f5e

Browse files
committed
[mlir][tosa] Add error if and level checks for COND_IF & WHILE_LOOP
Error if checks: verify whether the same length and type between input list, output list, and control-flow blocks. Level_checks: verify whether the nested depth exceeds MAX_NESTING.
1 parent 5eca2dd commit d8a3f5e

File tree

5 files changed

+745
-79
lines changed

5 files changed

+745
-79
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2559,6 +2559,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
25592559
);
25602560

25612561
let hasCustomAssemblyFormat = 1;
2562+
let hasVerifier = 1;
25622563
}
25632564

25642565
//===----------------------------------------------------------------------===//
@@ -2597,6 +2598,7 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
25972598
);
25982599

25992600
let hasCustomAssemblyFormat = 1;
2601+
let hasVerifier = 1;
26002602
}
26012603

26022604
include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,57 @@ static LogicalResult verifyConvOpErrorIf(T op) {
572572
return success();
573573
}
574574

575+
// Verify whether same type and shape of the given two types.
576+
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
577+
StringRef name1, Type type2,
578+
StringRef name2) {
579+
auto shapeType1 = dyn_cast<ShapedType>(type1);
580+
auto shapeType2 = dyn_cast<ShapedType>(type2);
581+
if (!shapeType1 || !shapeType2)
582+
return failure();
583+
584+
auto elemType1 = shapeType1.getElementType();
585+
auto elemType2 = shapeType2.getElementType();
586+
if (elemType1 != elemType2)
587+
return op->emitOpError()
588+
<< "require same element type for " << name1 << " (" << elemType1
589+
<< ") and " << name2 << " (" << elemType2 << ")";
590+
591+
if (failed(verifyCompatibleShape(type1, type2)))
592+
return op->emitOpError()
593+
<< "require same shapes for " << name1 << " (" << type1 << ") and "
594+
<< name2 << " (" << type2 << ")";
595+
596+
return success();
597+
}
598+
599+
// Verify whether same length, type, and shape of the given two tensor lists.
600+
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
601+
StringRef name1,
602+
ValueRange list2,
603+
StringRef name2) {
604+
if (list1.size() != list2.size())
605+
return op->emitOpError()
606+
<< "require same number of values in " << name1 << " ("
607+
<< list1.size() << ") and " << name2 << " (" << list2.size() << ")";
608+
609+
for (auto [type1, type2] :
610+
llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
611+
if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
612+
return failure();
613+
}
614+
615+
return success();
616+
}
617+
618+
static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
619+
ShapeAdaptor shapeAdaptor(type);
620+
if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
621+
return success();
622+
623+
return shapeAdaptor.getNumElements() == 1 ? success() : failure();
624+
}
625+
575626
// verify that inType and outType have same element types
576627
template <typename T>
577628
static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
@@ -3397,6 +3448,84 @@ void IfOp::print(OpAsmPrinter &p) {
33973448
p.printOptionalAttrDict((*this)->getAttrs());
33983449
}
33993450

3451+
LogicalResult IfOp::verify() {
3452+
if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
3453+
"'then_graph' arguments", getInputList(),
3454+
"'input_list'")
3455+
.failed())
3456+
return failure();
3457+
3458+
if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
3459+
"'else_graph' arguments", getInputList(),
3460+
"'input_list'")
3461+
.failed())
3462+
return failure();
3463+
3464+
auto thenYield = cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
3465+
if (errorIfTypeOrShapeMismatch(*this, thenYield.getInputs(),
3466+
"'then_graph' results", getOutputList(),
3467+
"'output_list'")
3468+
.failed())
3469+
return failure();
3470+
3471+
auto elseYield = cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
3472+
if (errorIfTypeOrShapeMismatch(*this, elseYield.getInputs(),
3473+
"'else_graph' results", getOutputList(),
3474+
"'output_list'")
3475+
.failed())
3476+
return failure();
3477+
3478+
auto condType = getCondition().getType();
3479+
if (errorIfShapeNotSizeOne(*this, condType).failed())
3480+
return emitOpError() << "'condition' must be a size 1 tensor, got "
3481+
<< condType;
3482+
3483+
return success();
3484+
}
3485+
3486+
LogicalResult WhileOp::verify() {
3487+
if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
3488+
getOutputList(), "'output_list'")
3489+
.failed())
3490+
return failure();
3491+
3492+
if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
3493+
"'cond_graph' arguments", getInputList(),
3494+
"'input_list'")
3495+
.failed())
3496+
return failure();
3497+
3498+
if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
3499+
"'body_graph' arguments", getInputList(),
3500+
"'input_list'")
3501+
.failed())
3502+
return failure();
3503+
3504+
auto bodyYield = cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
3505+
if (errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
3506+
"'body_graph' results", getInputList(),
3507+
"'input_list'")
3508+
.failed())
3509+
return failure();
3510+
3511+
// Condition block output must be a single element tensor with a single bool
3512+
// value.
3513+
auto condYield = cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
3514+
if (condYield.getInputs().size() != 1)
3515+
return emitOpError() << "require 'cond_graph' only have one result";
3516+
3517+
auto condOutType = condYield.getInputs()[0].getType();
3518+
if (errorIfShapeNotSizeOne(*this, condOutType).failed())
3519+
return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
3520+
<< condOutType;
3521+
3522+
if (!getElementTypeOrSelf(condOutType).isInteger(1))
3523+
return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
3524+
<< condOutType;
3525+
3526+
return success();
3527+
}
3528+
34003529
LogicalResult ReverseOp::verify() {
34013530
if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
34023531
/* outType = */ getOutput().getType())

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,42 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
449449
return true;
450450
}
451451

452+
// Preform depth-first search in Tosa IR structure to find the maximum nesting
453+
// depth. Tosa nesting_depth starts at 0 and increase by one each time a new
454+
// nested `region` is encountered.
455+
456+
static int32_t getMaxNestedDepth(Operation *op) {
457+
int32_t depth = 0;
458+
for (Region &region : op->getRegions())
459+
depth = std::max(depth, getMaxNestedDepth(region));
460+
return depth;
461+
}
462+
463+
static int32_t getMaxNestedDepth(Block &block) {
464+
int32_t depth = 0;
465+
for (Operation &op : block.getOperations())
466+
depth = std::max(depth, getMaxNestedDepth(&op));
467+
return depth;
468+
}
469+
470+
static int32_t getMaxNestedDepth(Region &region) {
471+
int32_t depth = 0;
472+
for (Block &block : region.getBlocks())
473+
depth = std::max(depth, getMaxNestedDepth(block));
474+
// Increase the nested depth.
475+
return depth + 1;
476+
}
477+
478+
bool levelCheckMaxNesting(Operation *op) {
479+
int32_t maxNestedDepth = getMaxNestedDepth(op);
480+
if (maxNestedDepth >= tosaLevel.MAX_NESTING) {
481+
op->emitOpError() << "failed level check: " << maxNestedDepth
482+
<< " >= MAX_NESTING";
483+
return false;
484+
}
485+
return true;
486+
}
487+
452488
bool levelCheckListSize(Operation *op) {
453489
if (auto concat = dyn_cast<tosa::ConcatOp>(op)) {
454490
return levelCheckListSize(op, concat.getInput1().size(), "input1");
@@ -750,6 +786,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
750786
return failure();
751787
}
752788

789+
if (isa<tosa::IfOp>(op) || isa<tosa::WhileOp>(op)) {
790+
if (!levelCheckMaxNesting(op)) {
791+
return failure();
792+
}
793+
}
794+
753795
return success();
754796
}
755797

0 commit comments

Comments
 (0)