@@ -670,22 +670,19 @@ OpFoldResult
670
670
spirv::LogicalEqualOp::fold (spirv::LogicalEqualOp::FoldAdaptor adaptor) {
671
671
// x == x -> true
672
672
if (getOperand1 () == getOperand2 ()) {
673
- auto type = getType ( );
674
- if (isa<IntegerType>(type )) {
675
- return BoolAttr::get ( getContext (), true ) ;
673
+ auto trueAttr = BoolAttr::get ( getContext (), true );
674
+ if (isa<IntegerType>(getType () )) {
675
+ return trueAttr ;
676
676
}
677
- if (isa<VectorType>(type)) {
678
- auto vtType = cast<ShapedType>(type);
679
- auto element = BoolAttr::get (getContext (), true );
680
- return DenseElementsAttr::get (vtType, element);
677
+ if (auto vecTy = dyn_cast<VectorType>(getType ())) {
678
+ return SplatElementsAttr::get (vecTy, trueAttr);
681
679
}
682
680
}
683
681
684
- return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands (),
685
- [](const APInt &a, const APInt &b) {
686
- APInt zero = APInt::getZero (1 );
687
- return a == b ? (zero + 1 ) : zero;
688
- });
682
+ return constFoldBinaryOp<IntegerAttr>(
683
+ adaptor.getOperands (), [](const APInt &a, const APInt &b) {
684
+ return a == b ? APInt::getAllOnes (1 ) : APInt::getZero (1 );
685
+ });
689
686
}
690
687
691
688
// ===----------------------------------------------------------------------===//
@@ -702,22 +699,19 @@ OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
702
699
703
700
// x == x -> false
704
701
if (getOperand1 () == getOperand2 ()) {
705
- auto type = getType ( );
706
- if (isa<IntegerType>(type )) {
707
- return BoolAttr::get ( getContext (), false ) ;
702
+ auto falseAttr = BoolAttr::get ( getContext (), false );
703
+ if (isa<IntegerType>(getType () )) {
704
+ return falseAttr ;
708
705
}
709
- if (isa<VectorType>(type)) {
710
- auto vtType = cast<ShapedType>(type);
711
- auto element = BoolAttr::get (getContext (), false );
712
- return DenseElementsAttr::get (vtType, element);
706
+ if (auto vecTy = dyn_cast<VectorType>(getType ())) {
707
+ return SplatElementsAttr::get (vecTy, falseAttr);
713
708
}
714
709
}
715
710
716
- return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands (),
717
- [](const APInt &a, const APInt &b) {
718
- APInt zero = APInt::getZero (1 );
719
- return a == b ? zero : (zero + 1 );
720
- });
711
+ return constFoldBinaryOp<IntegerAttr>(
712
+ adaptor.getOperands (), [](const APInt &a, const APInt &b) {
713
+ return a == b ? APInt::getZero (1 ) : APInt::getAllOnes (1 );
714
+ });
721
715
}
722
716
723
717
// ===----------------------------------------------------------------------===//
@@ -759,22 +753,19 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
759
753
OpFoldResult spirv::IEqualOp::fold (spirv::IEqualOp::FoldAdaptor adaptor) {
760
754
// x == x -> true
761
755
if (getOperand1 () == getOperand2 ()) {
762
- auto type = getType ( );
763
- if (isa<IntegerType>(type )) {
764
- return BoolAttr::get ( getContext (), true ) ;
756
+ auto trueAttr = BoolAttr::get ( getContext (), true );
757
+ if (isa<IntegerType>(getType () )) {
758
+ return trueAttr ;
765
759
}
766
- if (isa<VectorType>(type)) {
767
- auto vtType = cast<ShapedType>(type);
768
- auto element = BoolAttr::get (getContext (), true );
769
- return DenseElementsAttr::get (vtType, element);
760
+ if (auto vecTy = dyn_cast<VectorType>(getType ())) {
761
+ return SplatElementsAttr::get (vecTy, trueAttr);
770
762
}
771
763
}
772
764
773
- return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands (), getType (),
774
- [](const APInt &a, const APInt &b) {
775
- APInt zero = APInt::getZero (1 );
776
- return a == b ? (zero + 1 ) : zero;
777
- });
765
+ return constFoldBinaryOp<IntegerAttr>(
766
+ adaptor.getOperands (), getType (), [](const APInt &a, const APInt &b) {
767
+ return a == b ? APInt::getAllOnes (1 ) : APInt::getZero (1 );
768
+ });
778
769
}
779
770
780
771
// ===----------------------------------------------------------------------===//
@@ -784,22 +775,19 @@ OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
784
775
OpFoldResult spirv::INotEqualOp::fold (spirv::INotEqualOp::FoldAdaptor adaptor) {
785
776
// x == x -> false
786
777
if (getOperand1 () == getOperand2 ()) {
787
- auto type = getType ( );
788
- if (isa<IntegerType>(type )) {
789
- return BoolAttr::get ( getContext (), false ) ;
778
+ auto falseAttr = BoolAttr::get ( getContext (), false );
779
+ if (isa<IntegerType>(getType () )) {
780
+ return falseAttr ;
790
781
}
791
- if (isa<VectorType>(type)) {
792
- auto vtType = cast<ShapedType>(type);
793
- auto element = BoolAttr::get (getContext (), false );
794
- return DenseElementsAttr::get (vtType, element);
782
+ if (auto vecTy = dyn_cast<VectorType>(getType ())) {
783
+ return SplatElementsAttr::get (vecTy, falseAttr);
795
784
}
796
785
}
797
786
798
- return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands (), getType (),
799
- [](const APInt &a, const APInt &b) {
800
- APInt zero = APInt::getZero (1 );
801
- return a == b ? zero : (zero + 1 );
802
- });
787
+ return constFoldBinaryOp<IntegerAttr>(
788
+ adaptor.getOperands (), getType (), [](const APInt &a, const APInt &b) {
789
+ return a == b ? APInt::getZero (1 ) : APInt::getAllOnes (1 );
790
+ });
803
791
}
804
792
805
793
// ===----------------------------------------------------------------------===//
0 commit comments