Skip to content

[mlir][arith][tensor] Disable index type for bitcast #121455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 24, 2025

Conversation

jacquesguan
Copy link
Contributor

@jacquesguan jacquesguan commented Jan 2, 2025

Use kInternalStorageBitWidth as the bit width of index type. Fixes #121397.

@llvmbot
Copy link
Member

llvmbot commented Jan 2, 2025

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-math
@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir-arith

Author: Jianjian Guan (jacquesguan)

Changes

Use kInternalStorageBitWidth as the bit width of index type.


Full diff: https://github.com/llvm/llvm-project/pull/121455.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+12-1)
  • (modified) mlir/test/Dialect/Arith/ops.mlir (+6)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d8b314a3fa43c0..6c4aee3aad94fe 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1723,7 +1723,18 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (!srcType || !dstType)
     return false;
 
-  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
+  unsigned srcWidth, dstWidth;
+  if (auto indexTy = dyn_cast<IndexType>(srcType))
+    srcWidth = IndexType::kInternalStorageBitWidth;
+  else
+    srcWidth = srcType.getIntOrFloatBitWidth();
+
+  if (auto indexTy = dyn_cast<IndexType>(dstType))
+    dstWidth = IndexType::kInternalStorageBitWidth;
+  else
+    dstWidth = dstType.getIntOrFloatBitWidth();
+
+  return srcWidth == dstWidth;
 }
 
 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a517..46cb1993a3b789 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -954,6 +954,12 @@ func.func @test_bitcast_scalable_vector1(%arg0 : vector<[8]xf32>) -> vector<[8]x
   return %0 : vector<[8]xi32>
 }
 
+// CHECK-LABEL: test_bitcast_index
+func.func @test_bitcast_index(%arg0 : i64) -> index {
+  %0 = arith.bitcast %arg0 : i64 to index
+  return %0 : index
+}
+
 // CHECK-LABEL: test_cmpi
 func.func @test_cmpi(%arg0 : i64, %arg1 : i64) -> i1 {
   %0 = arith.cmpi ne, %arg0, %arg1 : i64

@llvmbot
Copy link
Member

llvmbot commented Jan 2, 2025

@llvm/pr-subscribers-mlir

Author: Jianjian Guan (jacquesguan)

Changes

Use kInternalStorageBitWidth as the bit width of index type.


Full diff: https://github.com/llvm/llvm-project/pull/121455.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+12-1)
  • (modified) mlir/test/Dialect/Arith/ops.mlir (+6)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index d8b314a3fa43c0..6c4aee3aad94fe 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1723,7 +1723,18 @@ bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   if (!srcType || !dstType)
     return false;
 
-  return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
+  unsigned srcWidth, dstWidth;
+  if (auto indexTy = dyn_cast<IndexType>(srcType))
+    srcWidth = IndexType::kInternalStorageBitWidth;
+  else
+    srcWidth = srcType.getIntOrFloatBitWidth();
+
+  if (auto indexTy = dyn_cast<IndexType>(dstType))
+    dstWidth = IndexType::kInternalStorageBitWidth;
+  else
+    dstWidth = dstType.getIntOrFloatBitWidth();
+
+  return srcWidth == dstWidth;
 }
 
 OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a517..46cb1993a3b789 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -954,6 +954,12 @@ func.func @test_bitcast_scalable_vector1(%arg0 : vector<[8]xf32>) -> vector<[8]x
   return %0 : vector<[8]xi32>
 }
 
+// CHECK-LABEL: test_bitcast_index
+func.func @test_bitcast_index(%arg0 : i64) -> index {
+  %0 = arith.bitcast %arg0 : i64 to index
+  return %0 : index
+}
+
 // CHECK-LABEL: test_cmpi
 func.func @test_cmpi(%arg0 : i64, %arg1 : i64) -> i1 {
   %0 = arith.cmpi ne, %arg0, %arg1 : i64

@CoTinker
Copy link
Contributor

CoTinker commented Jan 2, 2025

Thanks for your work, and I think we should fix tensor.bitcast, too.

bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think arith.bitcast should not be used with index type. The bitwidth is unknown, so a bitcast does not make sense. Can we improve the verifier and reject such ops? Note, there is a special index_cast op to convert from/to index type.

@Mogball
Copy link
Contributor

Mogball commented Jan 2, 2025

I think arith.bitcast should not be used with index type. The bitwidth is unknown, so a bitcast does not make sense. Can we improve the verifier and reject such ops? Note, there is a special index_cast op to convert from/to index type.

+1. A value of index type should be cast to a specific width, such as i32 or i64, before it can be bitcast.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think the internal storage really represents the actual bitwidth used. That only becomes apparent later on, for example when lowering to LLVM this is set as a field in the data layout used to lower index types.

EDIT: +1 to what was said above as well. Should disallow index types in bitcast operations.

@kuhar
Copy link
Member

kuhar commented Jan 2, 2025

+1. A value of index type should be cast to a specific width, such as i32 or i64, before it can be bitcast.

+1

@jacquesguan jacquesguan changed the title [mlir][arith] Support bitcast with index type [mlir][arith][tensor] Disable index type for bitcast Jan 4, 2025
@jacquesguan
Copy link
Contributor Author

Thanks for all comments, I change this PR to disable index type for arith.bitcast and tensor.bitcast.

@CoTinker
Copy link
Contributor

CoTinker commented Jan 4, 2025

And I think it's better to change the constraint of tensor.bitcast from AnyTensor to TensorOf<[AnySignlessInteger, AnyFloat]>

let arguments = (ins AnyTensor:$source);
let results = (outs AnyTensor:$dest);

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:ods labels Jan 7, 2025
@jacquesguan
Copy link
Contributor Author

And I think it's better to change the constraint of tensor.bitcast from AnyTensor to TensorOf<[AnySignlessInteger, AnyFloat]>

let arguments = (ins AnyTensor:$source);
let results = (outs AnyTensor:$dest);

Done, also add signed and unsigned integer.

@@ -908,6 +908,11 @@ def BoolLike : TypeOrContainer<I1, "bool-like">;

def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank<I1, "bool-like">;

// Type constraint for signless-integer-like types: signless integers,
// vectors of signless integers or tensors of signless integers.
def SignlessInteger : TypeOrValueSemanticsContainer<
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit confusing because there is also SignlessIntegerLike below.

I would change it as follows:

  • Rename SignlessIntegerLike to SignlessIntegerOrIndexLike.
  • Rename SignlessIntegerLikeOfAnyRank to SignlessIntegerOrIndexLikeOfAnyRank
  • Call this new definition SignlessIntegerLike instead of SignlessInteger.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for comment, I rename these constraints.

@jacquesguan jacquesguan force-pushed the mlir-fix-arith-bitcast branch from 230be88 to 91bf8e2 Compare January 8, 2025 06:25
Copy link
Contributor

@CoTinker CoTinker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is LGTM now, and please resolve the tests failure.

Use kInternalStorageBitWidth as the bit width of index type.
@jacquesguan jacquesguan force-pushed the mlir-fix-arith-bitcast branch from 91bf8e2 to 5767f51 Compare January 24, 2025 08:11
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Jan 24, 2025
@jacquesguan jacquesguan merged commit 990837f into llvm:main Jan 24, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:arith mlir:core MLIR Core Infrastructure mlir:math mlir:ods mlir:tensor mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir] arith.bitcast crashes on verify when a type is index
7 participants