Closed
Description
Basically add tests to cover concerns raised in #11
Testing of asyncio safety will require adding test-time dependency on pytest-asyncio
. The test file may be as follows:
import dpctl
import dpctl.memory as dpmem
import random
import asyncio
async def _task1():
q = dpctl.SyclQueue("opencl:gpu")
abc = b"abcdefghijklmnopqrstuvwxyz"
m = dpmem.MemoryUSMShared(len(abc))
with dpctl.device_context(q) as lq:
for _ in range(12):
cd = dpctl.get_current_queue().sycl_device
assert cd.backend == q.sycl_device.backend
m.copy_from_host(abc)
await asyncio.sleep(0.1 * random.random())
async def _task2():
q = dpctl.SyclQueue("level_zero:gpu")
m = dpmem.MemoryUSMShared(10)
host_data = b"\x00" * 10
with dpctl.device_context(q) as lq:
for _ in range(12):
cd = dpctl.get_current_queue().sycl_device
assert cd.backend == q.sycl_device.backend
m.copy_from_host(host_data)
await asyncio.sleep(0.1 * random.random())
async def test_asyncio_safety():
j1 = asyncio.create_task(_task1())
j2 = asyncio.create_task(_task2())
await j1
await j2
print("done")
if __name__ == '__main__':
asyncio.run(test_asyncio_safety())
The test for multi-processing may look as follows, but to turn it into a test, assertions must be added:
import dpctl
import dpctl.memory as dpmem
import multiprocessing as mp
def compute_works(i):
d = dpctl.SyclDevice("cpu")
sd = d.create_sub_devices(partition=(4, 4, 4))
ctx = dpctl.SyclContext(sd)
q = dpctl.SyclQueue(ctx, sd[i], property="in_order")
mem = dpmem.MemoryUSMShared(32, queue=q)
host = (b" ") * 32
mem.copy_from_host(host)
print((i, "->", mem.copy_to_host()))
print("", flush=True)
# dpctl.set_global_queue(q)
def rt_warnings(i):
cq = dpctl.get_current_queue()
with dpctl.device_context("cpu") as q:
print(
"Hello from {}, using {} with {} EUs".format(
i, q.sycl_device.name, q.sycl_device.max_compute_units
)
)
dpctl.set_global_queue(q)
return i
if __name__ == "__main__":
with mp.Pool(3) as p:
p.map(rt_warnings, [0, 1, 2])
print("Execution of rt_warnings finished")
print("--" * 30)
with mp.Pool(3) as p:
p.map(compute_works, [0, 1, 2])
print("Execution of compute_works finished")
Metadata
Metadata
Assignees
Labels
No labels