Skip to content

Commit e17c913

Browse files
[mlir][python] Add T.tf32 and missing tests for tf32 (#116725)
1 parent 6fe94c3 commit e17c913

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

mlir/python/mlir/extras/types.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Float8E4M3Type,
2222
Float8E5M2Type,
2323
Float8E8M0FNUType,
24+
FloatTF32Type,
2425
FunctionType,
2526
IndexType,
2627
IntegerType,
@@ -70,6 +71,7 @@ def ui(width):
7071

7172
f16 = lambda: F16Type.get()
7273
f32 = lambda: F32Type.get()
74+
tf32 = lambda: FloatTF32Type.get()
7375
f64 = lambda: F64Type.get()
7476
bf16 = lambda: BF16Type.get()
7577

mlir/test/python/ir/builtin_types.py

+5
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ def testTypeIDs():
639639
(BF16Type, BF16Type.get()),
640640
(F16Type, F16Type.get()),
641641
(F32Type, F32Type.get()),
642+
(FloatTF32Type, FloatTF32Type.get()),
642643
(F64Type, F64Type.get()),
643644
(NoneType, NoneType.get()),
644645
(ComplexType, ComplexType.get(f32)),
@@ -668,6 +669,7 @@ def testTypeIDs():
668669
# CHECK: BF16Type(bf16)
669670
# CHECK: F16Type(f16)
670671
# CHECK: F32Type(f32)
672+
# CHECK: FloatTF32Type(tf32)
671673
# CHECK: F64Type(f64)
672674
# CHECK: NoneType(none)
673675
# CHECK: ComplexType(complex<f32>)
@@ -734,6 +736,9 @@ def print_downcasted(typ):
734736
# CHECK: F32Type
735737
# CHECK: F32Type(f32)
736738
print_downcasted(F32Type.get())
739+
# CHECK: FloatTF32Type
740+
# CHECK: FloatTF32Type(tf32)
741+
print_downcasted(FloatTF32Type.get())
737742
# CHECK: F64Type
738743
# CHECK: F64Type(f64)
739744
print_downcasted(F64Type.get())

0 commit comments

Comments
 (0)