@@ -572,6 +572,57 @@ static LogicalResult verifyConvOpErrorIf(T op) {
572
572
return success ();
573
573
}
574
574
575
+ // Verify whether same type and shape of the given two types.
576
+ static LogicalResult errorIfTypeOrShapeMismatch (Operation *op, Type type1,
577
+ StringRef name1, Type type2,
578
+ StringRef name2) {
579
+ auto shapeType1 = dyn_cast<ShapedType>(type1);
580
+ auto shapeType2 = dyn_cast<ShapedType>(type2);
581
+ if (!shapeType1 || !shapeType2)
582
+ return failure ();
583
+
584
+ auto elemType1 = shapeType1.getElementType ();
585
+ auto elemType2 = shapeType2.getElementType ();
586
+ if (elemType1 != elemType2)
587
+ return op->emitOpError ()
588
+ << " require same element type for " << name1 << " (" << elemType1
589
+ << " ) and " << name2 << " (" << elemType2 << " )" ;
590
+
591
+ if (failed (verifyCompatibleShape (type1, type2)))
592
+ return op->emitOpError ()
593
+ << " require same shapes for " << name1 << " (" << type1 << " ) and "
594
+ << name2 << " (" << type2 << " )" ;
595
+
596
+ return success ();
597
+ }
598
+
599
+ // Verify whether same length, type, and shape of the given two tensor lists.
600
+ static LogicalResult errorIfTypeOrShapeMismatch (Operation *op, ValueRange list1,
601
+ StringRef name1,
602
+ ValueRange list2,
603
+ StringRef name2) {
604
+ if (list1.size () != list2.size ())
605
+ return op->emitOpError ()
606
+ << " require same number of values in " << name1 << " ("
607
+ << list1.size () << " ) and " << name2 << " (" << list2.size () << " )" ;
608
+
609
+ for (auto [type1, type2] :
610
+ llvm::zip_equal (list1.getTypes (), list2.getTypes ())) {
611
+ if (errorIfTypeOrShapeMismatch (op, type1, name1, type2, name2).failed ())
612
+ return failure ();
613
+ }
614
+
615
+ return success ();
616
+ }
617
+
618
+ static inline LogicalResult errorIfShapeNotSizeOne (Operation *op, Type type) {
619
+ ShapeAdaptor shapeAdaptor (type);
620
+ if (!shapeAdaptor.hasRank () || !shapeAdaptor.hasStaticShape ())
621
+ return success ();
622
+
623
+ return shapeAdaptor.getNumElements () == 1 ? success () : failure ();
624
+ }
625
+
575
626
// verify that inType and outType have same element types
576
627
template <typename T>
577
628
static LogicalResult verifySameElementTypes (T op, Type inType, Type outType) {
@@ -3397,6 +3448,84 @@ void IfOp::print(OpAsmPrinter &p) {
3397
3448
p.printOptionalAttrDict ((*this )->getAttrs ());
3398
3449
}
3399
3450
3451
+ LogicalResult IfOp::verify () {
3452
+ if (errorIfTypeOrShapeMismatch (*this , getThenGraph ().front ().getArguments (),
3453
+ " 'then_graph' arguments" , getInputList (),
3454
+ " 'input_list'" )
3455
+ .failed ())
3456
+ return failure ();
3457
+
3458
+ if (errorIfTypeOrShapeMismatch (*this , getElseGraph ().front ().getArguments (),
3459
+ " 'else_graph' arguments" , getInputList (),
3460
+ " 'input_list'" )
3461
+ .failed ())
3462
+ return failure ();
3463
+
3464
+ auto thenYield = cast<tosa::YieldOp>(getThenGraph ().front ().getTerminator ());
3465
+ if (errorIfTypeOrShapeMismatch (*this , thenYield.getInputs (),
3466
+ " 'then_graph' results" , getOutputList (),
3467
+ " 'output_list'" )
3468
+ .failed ())
3469
+ return failure ();
3470
+
3471
+ auto elseYield = cast<tosa::YieldOp>(getElseGraph ().front ().getTerminator ());
3472
+ if (errorIfTypeOrShapeMismatch (*this , elseYield.getInputs (),
3473
+ " 'else_graph' results" , getOutputList (),
3474
+ " 'output_list'" )
3475
+ .failed ())
3476
+ return failure ();
3477
+
3478
+ auto condType = getCondition ().getType ();
3479
+ if (errorIfShapeNotSizeOne (*this , condType).failed ())
3480
+ return emitOpError () << " 'condition' must be a size 1 tensor, got "
3481
+ << condType;
3482
+
3483
+ return success ();
3484
+ }
3485
+
3486
+ LogicalResult WhileOp::verify () {
3487
+ if (errorIfTypeOrShapeMismatch (*this , getInputList (), " 'input_list'" ,
3488
+ getOutputList (), " 'output_list'" )
3489
+ .failed ())
3490
+ return failure ();
3491
+
3492
+ if (errorIfTypeOrShapeMismatch (*this , getCondGraph ().front ().getArguments (),
3493
+ " 'cond_graph' arguments" , getInputList (),
3494
+ " 'input_list'" )
3495
+ .failed ())
3496
+ return failure ();
3497
+
3498
+ if (errorIfTypeOrShapeMismatch (*this , getBodyGraph ().front ().getArguments (),
3499
+ " 'body_graph' arguments" , getInputList (),
3500
+ " 'input_list'" )
3501
+ .failed ())
3502
+ return failure ();
3503
+
3504
+ auto bodyYield = cast<tosa::YieldOp>(getBodyGraph ().front ().getTerminator ());
3505
+ if (errorIfTypeOrShapeMismatch (*this , bodyYield.getInputs (),
3506
+ " 'body_graph' results" , getInputList (),
3507
+ " 'input_list'" )
3508
+ .failed ())
3509
+ return failure ();
3510
+
3511
+ // Condition block output must be a single element tensor with a single bool
3512
+ // value.
3513
+ auto condYield = cast<tosa::YieldOp>(getCondGraph ().front ().getTerminator ());
3514
+ if (condYield.getInputs ().size () != 1 )
3515
+ return emitOpError () << " require 'cond_graph' only have one result" ;
3516
+
3517
+ auto condOutType = condYield.getInputs ()[0 ].getType ();
3518
+ if (errorIfShapeNotSizeOne (*this , condOutType).failed ())
3519
+ return emitOpError () << " 'cond_graph' result must be a size 1 tensor, got "
3520
+ << condOutType;
3521
+
3522
+ if (!getElementTypeOrSelf (condOutType).isInteger (1 ))
3523
+ return emitOpError () << " 'cond_graph' result must be a boolean tensor, got "
3524
+ << condOutType;
3525
+
3526
+ return success ();
3527
+ }
3528
+
3400
3529
LogicalResult ReverseOp::verify () {
3401
3530
if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
3402
3531
/* outType = */ getOutput ().getType ())
0 commit comments