Skip to content

[mlir][arith][tensor] Disable index type for bitcast #121455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/Dialect/FIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def IsBaseBoxTypePred
def fir_BaseBoxType : Type<IsBaseBoxTypePred, "fir.box or fir.class type">;

// Generalized FIR and standard dialect types representing intrinsic types
def AnyIntegerLike : TypeConstraint<Or<[SignlessIntegerLike.predicate,
def AnyIntegerLike : TypeConstraint<Or<[SignlessIntegerOrIndexLike.predicate,
AnySignedInteger.predicate, AnyUnsignedInteger.predicate,
fir_IntegerType.predicate, fir_UnsignedType.predicate]>, "any integer">;
def AnyLogicalLike : TypeConstraint<Or<[BoolLike.predicate,
Expand Down
33 changes: 16 additions & 17 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class Arith_BinaryOp<string mnemonic, list<Trait> traits = []> :
class Arith_IntBinaryOp<string mnemonic, list<Trait> traits = []> :
Arith_BinaryOp<mnemonic, traits #
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]>,
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>,
Results<(outs SignlessIntegerLike:$result)>;
Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs)>,
Results<(outs SignlessIntegerOrIndexLike:$result)>;

// Base class for integer binary operations without undefined behavior.
class Arith_TotalIntBinaryOp<string mnemonic, list<Trait> traits = []> :
Expand Down Expand Up @@ -155,11 +155,11 @@ class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = [
Arith_BinaryOp<mnemonic, traits #
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
Arguments<(ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
DefaultValuedAttr<
Arith_IntegerOverflowAttr,
"::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
Results<(outs SignlessIntegerLike:$result)> {
Results<(outs SignlessIntegerOrIndexLike:$result)> {

let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
attr-dict `:` type($result) }];
Expand Down Expand Up @@ -198,7 +198,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
// However, it is necessary to allow arith.constant to return vectors/tensors
// of strings and signed/unsigned integers (for now) as an artefact of
// splitting the Standard dialect.
let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result);
let results = (outs /*SignlessIntegerOrIndexOrFloatLike*/AnyType:$result);

let extraClassDeclaration = [{
/// Whether the constant op can be constructed with a particular value and
Expand Down Expand Up @@ -288,8 +288,8 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
```
}];

let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
let results = (outs SignlessIntegerLike:$sum, BoolLike:$overflow);
let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
let results = (outs SignlessIntegerOrIndexLike:$sum, BoolLike:$overflow);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($sum) `,` type($overflow)
}];
Expand Down Expand Up @@ -429,8 +429,8 @@ def Arith_MulSIExtendedOp : Arith_Op<"mulsi_extended", [Pure, Commutative,
```
}];

let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
let results = (outs SignlessIntegerLike:$low, SignlessIntegerLike:$high);
let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);

let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";

Expand Down Expand Up @@ -472,8 +472,8 @@ def Arith_MulUIExtendedOp : Arith_Op<"mului_extended", [Pure, Commutative,
```
}];

let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
let results = (outs SignlessIntegerLike:$low, SignlessIntegerLike:$high);
let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
let results = (outs SignlessIntegerOrIndexLike:$low, SignlessIntegerOrIndexLike:$high);

let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";

Expand Down Expand Up @@ -1350,7 +1350,7 @@ def Arith_FPToSIOp : Arith_FToICastOp<"fptosi"> {

// Index cast can convert between memrefs of signless integers and indices too.
def IndexCastTypeConstraint : TypeConstraint<Or<[
SignlessIntegerLike.predicate,
SignlessIntegerOrIndexLike.predicate,
MemRefOf<[AnySignlessInteger, Index]>.predicate]>,
"signless-integer-like or memref of signless-integer">;

Expand Down Expand Up @@ -1392,11 +1392,10 @@ def Arith_IndexCastUIOp
// BitcastOp
//===----------------------------------------------------------------------===//

// Bitcast can convert between memrefs of signless integers, indices, and
// floats too.
// Bitcast can convert between memrefs of signless integers and floats.
def BitcastTypeConstraint : TypeConstraint<Or<[
SignlessIntegerOrFloatLike.predicate,
MemRefOf<[AnySignlessInteger, Index, AnyFloat]>.predicate]>,
MemRefOf<[AnySignlessInteger, AnyFloat]>.predicate]>,
"signless-integer-or-float-like or memref of signless-integer or float">;

def Arith_BitcastOp : Arith_CastOp<"bitcast", BitcastTypeConstraint,
Expand Down Expand Up @@ -1496,8 +1495,8 @@ def Arith_CmpIOp
}];

let arguments = (ins Arith_CmpIPredicateAttr:$predicate,
SignlessIntegerLikeOfAnyRank:$lhs,
SignlessIntegerLikeOfAnyRank:$rhs);
SignlessIntegerOrIndexLikeOfAnyRank:$lhs,
SignlessIntegerOrIndexLikeOfAnyRank:$rhs);

let hasFolder = 1;
let hasCanonicalizer = 1;
Expand Down
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/Math/IR/MathOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class Math_Op<string mnemonic, list<Trait> traits = []> :
// tensor thereof.
class Math_IntegerUnaryOp<string mnemonic, list<Trait> traits = []> :
Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
let arguments = (ins SignlessIntegerLike:$operand);
let results = (outs SignlessIntegerLike:$result);
let arguments = (ins SignlessIntegerOrIndexLike:$operand);
let results = (outs SignlessIntegerOrIndexLike:$result);

let assemblyFormat = "$operand attr-dict `:` type($result)";
}
Expand All @@ -55,8 +55,8 @@ class Math_FloatUnaryOp<string mnemonic, list<Trait> traits = []> :
// type, vector or tensor thereof.
class Math_IntegerBinaryOp<string mnemonic, list<Trait> traits = []> :
Math_Op<mnemonic, traits # [SameOperandsAndResultType]> {
let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
let results = (outs SignlessIntegerLike:$result);
let arguments = (ins SignlessIntegerOrIndexLike:$lhs, SignlessIntegerOrIndexLike:$rhs);
let results = (outs SignlessIntegerOrIndexLike:$result);

let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
}
Expand Down Expand Up @@ -976,7 +976,7 @@ def Math_FPowIOp : Math_Op<"fpowi",
```
}];

let arguments = (ins FloatLike:$lhs, SignlessIntegerLike:$rhs,
let arguments = (ins FloatLike:$lhs, SignlessIntegerOrIndexLike:$rhs,
DefaultValuedAttr<Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath);
let results = (outs FloatLike:$result);
Expand Down
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ def Tensor_BitcastOp : Tensor_Op<"bitcast", [
```
}];

let arguments = (ins AnyTensor:$source);
let results = (outs AnyTensor:$dest);
let arguments = (ins TensorOf<[AnySignlessInteger, AnyUnsignedInteger,
AnySignedInteger, AnyFloat]>:$source);
let results = (outs TensorOf<[AnySignlessInteger, AnyUnsignedInteger,
AnySignedInteger, AnyFloat]>:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";

let hasCanonicalizer = 1;
Expand Down
16 changes: 13 additions & 3 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -908,21 +908,31 @@ def BoolLike : TypeOrContainer<I1, "bool-like">;

def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;

// Type constraint for signless-integer-like types: signless integers,
// vectors of signless integers or tensors of signless integers.
def SignlessIntegerLike : TypeOrValueSemanticsContainer<
AnySignlessInteger, "signless-integer">;

// Type constraint for signless-integer-like types: signless integers, indices,
// vectors of signless integers or indices, tensors of signless integers.
def SignlessIntegerLike : TypeOrValueSemanticsContainer<
def SignlessIntegerOrIndexLike : TypeOrValueSemanticsContainer<
AnySignlessIntegerOrIndex, "signless-integer-like">;

def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank<
def SignlessIntegerOrIndexLikeOfAnyRank : TypeOrContainerOfAnyRank<
AnySignlessIntegerOrIndex,
"signless-integer-like">;

// Type constraint for float-like types: floats, vectors or tensors thereof.
def FloatLike : TypeOrContainer<AnyFloat, "floating-point-like">;

// Type constraint for signless-integer-like or float-like types.
// Type constraint for signless-integer-or-index-like or float-like types.
def SignlessIntegerOrFloatLike : TypeConstraint<Or<[
SignlessIntegerLike.predicate, FloatLike.predicate]>,
"signless-integer-like or floating-point-like">;

// Type constraint for signless-integer-or-index-like or float-like types.
def SignlessIntegerOrIndexOrFloatLike : TypeConstraint<Or<[
SignlessIntegerOrIndexLike.predicate, FloatLike.predicate]>,
"signless-integer-or-index-like or floating-point-like">;

#endif // COMMON_TYPE_CONSTRAINTS_TD
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1740,10 +1740,8 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (!areValidCastInputsAndOutputs(inputs, outputs))
return false;

auto srcType =
getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
auto dstType =
getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
auto srcType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(inputs.front());
auto dstType = getTypeIfLikeOrMemRef<IntegerType, FloatType>(outputs.front());
if (!srcType || !dstType)
return false;

Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Arith/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -853,3 +853,19 @@ func.func @select_tensor_encoding(
%0 = arith.select %arg0, %arg1, %arg2 : tensor<8xi1, "bar">, tensor<8xi32, "foo">
return %0 : tensor<8xi32, "foo">
}

// -----

func.func @bitcast_index_0(%arg0 : i64) -> index {
// expected-error @+1 {{'arith.bitcast' op result #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
%0 = arith.bitcast %arg0 : i64 to index
return %0 : index
}

// -----

func.func @bitcast_index_1(%arg0 : index) -> i64 {
// expected-error @+1 {{'arith.bitcast' op operand #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got 'index'}}
%0 = arith.bitcast %arg0 : index to i64
return %0 : i64
}
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Tensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -807,3 +807,19 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
%0 = tensor.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x?x4xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

// -----

func.func @bitcast_index_0(%arg0 : tensor<?xi64>) -> tensor<?xindex> {
// expected-error @+1 {{'tensor.bitcast' op result #0 must be tensor of signless integer or unsigned integer or signed integer or floating-point values, but got 'tensor<?xindex>'}}
%0 = tensor.bitcast %arg0 : tensor<?xi64> to tensor<?xindex>
return %0 : tensor<?xindex>
}

// -----

func.func @bitcast_index_1(%arg0 : tensor<?xindex>) -> tensor<?xi64> {
// expected-error @+1 {{'tensor.bitcast' op operand #0 must be tensor of signless integer or unsigned integer or signed integer or floating-point values, but got 'tensor<?xindex>'}}
%0 = tensor.bitcast %arg0 : tensor<?xindex> to tensor<?xi64>
return %0 : tensor<?xi64>
}
Loading