@@ -562,6 +562,57 @@ static LogicalResult verifyConvOpErrorIf(T op) {
562
562
return success ();
563
563
}
564
564
565
+ // Verify whether same type and shape of the given two types.
566
+ static LogicalResult errorIfTypeOrShapeMismatch (Operation *op, Type type1,
567
+ StringRef name1, Type type2,
568
+ StringRef name2) {
569
+ auto shapeType1 = dyn_cast<ShapedType>(type1);
570
+ auto shapeType2 = dyn_cast<ShapedType>(type2);
571
+ if (!shapeType1 || !shapeType2)
572
+ return failure ();
573
+
574
+ auto elemType1 = shapeType1.getElementType ();
575
+ auto elemType2 = shapeType2.getElementType ();
576
+ if (elemType1 != elemType2)
577
+ return op->emitOpError ()
578
+ << " require same element type for " << name1 << " (" << elemType1
579
+ << " ) and " << name2 << " (" << elemType2 << " )" ;
580
+
581
+ if (failed (verifyCompatibleShape (type1, type2)))
582
+ return op->emitOpError ()
583
+ << " require same shapes for " << name1 << " (" << type1 << " ) and "
584
+ << name2 << " (" << type2 << " )" ;
585
+
586
+ return success ();
587
+ }
588
+
589
+ // Verify whether same length, type, and shape of the given two tensor lists.
590
+ static LogicalResult errorIfTypeOrShapeMismatch (Operation *op, ValueRange list1,
591
+ StringRef name1,
592
+ ValueRange list2,
593
+ StringRef name2) {
594
+ if (list1.size () != list2.size ())
595
+ return op->emitOpError ()
596
+ << " require same number of values in " << name1 << " ("
597
+ << list1.size () << " ) and " << name2 << " (" << list2.size () << " )" ;
598
+
599
+ for (auto [type1, type2] :
600
+ llvm::zip_equal (list1.getTypes (), list2.getTypes ())) {
601
+ if (errorIfTypeOrShapeMismatch (op, type1, name1, type2, name2).failed ())
602
+ return failure ();
603
+ }
604
+
605
+ return success ();
606
+ }
607
+
608
+ static inline LogicalResult errorIfShapeNotSizeOne (Operation *op, Type type) {
609
+ ShapeAdaptor shapeAdaptor (type);
610
+ if (!shapeAdaptor.hasRank () || !shapeAdaptor.hasStaticShape ())
611
+ return success ();
612
+
613
+ return shapeAdaptor.getNumElements () == 1 ? success () : failure ();
614
+ }
615
+
565
616
// verify that inType and outType have same element types
566
617
template <typename T>
567
618
static LogicalResult verifySameElementTypes (T op, Type inType, Type outType) {
@@ -3437,6 +3488,84 @@ void IfOp::print(OpAsmPrinter &p) {
3437
3488
p.printOptionalAttrDict ((*this )->getAttrs ());
3438
3489
}
3439
3490
3491
+ LogicalResult IfOp::verify () {
3492
+ if (errorIfTypeOrShapeMismatch (*this , getThenGraph ().front ().getArguments (),
3493
+ " 'then_graph' arguments" , getInputList (),
3494
+ " 'input_list'" )
3495
+ .failed ())
3496
+ return failure ();
3497
+
3498
+ if (errorIfTypeOrShapeMismatch (*this , getElseGraph ().front ().getArguments (),
3499
+ " 'else_graph' arguments" , getInputList (),
3500
+ " 'input_list'" )
3501
+ .failed ())
3502
+ return failure ();
3503
+
3504
+ auto thenYield = cast<tosa::YieldOp>(getThenGraph ().front ().getTerminator ());
3505
+ if (errorIfTypeOrShapeMismatch (*this , thenYield.getInputs (),
3506
+ " 'then_graph' results" , getOutputList (),
3507
+ " 'output_list'" )
3508
+ .failed ())
3509
+ return failure ();
3510
+
3511
+ auto elseYield = cast<tosa::YieldOp>(getElseGraph ().front ().getTerminator ());
3512
+ if (errorIfTypeOrShapeMismatch (*this , elseYield.getInputs (),
3513
+ " 'else_graph' results" , getOutputList (),
3514
+ " 'output_list'" )
3515
+ .failed ())
3516
+ return failure ();
3517
+
3518
+ auto condType = getCondition ().getType ();
3519
+ if (errorIfShapeNotSizeOne (*this , condType).failed ())
3520
+ return emitOpError () << " 'condition' must be a size 1 tensor, got "
3521
+ << condType;
3522
+
3523
+ return success ();
3524
+ }
3525
+
3526
+ LogicalResult WhileOp::verify () {
3527
+ if (errorIfTypeOrShapeMismatch (*this , getInputList (), " 'input_list'" ,
3528
+ getOutputList (), " 'output_list'" )
3529
+ .failed ())
3530
+ return failure ();
3531
+
3532
+ if (errorIfTypeOrShapeMismatch (*this , getCondGraph ().front ().getArguments (),
3533
+ " 'cond_graph' arguments" , getInputList (),
3534
+ " 'input_list'" )
3535
+ .failed ())
3536
+ return failure ();
3537
+
3538
+ if (errorIfTypeOrShapeMismatch (*this , getBodyGraph ().front ().getArguments (),
3539
+ " 'body_graph' arguments" , getInputList (),
3540
+ " 'input_list'" )
3541
+ .failed ())
3542
+ return failure ();
3543
+
3544
+ auto bodyYield = cast<tosa::YieldOp>(getBodyGraph ().front ().getTerminator ());
3545
+ if (errorIfTypeOrShapeMismatch (*this , bodyYield.getInputs (),
3546
+ " 'body_graph' results" , getInputList (),
3547
+ " 'input_list'" )
3548
+ .failed ())
3549
+ return failure ();
3550
+
3551
+ // Condition block output must be a single element tensor with a single bool
3552
+ // value.
3553
+ auto condYield = cast<tosa::YieldOp>(getCondGraph ().front ().getTerminator ());
3554
+ if (condYield.getInputs ().size () != 1 )
3555
+ return emitOpError () << " require 'cond_graph' only have one result" ;
3556
+
3557
+ auto condOutType = condYield.getInputs ()[0 ].getType ();
3558
+ if (errorIfShapeNotSizeOne (*this , condOutType).failed ())
3559
+ return emitOpError () << " 'cond_graph' result must be a size 1 tensor, got "
3560
+ << condOutType;
3561
+
3562
+ if (!getElementTypeOrSelf (condOutType).isInteger (1 ))
3563
+ return emitOpError () << " 'cond_graph' result must be a boolean tensor, got "
3564
+ << condOutType;
3565
+
3566
+ return success ();
3567
+ }
3568
+
3440
3569
LogicalResult ReverseOp::verify () {
3441
3570
if (verifySameElementTypes (*this , /* inType = */ getInput1 ().getType (),
3442
3571
/* outType = */ getOutput ().getType ())
0 commit comments