Skip to content

Commit 880f50c

Browse files
committed
review comments:
- fix coding style
1 parent d7d3026 commit 880f50c

File tree

1 file changed

+36
-48
lines changed

1 file changed

+36
-48
lines changed

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 36 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -670,22 +670,19 @@ OpFoldResult
670670
spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
671671
// x == x -> true
672672
if (getOperand1() == getOperand2()) {
673-
auto type = getType();
674-
if (isa<IntegerType>(type)) {
675-
return BoolAttr::get(getContext(), true);
673+
auto trueAttr = BoolAttr::get(getContext(), true);
674+
if (isa<IntegerType>(getType())) {
675+
return trueAttr;
676676
}
677-
if (isa<VectorType>(type)) {
678-
auto vtType = cast<ShapedType>(type);
679-
auto element = BoolAttr::get(getContext(), true);
680-
return DenseElementsAttr::get(vtType, element);
677+
if (auto vecTy = dyn_cast<VectorType>(getType())) {
678+
return SplatElementsAttr::get(vecTy, trueAttr);
681679
}
682680
}
683681

684-
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
685-
[](const APInt &a, const APInt &b) {
686-
APInt zero = APInt::getZero(1);
687-
return a == b ? (zero + 1) : zero;
688-
});
682+
return constFoldBinaryOp<IntegerAttr>(
683+
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
684+
return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
685+
});
689686
}
690687

691688
//===----------------------------------------------------------------------===//
@@ -702,22 +699,19 @@ OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
702699

703700
// x == x -> false
704701
if (getOperand1() == getOperand2()) {
705-
auto type = getType();
706-
if (isa<IntegerType>(type)) {
707-
return BoolAttr::get(getContext(), false);
702+
auto falseAttr = BoolAttr::get(getContext(), false);
703+
if (isa<IntegerType>(getType())) {
704+
return falseAttr;
708705
}
709-
if (isa<VectorType>(type)) {
710-
auto vtType = cast<ShapedType>(type);
711-
auto element = BoolAttr::get(getContext(), false);
712-
return DenseElementsAttr::get(vtType, element);
706+
if (auto vecTy = dyn_cast<VectorType>(getType())) {
707+
return SplatElementsAttr::get(vecTy, falseAttr);
713708
}
714709
}
715710

716-
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
717-
[](const APInt &a, const APInt &b) {
718-
APInt zero = APInt::getZero(1);
719-
return a == b ? zero : (zero + 1);
720-
});
711+
return constFoldBinaryOp<IntegerAttr>(
712+
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
713+
return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
714+
});
721715
}
722716

723717
//===----------------------------------------------------------------------===//
@@ -759,22 +753,19 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
759753
OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
760754
// x == x -> true
761755
if (getOperand1() == getOperand2()) {
762-
auto type = getType();
763-
if (isa<IntegerType>(type)) {
764-
return BoolAttr::get(getContext(), true);
756+
auto trueAttr = BoolAttr::get(getContext(), true);
757+
if (isa<IntegerType>(getType())) {
758+
return trueAttr;
765759
}
766-
if (isa<VectorType>(type)) {
767-
auto vtType = cast<ShapedType>(type);
768-
auto element = BoolAttr::get(getContext(), true);
769-
return DenseElementsAttr::get(vtType, element);
760+
if (auto vecTy = dyn_cast<VectorType>(getType())) {
761+
return SplatElementsAttr::get(vecTy, trueAttr);
770762
}
771763
}
772764

773-
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
774-
[](const APInt &a, const APInt &b) {
775-
APInt zero = APInt::getZero(1);
776-
return a == b ? (zero + 1) : zero;
777-
});
765+
return constFoldBinaryOp<IntegerAttr>(
766+
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
767+
return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
768+
});
778769
}
779770

780771
//===----------------------------------------------------------------------===//
@@ -784,22 +775,19 @@ OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
784775
OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
785776
// x == x -> false
786777
if (getOperand1() == getOperand2()) {
787-
auto type = getType();
788-
if (isa<IntegerType>(type)) {
789-
return BoolAttr::get(getContext(), false);
778+
auto falseAttr = BoolAttr::get(getContext(), false);
779+
if (isa<IntegerType>(getType())) {
780+
return falseAttr;
790781
}
791-
if (isa<VectorType>(type)) {
792-
auto vtType = cast<ShapedType>(type);
793-
auto element = BoolAttr::get(getContext(), false);
794-
return DenseElementsAttr::get(vtType, element);
782+
if (auto vecTy = dyn_cast<VectorType>(getType())) {
783+
return SplatElementsAttr::get(vecTy, falseAttr);
795784
}
796785
}
797786

798-
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
799-
[](const APInt &a, const APInt &b) {
800-
APInt zero = APInt::getZero(1);
801-
return a == b ? zero : (zero + 1);
802-
});
787+
return constFoldBinaryOp<IntegerAttr>(
788+
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
789+
return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
790+
});
803791
}
804792

805793
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)