@@ -55,24 +55,19 @@ def spMxM(*args):
55
55
def boilerplate (attr : st .EncodingAttr ):
56
56
"""Returns boilerplate main method.
57
57
58
- This method sets up a boilerplate main method that calls the generated
59
- sparse kernel. For convenience, this part is purely done as string input.
58
+ This method sets up a boilerplate main method that takes three tensors
59
+ (a, b, c), converts the first tensor a into s sparse tensor, and then
60
+ calls the sparse kernel for matrix multiplication. For convenience,
61
+ this part is purely done as string input.
60
62
"""
61
63
return f"""
62
- func @main(%c: tensor<3x2xf64>) -> tensor<3x2xf64>
64
+ func @main(%ad: tensor<3x4xf64>, %b: tensor<4x2xf64>, % c: tensor<3x2xf64>) -> tensor<3x2xf64>
63
65
attributes {{ llvm.emit_c_interface }} {{
64
- %0 = constant dense<[ [ 1.1, 0.0, 0.0, 1.4 ],
65
- [ 0.0, 0.0, 0.0, 0.0 ],
66
- [ 0.0, 0.0, 3.3, 0.0 ]]> : tensor<3x4xf64>
67
- %a = sparse_tensor.convert %0 : tensor<3x4xf64> to tensor<3x4xf64, { attr } >
68
- %b = constant dense<[ [ 1.0, 2.0 ],
69
- [ 4.0, 3.0 ],
70
- [ 5.0, 6.0 ],
71
- [ 8.0, 7.0 ]]> : tensor<4x2xf64>
72
- %1 = call @spMxM(%a, %b, %c) : (tensor<3x4xf64, { attr } >,
66
+ %a = sparse_tensor.convert %ad : tensor<3x4xf64> to tensor<3x4xf64, { attr } >
67
+ %0 = call @spMxM(%a, %b, %c) : (tensor<3x4xf64, { attr } >,
73
68
tensor<4x2xf64>,
74
69
tensor<3x2xf64>) -> tensor<3x2xf64>
75
- return %1 : tensor<3x2xf64>
70
+ return %0 : tensor<3x2xf64>
76
71
}}
77
72
"""
78
73
@@ -83,25 +78,34 @@ def build_compile_and_run_SpMM(attr: st.EncodingAttr, support_lib: str,
83
78
module = build_SpMM (attr )
84
79
func = str (module .operation .regions [0 ].blocks [0 ].operations [0 ].operation )
85
80
module = ir .Module .parse (func + boilerplate (attr ))
81
+
86
82
# Compile.
87
83
compiler (module )
88
84
engine = execution_engine .ExecutionEngine (
89
85
module , opt_level = 0 , shared_libs = [support_lib ])
90
- # Set up numpy input, invoke the kernel, and get numpy output.
86
+
87
+ # Set up numpy input and buffer for output.
88
+ a = np .array (
89
+ [[1.1 , 0.0 , 0.0 , 1.4 ], [0.0 , 0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 3.3 , 0.0 ]],
90
+ np .float64 )
91
+ b = np .array ([[1.0 , 2.0 ], [4.0 , 3.0 ], [5.0 , 6.0 ], [8.0 , 7.0 ]], np .float64 )
92
+ c = np .zeros ((3 , 2 ), np .float64 )
93
+ out = np .zeros ((3 , 2 ), np .float64 )
94
+
95
+ mem_a = ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (a )))
96
+ mem_b = ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (b )))
97
+ mem_c = ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (c )))
98
+ mem_out = ctypes .pointer (ctypes .pointer (rt .get_ranked_memref_descriptor (out )))
99
+
100
+ # Invoke the kernel and get numpy output.
91
101
# Built-in bufferization uses in-out buffers.
92
102
# TODO: replace with inplace comprehensive bufferization.
93
- Cin = np .zeros ((3 , 2 ), np .double )
94
- Cout = np .zeros ((3 , 2 ), np .double )
95
- Cin_memref_ptr = ctypes .pointer (
96
- ctypes .pointer (rt .get_ranked_memref_descriptor (Cin )))
97
- Cout_memref_ptr = ctypes .pointer (
98
- ctypes .pointer (rt .get_ranked_memref_descriptor (Cout )))
99
- engine .invoke ('main' , Cout_memref_ptr , Cin_memref_ptr )
100
- Cresult = rt .ranked_memref_to_numpy (Cout_memref_ptr [0 ])
103
+ engine .invoke ('main' , mem_out , mem_a , mem_b , mem_c )
101
104
102
105
# Sanity check on computed result.
103
- expected = [[12.3 , 12.0 ], [0.0 , 0.0 ], [16.5 , 19.8 ]]
104
- if np .allclose (Cresult , expected ):
106
+ expected = np .matmul (a , b );
107
+ c = rt .ranked_memref_to_numpy (mem_out [0 ])
108
+ if np .allclose (c , expected ):
105
109
pass
106
110
else :
107
111
quit (f'FAILURE' )
@@ -132,7 +136,10 @@ def __call__(self, module: ir.Module):
132
136
# CHECK: Passed 72 tests
133
137
@run
134
138
def testSpMM ():
139
+ # Obtain path to runtime support library.
135
140
support_lib = os .getenv ('SUPPORT_LIB' )
141
+ assert os .path .exists (support_lib ), f'{ support_lib } does not exist'
142
+
136
143
with ir .Context () as ctx , ir .Location .unknown ():
137
144
count = 0
138
145
# Fixed compiler optimization strategy.
0 commit comments