Skip to content

Commit 990837f

Browse files
authored
[mlir][arith][tensor] Disable index type for bitcast (#121455)
Fixes #121397.
1 parent df3bc54 commit 990837f

File tree

8 files changed

+73
-32
lines changed

8 files changed

+73
-32
lines changed

flang/include/flang/Optimizer/Dialect/FIRTypes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def IsBaseBoxTypePred
579579
def fir_BaseBoxType : Type<IsBaseBoxTypePred, "fir.box or fir.class type">;
580580

581581
// Generalized FIR and standard dialect types representing intrinsic types
582-
def AnyIntegerLike : TypeConstraint<Or<[SignlessIntegerLike.predicate,
582+
def AnyIntegerLike : TypeConstraint<Or<[SignlessIntegerOrIndexLike.predicate,
583583
AnySignedInteger.predicate, AnyUnsignedInteger.predicate,
584584
fir_IntegerType.predicate, fir_UnsignedType.predicate]>, "any integer">;
585585
def AnyLogicalLike : TypeConstraint<Or<[BoolLike.predicate,

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
5151
class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
5252
Arith_BinaryOp<mnemonic, traits #
5353
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
54-
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>,
55-
Results<(outs SignlessIntegerLike:$result)>;
54+
Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs)>,
55+
Results<(outs SignlessIntegerOrIndexLike:$result)>;
5656

5757
// Base class for integer binary operations without undefined behavior.
5858
class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
@@ -155,11 +155,11 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
155155
Arith_BinaryOp<mnemonic, traits #
156156
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
157157
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
158-
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
158+
Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
159159
DefaultValuedAttr<
160160
Arith_IntegerOverflowAttr,
161161
"::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
162-
Results<(outs SignlessIntegerLike:$result)> {
162+
Results<(outs SignlessIntegerOrIndexLike:$result)> {
163163

164164
let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
165165
attr-dict `:` type($result) }];
@@ -198,7 +198,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
198198
// However, it is necessary to allow arith.constant to return vectors/tensors
199199
// of strings and signed/unsigned integers (for now) as an artefact of
200200
// splitting the Standard dialect.
201-
let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result);
201+
let results = (outs /*SignlessIntegerOrIndexOrFloatLike*/AnyType:$result);
202202

203203
let extraClassDeclaration = [{
204204
/// Whether the constant op can be constructed with a particular value and
@@ -288,8 +288,8 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
288288
```
289289
}];
290290

291-
let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
292-
let results = (outs SignlessIntegerLike:$sum, BoolLike:$overflow);
291+
let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
292+
let results = (outs SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
293293
let assemblyFormat = [{
294294
$lhs `,` $rhs attr-dict `:` type($sum) `,` type($overflow)
295295
}];
@@ -429,8 +429,8 @@ def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative,
429429
```
430430
}];
431431

432-
let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
433-
let results = (outs SignlessIntegerLike:$low, SignlessIntegerLike:$high);
432+
let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
433+
let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
434434

435435
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
436436

@@ -472,8 +472,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
472472
```
473473
}];
474474

475-
let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
476-
let results = (outs SignlessIntegerLike:$low, SignlessIntegerLike:$high);
475+
let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
476+
let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);
477477

478478
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
479479

@@ -1350,7 +1350,7 @@ def Arith_FPToSIOp : Arith_FToICastOp<"fptosi"> {
13501350

13511351
// Index cast can convert between memrefs of signless integers and indices too.
13521352
def IndexCastTypeConstraint : TypeConstraint<Or<[
1353-
SignlessIntegerLike.predicate,
1353+
SignlessIntegerOrIndexLike.predicate,
13541354
MemRefOf<[AnySignlessInteger, Index]>.predicate]>,
13551355
"signless-integer-like or memref of signless-integer">;
13561356

@@ -1392,11 +1392,10 @@ def Arith_IndexCastUIOp
13921392
// BitcastOp
13931393
//===----------------------------------------------------------------------===//
13941394

1395-
// Bitcast can convert between memrefs of signless integers, indices, and
1396-
// floats too.
1395+
// Bitcast can convert between memrefs of signless integers and floats.
13971396
def BitcastTypeConstraint : TypeConstraint<Or<[
13981397
SignlessIntegerOrFloatLike.predicate,
1399-
MemRefOf<[AnySignlessInteger, Index, AnyFloat]>.predicate]>,
1398+
MemRefOf<[AnySignlessInteger, AnyFloat]>.predicate]>,
14001399
"signless-integer-or-float-like or memref of signless-integer or float">;
14011400

14021401
def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
@@ -1496,8 +1495,8 @@ def Arith_CmpIOp
14961495
}];
14971496

14981497
let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
1499-
SignlessIntegerLikeOfAnyRank:$lhs,
1500-
SignlessIntegerLikeOfAnyRank:$rhs);
1498+
SignlessIntegerOrIndexLikeOfAnyRank:$lhs,
1499+
SignlessIntegerOrIndexLikeOfAnyRank:$rhs);
15011500

15021501
let hasFolder = 1;
15031502
let hasCanonicalizer = 1;

mlir/include/mlir/Dialect/Math/IR/MathOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ class Math_Op<string mnemonic, list<Trait> traits = []> :
2828
// tensor thereof.
2929
class Math_IntegerUnaryOp<string mnemonic, list<Trait> traits = []> :
3030
Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
31-
let arguments = (ins SignlessIntegerLike:$operand);
32-
let results = (outs SignlessIntegerLike:$result);
31+
let arguments = (ins SignlessIntegerOrIndexLike:$operand);
32+
let results = (outs SignlessIntegerOrIndexLike:$result);
3333

3434
let assemblyFormat = "$operand attr-dict `:` type($result)";
3535
}
@@ -55,8 +55,8 @@ class Math_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
5555
// type, vector or tensor thereof.
5656
class Math_IntegerBinaryOp<string mnemonic, list<Trait> traits = []> :
5757
Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
58-
let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
59-
let results = (outs SignlessIntegerLike:$result);
58+
let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
59+
let results = (outs SignlessIntegerOrIndexLike:$result);
6060

6161
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
6262
}
@@ -976,7 +976,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
976976
```
977977
}];
978978

979-
let arguments = (ins FloatLike:$lhs, SignlessIntegerLike:$rhs,
979+
let arguments = (ins FloatLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
980980
DefaultValuedAttr<Arith_FastMathAttr,
981981
"::mlir::arith::FastMathFlags::none">:$fastmath);
982982
let results = (outs FloatLike:$result);

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@ def Tensor_BitcastOp : Tensor_Op<"bitcast", [
7575
```
7676
}];
7777

78-
let arguments = (ins AnyTensor:$source);
79-
let results = (outs AnyTensor:$dest);
78+
let arguments = (ins TensorOf<[AnySignlessInteger, AnyUnsignedInteger,
79+
AnySignedInteger, AnyFloat]>:$source);
80+
let results = (outs TensorOf<[AnySignlessInteger, AnyUnsignedInteger,
81+
AnySignedInteger, AnyFloat]>:$dest);
8082
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
8183

8284
let hasCanonicalizer = 1;

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -908,21 +908,31 @@ def BoolLike : TypeOrContainer<I1, "bool-like">;
908908

909909
def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;
910910

911+
// Type constraint for signless-integer-like types: signless integers,
912+
// vectors of signless integers or tensors of signless integers.
913+
def SignlessIntegerLike : TypeOrValueSemanticsContainer<
914+
AnySignlessInteger, "signless-integer">;
915+
911916
// Type constraint for signless-integer-like types: signless integers, indices,
912917
// vectors of signless integers or indices, tensors of signless integers.
913-
def SignlessIntegerLike : TypeOrValueSemanticsContainer<
918+
def SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer<
914919
AnySignlessIntegerOrIndex, "signless-integer-like">;
915920

916-
def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank<
921+
def SignlessIntegerOrIndexLikeOfAnyRank : TypeOrContainerOfAnyRank<
917922
AnySignlessIntegerOrIndex,
918923
"signless-integer-like">;
919924

920925
// Type constraint for float-like types: floats, vectors or tensors thereof.
921926
def FloatLike : TypeOrContainer<AnyFloat, "floating-point-like">;
922927

923-
// Type constraint for signless-integer-like or float-like types.
928+
// Type constraint for signless-integer-or-index-like or float-like types.
924929
def SignlessIntegerOrFloatLike : TypeConstraint<Or<[
925930
SignlessIntegerLike.predicate, FloatLike.predicate]>,
926931
"signless-integer-like or floating-point-like">;
927932

933+
// Type constraint for signless-integer-or-index-like or float-like types.
934+
def SignlessIntegerOrIndexOrFloatLike : TypeConstraint<Or<[
935+
SignlessIntegerOrIndexLike.predicate, FloatLike.predicate]>,
936+
"signless-integer-or-index-like or floating-point-like">;
937+
928938
#endif // COMMON_TYPE_CONSTRAINTS_TD

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,10 +1740,8 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
17401740
if (!areValidCastInputsAndOutputs(inputs, outputs))
17411741
return false;
17421742

1743-
auto srcType =
1744-
getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1745-
auto dstType =
1746-
getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1743+
auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
1744+
auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
17471745
if (!srcType || !dstType)
17481746
return false;
17491747

mlir/test/Dialect/Arith/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,3 +853,19 @@ func.func @select_tensor_encoding(
853853
%0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo">
854854
return %0 : tensor<8xi32, "foo">
855855
}
856+
857+
// -----
858+
859+
func.func @bitcast_index_0(%arg0 : i64) -> index {
860+
// expected-error @+1 {{'arith.bitcast' op result #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
861+
%0 = arith.bitcast %arg0 : i64 to index
862+
return %0 : index
863+
}
864+
865+
// -----
866+
867+
func.func @bitcast_index_1(%arg0 : index) -> i64 {
868+
// expected-error @+1 {{'arith.bitcast' op operand #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
869+
%0 = arith.bitcast %arg0 : index to i64
870+
return %0 : i64
871+
}

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,3 +807,19 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
807807
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x?x4xf32> -> tensor<?x?xf32>
808808
return %0 : tensor<?x?xf32>
809809
}
810+
811+
// -----
812+
813+
func.func @bitcast_index_0(%arg0 : tensor<?xi64>) -> tensor<?xindex> {
814+
// expected-error @+1 {{'tensor.bitcast' op result #0 must be tensor of signless integer or unsigned integer or signed integer or floating-point values, but got 'tensor<?xindex>'}}
815+
%0 = tensor.bitcast %arg0 : tensor<?xi64> to tensor<?xindex>
816+
return %0 : tensor<?xindex>
817+
}
818+
819+
// -----
820+
821+
func.func @bitcast_index_1(%arg0 : tensor<?xindex>) -> tensor<?xi64> {
822+
// expected-error @+1 {{'tensor.bitcast' op operand #0 must be tensor of signless integer or unsigned integer or signed integer or floating-point values, but got 'tensor<?xindex>'}}
823+
%0 = tensor.bitcast %arg0 : tensor<?xindex> to tensor<?xi64>
824+
return %0 : tensor<?xi64>
825+
}

0 commit comments

Comments
 (0)