Skip to content

Commit ba44d7b

Browse files
authored
[MLIR][test] Fixup for checking for ml_dtypes (#123240)
In order to optionally run some checks that depend on the `ml_dtypes` python module we have to remove the `CHECK` lines for those tests or they will be required and missed in the test output. I've changed to use asserts as recommended in [1]. [1]: #123061 (comment)
1 parent 63b0ab8 commit ba44d7b

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

mlir/test/python/execution_engine.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -566,13 +566,15 @@ def testBF16Memref():
566566
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
567567

568568
# test to-numpy utility
569-
# CHECK: [0.5]
570-
npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
571-
log(npout)
569+
x = ranked_memref_to_numpy(arg2_memref_ptr[0])
570+
assert len(x) == 1
571+
assert x[0] == 0.5
572572

573573

574574
if HAS_ML_DTYPES:
575575
run(testBF16Memref)
576+
else:
577+
log("TEST: testBF16Memref")
576578

577579

578580
# Test f8E5M2 memrefs
@@ -606,13 +608,15 @@ def testF8E5M2Memref():
606608
execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
607609

608610
# test to-numpy utility
609-
# CHECK: [0.5]
610-
npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
611-
log(npout)
611+
x = ranked_memref_to_numpy(arg2_memref_ptr[0])
612+
assert len(x) == 1
613+
assert x[0] == 0.5
612614

613615

614616
if HAS_ML_DTYPES:
615617
run(testF8E5M2Memref)
618+
else:
619+
log("TEST: testF8E5M2Memref")
616620

617621

618622
# Test addition of two 2d_memref

0 commit comments

Comments
 (0)