Skip to content

Add tests for asyncio safety, and Python parallelism safety #410

Closed
@oleksandr-pavlyk

Description

@oleksandr-pavlyk

Basically add tests to cover concerns raised in #11

Testing of asyncio safety will require adding test-time dependency on pytest-asyncio. The test file may be as follows:

import dpctl
import dpctl.memory as dpmem
import random
import asyncio

async def _task1():
    q = dpctl.SyclQueue("opencl:gpu")
    abc = b"abcdefghijklmnopqrstuvwxyz"
    m = dpmem.MemoryUSMShared(len(abc))
    with dpctl.device_context(q) as lq:
        for _ in range(12):
            cd = dpctl.get_current_queue().sycl_device
            assert cd.backend == q.sycl_device.backend
            m.copy_from_host(abc)
            await asyncio.sleep(0.1 * random.random())


async def _task2():
    q = dpctl.SyclQueue("level_zero:gpu")
    m = dpmem.MemoryUSMShared(10)
    host_data = b"\x00" * 10
    with dpctl.device_context(q) as lq:
        for _ in range(12):
            cd = dpctl.get_current_queue().sycl_device
            assert cd.backend == q.sycl_device.backend
            m.copy_from_host(host_data)
            await asyncio.sleep(0.1 * random.random())


async def test_asyncio_safety():
    j1 = asyncio.create_task(_task1())
    j2 = asyncio.create_task(_task2())
    await j1
    await j2

    print("done")



if __name__ == '__main__':
    asyncio.run(test_asyncio_safety())

The test for multi-processing may look as follows, but to turn it into a test, assertions must be added:

import dpctl
import dpctl.memory as dpmem
import multiprocessing as mp


def compute_works(i):
    d = dpctl.SyclDevice("cpu")
    sd = d.create_sub_devices(partition=(4, 4, 4))
    ctx = dpctl.SyclContext(sd)
    q = dpctl.SyclQueue(ctx, sd[i], property="in_order")
    mem = dpmem.MemoryUSMShared(32, queue=q)
    host = (b" ") * 32
    mem.copy_from_host(host)
    print((i, "->", mem.copy_to_host()))
    print("", flush=True)
    # dpctl.set_global_queue(q)


def rt_warnings(i):
    cq = dpctl.get_current_queue()
    with dpctl.device_context("cpu") as q:
        print(
            "Hello from {}, using {} with {} EUs".format(
                i, q.sycl_device.name, q.sycl_device.max_compute_units
            )
        )
        dpctl.set_global_queue(q)
    return i


if __name__ == "__main__":
    with mp.Pool(3) as p:
        p.map(rt_warnings, [0, 1, 2])

    print("Execution of rt_warnings finished")
    print("--" * 30)

    with mp.Pool(3) as p:
        p.map(compute_works, [0, 1, 2])

    print("Execution of compute_works finished")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions