Skip to content

Commit 24ea94a

Browse files
committed
[mlir][sparse][python] migrate more code from boilerplate into proper numpy land
The boilerplate was setting up some arrays for testing. To fully illustrate python - MLIR potential, however, this data should also come from numpy land. Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D108336
1 parent 02d1175 commit 24ea94a

File tree

1 file changed

+31
-24
lines changed

1 file changed

+31
-24
lines changed

mlir/test/python/dialects/sparse_tensor/test_SpMM.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,19 @@ def spMxM(*args):
5555
def boilerplate(attr: st.EncodingAttr):
5656
"""Returns boilerplate main method.
5757
58-
This method sets up a boilerplate main method that calls the generated
59-
sparse kernel. For convenience, this part is purely done as string input.
58+
This method sets up a boilerplate main method that takes three tensors
59+
(a, b, c), converts the first tensor a into s sparse tensor, and then
60+
calls the sparse kernel for matrix multiplication. For convenience,
61+
this part is purely done as string input.
6062
"""
6163
return f"""
62-
func @main(%c: tensor<3x2xf64>) -> tensor<3x2xf64>
64+
func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, %c: tensor<3x2xf64>) -> tensor<3x2xf64>
6365
attributes {{ llvm.emit_c_interface }} {{
64-
%0 = constant dense<[ [ 1.1, 0.0, 0.0, 1.4 ],
65-
[ 0.0, 0.0, 0.0, 0.0 ],
66-
[ 0.0, 0.0, 3.3, 0.0 ]]> : tensor<3x4xf64>
67-
%a = sparse_tensor.convert %0 : tensor<3x4xf64> to tensor<3x4xf64, {attr}>
68-
%b = constant dense<[ [ 1.0, 2.0 ],
69-
[ 4.0, 3.0 ],
70-
[ 5.0, 6.0 ],
71-
[ 8.0, 7.0 ]]> : tensor<4x2xf64>
72-
%1 = call @spMxM(%a, %b, %c) : (tensor<3x4xf64, {attr}>,
66+
%a = sparse_tensor.convert %ad : tensor<3x4xf64> to tensor<3x4xf64, {attr}>
67+
%0 = call @spMxM(%a, %b, %c) : (tensor<3x4xf64, {attr}>,
7368
tensor<4x2xf64>,
7469
tensor<3x2xf64>) -> tensor<3x2xf64>
75-
return %1 : tensor<3x2xf64>
70+
return %0 : tensor<3x2xf64>
7671
}}
7772
"""
7873

@@ -83,25 +78,34 @@ def build_compile_and_run_SpMM(attr: st.EncodingAttr, support_lib: str,
8378
module = build_SpMM(attr)
8479
func = str(module.operation.regions[0].blocks[0].operations[0].operation)
8580
module = ir.Module.parse(func + boilerplate(attr))
81+
8682
# Compile.
8783
compiler(module)
8884
engine = execution_engine.ExecutionEngine(
8985
module, opt_level=0, shared_libs=[support_lib])
90-
# Set up numpy input, invoke the kernel, and get numpy output.
86+
87+
# Set up numpy input and buffer for output.
88+
a = np.array(
89+
[[1.1, 0.0, 0.0, 1.4], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 3.3, 0.0]],
90+
np.float64)
91+
b = np.array([[1.0, 2.0], [4.0, 3.0], [5.0, 6.0], [8.0, 7.0]], np.float64)
92+
c = np.zeros((3, 2), np.float64)
93+
out = np.zeros((3, 2), np.float64)
94+
95+
mem_a = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(a)))
96+
mem_b = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(b)))
97+
mem_c = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(c)))
98+
mem_out = ctypes.pointer(ctypes.pointer(rt.get_ranked_memref_descriptor(out)))
99+
100+
# Invoke the kernel and get numpy output.
91101
# Built-in bufferization uses in-out buffers.
92102
# TODO: replace with inplace comprehensive bufferization.
93-
Cin = np.zeros((3, 2), np.double)
94-
Cout = np.zeros((3, 2), np.double)
95-
Cin_memref_ptr = ctypes.pointer(
96-
ctypes.pointer(rt.get_ranked_memref_descriptor(Cin)))
97-
Cout_memref_ptr = ctypes.pointer(
98-
ctypes.pointer(rt.get_ranked_memref_descriptor(Cout)))
99-
engine.invoke('main', Cout_memref_ptr, Cin_memref_ptr)
100-
Cresult = rt.ranked_memref_to_numpy(Cout_memref_ptr[0])
103+
engine.invoke('main', mem_out, mem_a, mem_b, mem_c)
101104

102105
# Sanity check on computed result.
103-
expected = [[12.3, 12.0], [0.0, 0.0], [16.5, 19.8]]
104-
if np.allclose(Cresult, expected):
106+
expected = np.matmul(a, b);
107+
c = rt.ranked_memref_to_numpy(mem_out[0])
108+
if np.allclose(c, expected):
105109
pass
106110
else:
107111
quit(f'FAILURE')
@@ -132,7 +136,10 @@ def __call__(self, module: ir.Module):
132136
# CHECK: Passed 72 tests
133137
@run
134138
def testSpMM():
139+
# Obtain path to runtime support library.
135140
support_lib = os.getenv('SUPPORT_LIB')
141+
assert os.path.exists(support_lib), f'{support_lib} does not exist'
142+
136143
with ir.Context() as ctx, ir.Location.unknown():
137144
count = 0
138145
# Fixed compiler optimization strategy.

0 commit comments

Comments
 (0)