Skip to content

Commit b1904a5

Browse files
[mlir][complex] Allow integer element types in complex.constant ops
The op used to support only float element types. This was inconsistent with `ConstantOp::isBuildableWith`, which allows integer element types. The complex type allows any float/integer element type. Note: The other complex dialect ops do not support non-float element types yet. The purpose of this change to fix `Tensor/canonicalize.mlir`, which is currently failing when verifying the IR after each pattern application (llvm#74270). ``` within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: error: 'complex.constant' op result #0 must be complex type with floating-point elements, but got 'complex<i32>' %complex1 = tensor.extract %c1[] : tensor<complex<i32>> ^ within split at mlir/test/Dialect/Tensor/canonicalize.mlir:231 offset :8:15: note: see current operation: %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex<i32> "func.func"() <{function_type = () -> tensor<3xcomplex<i32>>, sym_name = "extract_from_elements_complex_i"}> ({ %0 = "complex.constant"() <{value = [1 : i32, 2 : i32]}> : () -> complex<i32> %1 = "arith.constant"() <{value = dense<(3,2)> : tensor<complex<i32>>}> : () -> tensor<complex<i32>> %2 = "arith.constant"() <{value = dense<(1,2)> : tensor<complex<i32>>}> : () -> tensor<complex<i32>> %3 = "tensor.extract"(%1) : (tensor<complex<i32>>) -> complex<i32> %4 = "tensor.from_elements"(%0, %3, %0) : (complex<i32>, complex<i32>, complex<i32>) -> tensor<3xcomplex<i32>> "func.return"(%4) : (tensor<3xcomplex<i32>>) -> () }) : () -> () ```
1 parent c4cebe5 commit b1904a5

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def ConstantOp : Complex_Op<"constant", [
145145
}];
146146

147147
let arguments = (ins ArrayAttr:$value);
148-
let results = (outs Complex<AnyFloat>:$complex);
148+
let results = (outs AnyComplex:$complex);
149149

150150
let assemblyFormat = "$value attr-dict `:` type($complex)";
151151
let hasFolder = 1;

mlir/lib/Dialect/Complex/IR/ComplexOps.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@ LogicalResult ConstantOp::verify() {
5858
}
5959

6060
auto complexEltTy = getType().getElementType();
61-
auto re = llvm::dyn_cast<FloatAttr>(arrayAttr[0]);
62-
auto im = llvm::dyn_cast<FloatAttr>(arrayAttr[1]);
63-
if (!re || !im)
64-
return emitOpError("requires attribute's elements to be float attributes");
61+
if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
62+
!isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
63+
return emitOpError(
64+
"requires attribute's elements to be float or integer attributes");
65+
auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
66+
auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
6567
if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
6668
return emitOpError()
6769
<< "requires attribute's element types (" << re.getType() << ", "

mlir/test/Dialect/Complex/ops.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ func.func @ops(%f: f32) {
1111
// CHECK: complex.constant [1.{{.*}} : f32, -1.{{.*}} : f32] : complex<f32>
1212
%cst_f32 = complex.constant [0.1 : f32, -1.0 : f32] : complex<f32>
1313

14+
// CHECK: complex.constant [true, false] : complex<i1>
15+
%cst_i1 = complex.constant [1 : i1, 0 : i1] : complex<i1>
16+
1417
// CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex<f32>
1518
%complex = complex.create %f, %f : complex<f32>
1619

0 commit comments

Comments
 (0)