-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][spirv] Add folding for [I|Logical][Not]Equal #74194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
1ef8988
to
73ff53c
Compare
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Finn Plummer (inbelic) ChangesFull diff: https://github.com/llvm/llvm-project/pull/74194.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
index cf38c15d20dc3..0053cd5fc9448 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
@@ -473,6 +473,8 @@ def SPIRV_IEqualOp : SPIRV_LogicalBinaryOp<"IEqual",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -506,6 +508,8 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual",
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -644,6 +648,8 @@ def SPIRV_LogicalEqualOp : SPIRV_LogicalBinaryOp<"LogicalEqual",
%2 = spirv.LogicalEqual %0, %1 : vector<4xi1>
```
}];
+
+ let hasFolder = 1;
}
// -----
@@ -713,7 +719,8 @@ def SPIRV_LogicalNotEqualOp : SPIRV_LogicalBinaryOp<"LogicalNotEqual",
%2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1>
```
}];
- let hasFolder = true;
+
+ let hasFolder = 1;
}
// -----
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
index 9acd982dc95af..16efe8797f4a3 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
@@ -309,6 +309,32 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
return Attribute();
}
+//===----------------------------------------------------------------------===//
+// spirv.LogicalEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult
+spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
+ // x == x -> true
+ if (getOperand1() == getOperand2()) {
+ auto type = getType();
+ if (isa<IntegerType>(type)) {
+ return BoolAttr::get(getContext(), true);
+ }
+ if (isa<VectorType>(type)) {
+ auto vtType = cast<ShapedType>(type);
+ auto element = BoolAttr::get(getContext(), true);
+ return DenseElementsAttr::get(vtType, element);
+ }
+ }
+
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(1);
+ return a == b ? (zero + 1) : zero;
+ });
+}
+
//===----------------------------------------------------------------------===//
// spirv.LogicalNotEqualOp
//===----------------------------------------------------------------------===//
@@ -316,12 +342,29 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
if (std::optional<bool> rhs =
getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
- // x && false = x
+ // x != false -> x
if (!rhs.value())
return getOperand1();
}
- return Attribute();
+ // x == x -> false
+ if (getOperand1() == getOperand2()) {
+ auto type = getType();
+ if (isa<IntegerType>(type)) {
+ return BoolAttr::get(getContext(), false);
+ }
+ if (isa<VectorType>(type)) {
+ auto vtType = cast<ShapedType>(type);
+ auto element = BoolAttr::get(getContext(), false);
+ return DenseElementsAttr::get(vtType, element);
+ }
+ }
+
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
+ [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(1);
+ return a == b ? zero : (zero + 1);
+ });
}
//===----------------------------------------------------------------------===//
@@ -356,6 +399,56 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
return Attribute();
}
+//===----------------------------------------------------------------------===//
+// spirv.IEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
+ // x == x -> true
+ if (getOperand1() == getOperand2()) {
+ auto type = getType();
+ if (isa<IntegerType>(type)) {
+ return BoolAttr::get(getContext(), true);
+ }
+ if (isa<VectorType>(type)) {
+ auto vtType = cast<ShapedType>(type);
+ auto element = BoolAttr::get(getContext(), true);
+ return DenseElementsAttr::get(vtType, element);
+ }
+ }
+
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
+ [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(1);
+ return a == b ? (zero + 1) : zero;
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// spirv.INotEqualOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
+ // x == x -> false
+ if (getOperand1() == getOperand2()) {
+ auto type = getType();
+ if (isa<IntegerType>(type)) {
+ return BoolAttr::get(getContext(), false);
+ }
+ if (isa<VectorType>(type)) {
+ auto vtType = cast<ShapedType>(type);
+ auto element = BoolAttr::get(getContext(), false);
+ return DenseElementsAttr::get(vtType, element);
+ }
+ }
+
+ return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(), getType(),
+ [](const APInt &a, const APInt &b) {
+ APInt zero = APInt::getZero(1);
+ return a == b ? zero : (zero + 1);
+ });
+}
+
//===----------------------------------------------------------------------===//
// spirv.mlir.selection
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
index 6d93480d3ed14..aab2dce980ca7 100644
--- a/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
+++ b/mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
@@ -7,14 +7,14 @@
// CHECK-LABEL: @logical_equal_scalar
spirv.func @logical_equal_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : i1
- %0 = spirv.LogicalEqual %arg0, %arg0 : i1
+ %0 = spirv.LogicalEqual %arg0, %arg1 : i1
spirv.Return
}
// CHECK-LABEL: @logical_equal_vector
spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : vector<4xi1>
- %0 = spirv.LogicalEqual %arg0, %arg0 : vector<4xi1>
+ %0 = spirv.LogicalEqual %arg0, %arg1 : vector<4xi1>
spirv.Return
}
@@ -25,14 +25,14 @@ spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None
// CHECK-LABEL: @logical_not_equal_scalar
spirv.func @logical_not_equal_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i1
- %0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
+ %0 = spirv.LogicalNotEqual %arg0, %arg1 : i1
spirv.Return
}
// CHECK-LABEL: @logical_not_equal_vector
spirv.func @logical_not_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : vector<4xi1>
- %0 = spirv.LogicalNotEqual %arg0, %arg0 : vector<4xi1>
+ %0 = spirv.LogicalNotEqual %arg0, %arg1 : vector<4xi1>
spirv.Return
}
@@ -63,14 +63,14 @@ spirv.func @logical_not_vector(%arg0: vector<4xi1>) "None" {
// CHECK-LABEL: @logical_and_scalar
spirv.func @logical_and_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.and %{{.*}}, %{{.*}} : i1
- %0 = spirv.LogicalAnd %arg0, %arg0 : i1
+ %0 = spirv.LogicalAnd %arg0, %arg1 : i1
spirv.Return
}
// CHECK-LABEL: @logical_and_vector
spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.and %{{.*}}, %{{.*}} : vector<4xi1>
- %0 = spirv.LogicalAnd %arg0, %arg0 : vector<4xi1>
+ %0 = spirv.LogicalAnd %arg0, %arg1 : vector<4xi1>
spirv.Return
}
@@ -81,13 +81,13 @@ spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None"
// CHECK-LABEL: @logical_or_scalar
spirv.func @logical_or_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.or %{{.*}}, %{{.*}} : i1
- %0 = spirv.LogicalOr %arg0, %arg0 : i1
+ %0 = spirv.LogicalOr %arg0, %arg1 : i1
spirv.Return
}
// CHECK-LABEL: @logical_or_vector
spirv.func @logical_or_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.or %{{.*}}, %{{.*}} : vector<4xi1>
- %0 = spirv.LogicalOr %arg0, %arg0 : vector<4xi1>
+ %0 = spirv.LogicalOr %arg0, %arg1 : vector<4xi1>
spirv.Return
}
diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
index 0200805a44439..7a8e262db266a 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
@@ -569,6 +569,48 @@ func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<
spirv.ReturnValue %3 : vector<3xi1>
}
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.LogicalEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @logical_equal_same
+func.func @logical_equal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
+
+ %0 = spirv.LogicalEqual %arg0, %arg0 : i1
+ %1 = spirv.LogicalEqual %arg1, %arg1 : vector<3xi1>
+ // CHECK: return %[[CTRUE]], %[[CVTRUE]]
+ return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_logical_equal
+func.func @const_fold_scalar_logical_equal() -> (i1, i1) {
+ %true = spirv.Constant true
+ %false = spirv.Constant false
+
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ %0 = spirv.LogicalEqual %true, %false : i1
+ %1 = spirv.LogicalEqual %false, %false : i1
+
+ // CHECK: return %[[CFALSE]], %[[CTRUE]]
+ return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_logical_equal
+func.func @const_fold_vector_logical_equal() -> vector<3xi1> {
+ %cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
+ %cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[true, true, false]>
+ %0 = spirv.LogicalEqual %cv0, %cv1 : vector<3xi1>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi1>
+}
// -----
@@ -585,6 +627,43 @@ func.func @convert_logical_not_equal_false(%arg: vector<4xi1>) -> vector<4xi1> {
spirv.ReturnValue %0 : vector<4xi1>
}
+// CHECK-LABEL: @logical_not_equal_same
+func.func @logical_not_equal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ // CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
+ %0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
+ %1 = spirv.LogicalNotEqual %arg1, %arg1 : vector<3xi1>
+
+ // CHECK: return %[[CFALSE]], %[[CVFALSE]]
+ return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_logical_not_equal
+func.func @const_fold_scalar_logical_not_equal() -> (i1, i1) {
+ %true = spirv.Constant true
+ %false = spirv.Constant false
+
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ %0 = spirv.LogicalNotEqual %true, %false : i1
+ %1 = spirv.LogicalNotEqual %false, %false : i1
+
+ // CHECK: return %[[CTRUE]], %[[CFALSE]]
+ return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_logical_not_equal
+func.func @const_fold_vector_logical_not_equal() -> vector<3xi1> {
+ %cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
+ %cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[false, false, true]>
+ %0 = spirv.LogicalNotEqual %cv0, %cv1 : vector<3xi1>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi1>
+}
+
// -----
func.func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
@@ -660,6 +739,92 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3
// -----
+//===----------------------------------------------------------------------===//
+// spirv.IEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @iequal_same
+func.func @iequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
+ %0 = spirv.IEqual %arg0, %arg0 : i32
+ %1 = spirv.IEqual %arg1, %arg1 : vector<3xi32>
+
+ // CHECK: return %[[CTRUE]], %[[CVTRUE]]
+ return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_iequal
+func.func @const_fold_scalar_iequal() -> (i1, i1) {
+ %c5 = spirv.Constant 5 : i32
+ %c6 = spirv.Constant 6 : i32
+
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ %0 = spirv.IEqual %c5, %c6 : i32
+ %1 = spirv.IEqual %c5, %c5 : i32
+
+ // CHECK: return %[[CFALSE]], %[[CTRUE]]
+ return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_iequal
+func.func @const_fold_vector_iequal() -> vector<3xi1> {
+ %cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
+ %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[true, false, true]>
+ %0 = spirv.IEqual %cv0, %cv1 : vector<3xi32>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi1>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.INotEqual
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @inotequal_same
+func.func @inotequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ // CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
+ %0 = spirv.INotEqual %arg0, %arg0 : i32
+ %1 = spirv.INotEqual %arg1, %arg1 : vector<3xi32>
+
+ // CHECK: return %[[CFALSE]], %[[CVFALSE]]
+ return %0, %1 : i1, vector<3xi1>
+}
+
+// CHECK-LABEL: @const_fold_scalar_inotequal
+func.func @const_fold_scalar_inotequal() -> (i1, i1) {
+ %c5 = spirv.Constant 5 : i32
+ %c6 = spirv.Constant 6 : i32
+
+ // CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
+ // CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
+ %0 = spirv.INotEqual %c5, %c6 : i32
+ %1 = spirv.INotEqual %c5, %c5 : i32
+
+ // CHECK: return %[[CTRUE]], %[[CFALSE]]
+ return %0, %1 : i1, i1
+}
+
+// CHECK-LABEL: @const_fold_vector_inotequal
+func.func @const_fold_vector_inotequal() -> vector<3xi1> {
+ %cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
+ %cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>
+
+ // CHECK: %[[RET:.*]] = spirv.Constant dense<[false, true, false]>
+ %0 = spirv.INotEqual %cv0, %cv1 : vector<3xi32>
+
+ // CHECK: return %[[RET]]
+ return %0 : vector<3xi1>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.mlir.selection
//===----------------------------------------------------------------------===//
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, just some coding style suggestions here
cb66d1e
to
f7f5ea0
Compare
Nice, thanks. Let me know if it looks good and I can squash and try my new commit privileges :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just one cosmetic suggestion for single-line if statements. Feel free to change and merge without asking for re-approval.
if (isa<IntegerType>(getType())) { | ||
return trueAttr; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we don't need braces around single-line if statements. also elsewhere
Add missing constant propogation folder for [I|Logical][N]Eq Implement additional folding when lhs == rhs for all ops. As well as, fix test cases in logical-ops-to-llvm that failed due to introduced folding. This helps for readability of lowered code into SPIR-V. Part of work for llvm#70704
- fix coding style
f7f5ea0
to
250c692
Compare
No description provided.