Skip to content

[MLIR][test] Fixup for checking for ml_dtypes #123240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 17, 2025

Conversation

kwk
Copy link
Contributor

@kwk kwk commented Jan 16, 2025

In order to optionally run some checks that depend on the ml_dtypes python module we have to remove the CHECK lines for those tests or they will be required and missed in the test output.

I've changed to use asserts as recommended in 1.

@llvmbot
Copy link
Member

llvmbot commented Jan 16, 2025

@llvm/pr-subscribers-mlir

Author: Konrad Kleine (kwk)

Changes

In order to optionally run some checks that depend on the ml_dtypes python module we have to remove the CHECK lines for those tests or they will be required and missed in the test output.

I've changed to use asserts as recommended in 1.


Full diff: https://github.com/llvm/llvm-project/pull/123240.diff

1 Files Affected:

  • (modified) mlir/test/python/execution_engine.py (+6-8)
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index e3f41815800d58..cab6b69a01f4cc 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -536,7 +536,6 @@ def testComplexUnrankedMemrefAdd():
 
 
 # Test bf16 memrefs
-# CHECK-LABEL: TEST: testBF16Memref
 def testBF16Memref():
     with Context():
         module = Module.parse(
@@ -566,9 +565,9 @@ def testBF16Memref():
         execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
 
         # test to-numpy utility
-        # CHECK: [0.5]
-        npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
-        log(npout)
+        x = ranked_memref_to_numpy(arg2_memref_ptr[0])
+        assert len(x) == 1
+        assert x[0] == 0.5
 
 
 if HAS_ML_DTYPES:
@@ -576,7 +575,6 @@ def testBF16Memref():
 
 
 # Test f8E5M2 memrefs
-# CHECK-LABEL: TEST: testF8E5M2Memref
 def testF8E5M2Memref():
     with Context():
         module = Module.parse(
@@ -606,9 +604,9 @@ def testF8E5M2Memref():
         execution_engine.invoke("main", arg1_memref_ptr, arg2_memref_ptr)
 
         # test to-numpy utility
-        # CHECK: [0.5]
-        npout = ranked_memref_to_numpy(arg2_memref_ptr[0])
-        log(npout)
+        x = ranked_memref_to_numpy(arg2_memref_ptr[0])
+        assert len(x) == 1
+        assert x[0] == 0.5
 
 
 if HAS_ML_DTYPES:



if HAS_ML_DTYPES:
run(testBF16Memref)


# Test f8E5M2 memrefs
# CHECK-LABEL: TEST: testF8E5M2Memref
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry why not the

if HAS_ML_DTYPES:
    run(testF8E5M2Memref)
else:
    print("TEST: testF8E5M2Memref")

I suggested?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally I had this even before you suggested it but it didn't work. Today when I put the prints (I used log()) back in in order to show the error it did work. What can I say other than sorry and thank you. I'll update the PR shortly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@makslevental I've update the PR and would be super happy if you could take a look again. Thank you for accepting my PRs and helping me fix them btw. This is very welcoming.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ofc ofc no problem

@kwk kwk force-pushed the fixup_for_checking_for_ml_dtypes branch from 2fe0c70 to e9e99f7 Compare January 17, 2025 08:53
In order to optionally run some checks that depend on the `ml_dtypes`
python module we have to remove the `CHECK` lines for those tests or
they will be required and missed in the test output.

I've changed to use asserts as recommended in [1].

[1]: llvm#123061 (comment)
@kwk kwk force-pushed the fixup_for_checking_for_ml_dtypes branch from e9e99f7 to bc2e254 Compare January 17, 2025 10:32
Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Hopefully it works on your end now :)

@kwk kwk merged commit ba44d7b into llvm:main Jan 17, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants