Skip to content

[mlir][spirv] Add folding for [I|Logical][Not]Equal #74194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ def SPIRV_IEqualOp : SPIRV_LogicalBinaryOp<"IEqual",
%5 = spirv.IEqual %2, %3 : vector<4xi32>
```
}];

let hasFolder = 1;
}

// -----
Expand All @@ -395,6 +397,8 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual",

```
}];

let hasFolder = 1;
}

// -----
Expand Down Expand Up @@ -501,6 +505,8 @@ def SPIRV_LogicalEqualOp : SPIRV_LogicalBinaryOp<"LogicalEqual",
%2 = spirv.LogicalEqual %0, %1 : vector<4xi1>
```
}];

let hasFolder = 1;
}

// -----
Expand Down Expand Up @@ -557,7 +563,8 @@ def SPIRV_LogicalNotEqualOp : SPIRV_LogicalBinaryOp<"LogicalNotEqual",
%2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1>
```
}];
let hasFolder = true;

let hasFolder = 1;
}

// -----
Expand Down
77 changes: 75 additions & 2 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,19 +662,52 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
return Attribute();
}

//===----------------------------------------------------------------------===//
// spirv.LogicalEqualOp
//===----------------------------------------------------------------------===//

OpFoldResult
spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
// x == x -> true
if (getOperand1() == getOperand2()) {
auto trueAttr = BoolAttr::get(getContext(), true);
if (isa<IntegerType>(getType()))
return trueAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, trueAttr);
}

return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
});
}

//===----------------------------------------------------------------------===//
// spirv.LogicalNotEqualOp
//===----------------------------------------------------------------------===//

OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
if (std::optional<bool> rhs =
getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
// x && false = x
// x != false -> x
if (!rhs.value())
return getOperand1();
}

return Attribute();
// x == x -> false
if (getOperand1() == getOperand2()) {
auto falseAttr = BoolAttr::get(getContext(), false);
if (isa<IntegerType>(getType()))
return falseAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, falseAttr);
}

return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
});
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -709,6 +742,46 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
return Attribute();
}

//===----------------------------------------------------------------------===//
// spirv.IEqualOp
//===----------------------------------------------------------------------===//

OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
// x == x -> true
if (getOperand1() == getOperand2()) {
auto trueAttr = BoolAttr::get(getContext(), true);
if (isa<IntegerType>(getType()))
return trueAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, trueAttr);
}

return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
});
}

//===----------------------------------------------------------------------===//
// spirv.INotEqualOp
//===----------------------------------------------------------------------===//

OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
// x == x -> false
if (getOperand1() == getOperand2()) {
auto falseAttr = BoolAttr::get(getContext(), false);
if (isa<IntegerType>(getType()))
return falseAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, falseAttr);
}

return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
});
}

//===----------------------------------------------------------------------===//
// spirv.ShiftLeftLogical
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
// CHECK-LABEL: @logical_equal_scalar
spirv.func @logical_equal_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : i1
%0 = spirv.LogicalEqual %arg0, %arg0 : i1
%0 = spirv.LogicalEqual %arg0, %arg1 : i1
spirv.Return
}

// CHECK-LABEL: @logical_equal_vector
spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : vector<4xi1>
%0 = spirv.LogicalEqual %arg0, %arg0 : vector<4xi1>
%0 = spirv.LogicalEqual %arg0, %arg1 : vector<4xi1>
spirv.Return
}

Expand All @@ -25,14 +25,14 @@ spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None
// CHECK-LABEL: @logical_not_equal_scalar
spirv.func @logical_not_equal_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i1
%0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
%0 = spirv.LogicalNotEqual %arg0, %arg1 : i1
spirv.Return
}

// CHECK-LABEL: @logical_not_equal_vector
spirv.func @logical_not_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : vector<4xi1>
%0 = spirv.LogicalNotEqual %arg0, %arg0 : vector<4xi1>
%0 = spirv.LogicalNotEqual %arg0, %arg1 : vector<4xi1>
spirv.Return
}

Expand Down Expand Up @@ -63,14 +63,14 @@ spirv.func @logical_not_vector(%arg0: vector<4xi1>) "None" {
// CHECK-LABEL: @logical_and_scalar
spirv.func @logical_and_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.and %{{.*}}, %{{.*}} : i1
%0 = spirv.LogicalAnd %arg0, %arg0 : i1
%0 = spirv.LogicalAnd %arg0, %arg1 : i1
spirv.Return
}

// CHECK-LABEL: @logical_and_vector
spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.and %{{.*}}, %{{.*}} : vector<4xi1>
%0 = spirv.LogicalAnd %arg0, %arg0 : vector<4xi1>
%0 = spirv.LogicalAnd %arg0, %arg1 : vector<4xi1>
spirv.Return
}

Expand All @@ -81,13 +81,13 @@ spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None"
// CHECK-LABEL: @logical_or_scalar
spirv.func @logical_or_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.or %{{.*}}, %{{.*}} : i1
%0 = spirv.LogicalOr %arg0, %arg0 : i1
%0 = spirv.LogicalOr %arg0, %arg1 : i1
spirv.Return
}

// CHECK-LABEL: @logical_or_vector
spirv.func @logical_or_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.or %{{.*}}, %{{.*}} : vector<4xi1>
%0 = spirv.LogicalOr %arg0, %arg0 : vector<4xi1>
%0 = spirv.LogicalOr %arg0, %arg1 : vector<4xi1>
spirv.Return
}
165 changes: 165 additions & 0 deletions mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,48 @@ func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<
spirv.ReturnValue %3 : vector<3xi1>
}

// -----

//===----------------------------------------------------------------------===//
// spirv.LogicalEqual
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @logical_equal_same
func.func @logical_equal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
// CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>

%0 = spirv.LogicalEqual %arg0, %arg0 : i1
%1 = spirv.LogicalEqual %arg1, %arg1 : vector<3xi1>
// CHECK: return %[[CTRUE]], %[[CVTRUE]]
return %0, %1 : i1, vector<3xi1>
}

// CHECK-LABEL: @const_fold_scalar_logical_equal
func.func @const_fold_scalar_logical_equal() -> (i1, i1) {
%true = spirv.Constant true
%false = spirv.Constant false

// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
%0 = spirv.LogicalEqual %true, %false : i1
%1 = spirv.LogicalEqual %false, %false : i1

// CHECK: return %[[CFALSE]], %[[CTRUE]]
return %0, %1 : i1, i1
}

// CHECK-LABEL: @const_fold_vector_logical_equal
func.func @const_fold_vector_logical_equal() -> vector<3xi1> {
%cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
%cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>

// CHECK: %[[RET:.*]] = spirv.Constant dense<[true, true, false]>
%0 = spirv.LogicalEqual %cv0, %cv1 : vector<3xi1>

// CHECK: return %[[RET]]
return %0 : vector<3xi1>
}

// -----

Expand All @@ -1064,6 +1106,43 @@ func.func @convert_logical_not_equal_false(%arg: vector<4xi1>) -> vector<4xi1> {
spirv.ReturnValue %0 : vector<4xi1>
}

// CHECK-LABEL: @logical_not_equal_same
func.func @logical_not_equal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
// CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
%0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
%1 = spirv.LogicalNotEqual %arg1, %arg1 : vector<3xi1>

// CHECK: return %[[CFALSE]], %[[CVFALSE]]
return %0, %1 : i1, vector<3xi1>
}

// CHECK-LABEL: @const_fold_scalar_logical_not_equal
func.func @const_fold_scalar_logical_not_equal() -> (i1, i1) {
%true = spirv.Constant true
%false = spirv.Constant false

// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
%0 = spirv.LogicalNotEqual %true, %false : i1
%1 = spirv.LogicalNotEqual %false, %false : i1

// CHECK: return %[[CTRUE]], %[[CFALSE]]
return %0, %1 : i1, i1
}

// CHECK-LABEL: @const_fold_vector_logical_not_equal
func.func @const_fold_vector_logical_not_equal() -> vector<3xi1> {
%cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
%cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>

// CHECK: %[[RET:.*]] = spirv.Constant dense<[false, false, true]>
%0 = spirv.LogicalNotEqual %cv0, %cv1 : vector<3xi1>

// CHECK: return %[[RET]]
return %0 : vector<3xi1>
}

// -----

func.func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
Expand Down Expand Up @@ -1139,6 +1218,92 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3

// -----

//===----------------------------------------------------------------------===//
// spirv.IEqual
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @iequal_same
func.func @iequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
// CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
%0 = spirv.IEqual %arg0, %arg0 : i32
%1 = spirv.IEqual %arg1, %arg1 : vector<3xi32>

// CHECK: return %[[CTRUE]], %[[CVTRUE]]
return %0, %1 : i1, vector<3xi1>
}

// CHECK-LABEL: @const_fold_scalar_iequal
func.func @const_fold_scalar_iequal() -> (i1, i1) {
%c5 = spirv.Constant 5 : i32
%c6 = spirv.Constant 6 : i32

// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
%0 = spirv.IEqual %c5, %c6 : i32
%1 = spirv.IEqual %c5, %c5 : i32

// CHECK: return %[[CFALSE]], %[[CTRUE]]
return %0, %1 : i1, i1
}

// CHECK-LABEL: @const_fold_vector_iequal
func.func @const_fold_vector_iequal() -> vector<3xi1> {
%cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
%cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>

// CHECK: %[[RET:.*]] = spirv.Constant dense<[true, false, true]>
%0 = spirv.IEqual %cv0, %cv1 : vector<3xi32>

// CHECK: return %[[RET]]
return %0 : vector<3xi1>
}

// -----

//===----------------------------------------------------------------------===//
// spirv.INotEqual
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @inotequal_same
func.func @inotequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
// CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
%0 = spirv.INotEqual %arg0, %arg0 : i32
%1 = spirv.INotEqual %arg1, %arg1 : vector<3xi32>

// CHECK: return %[[CFALSE]], %[[CVFALSE]]
return %0, %1 : i1, vector<3xi1>
}

// CHECK-LABEL: @const_fold_scalar_inotequal
func.func @const_fold_scalar_inotequal() -> (i1, i1) {
%c5 = spirv.Constant 5 : i32
%c6 = spirv.Constant 6 : i32

// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
%0 = spirv.INotEqual %c5, %c6 : i32
%1 = spirv.INotEqual %c5, %c5 : i32

// CHECK: return %[[CTRUE]], %[[CFALSE]]
return %0, %1 : i1, i1
}

// CHECK-LABEL: @const_fold_vector_inotequal
func.func @const_fold_vector_inotequal() -> vector<3xi1> {
%cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
%cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>

// CHECK: %[[RET:.*]] = spirv.Constant dense<[false, true, false]>
%0 = spirv.INotEqual %cv0, %cv1 : vector<3xi32>

// CHECK: return %[[RET]]
return %0 : vector<3xi1>
}

// -----

//===----------------------------------------------------------------------===//
// spirv.LeftShiftLogical
//===----------------------------------------------------------------------===//
Expand Down