Open
Description
🚀 The feature, motivation and pitch
Goal
Allow users to directly call kernels (potentially delegates) in python runtime. Supports dynamic kernel registration/deregistration, kernel metadata (op schema, selected dtype/dim order) inspection.
Motivation
Simplify kernel development workflow.
Leverage existing PyTorch op unit test framework for kernel coverage.
Microbenchmarks for kernels.
API
import torch
from typing import Callable
from executorch.runtime import Verification, Runtime, Program, Method
et_runtime: Runtime = Runtime.get()
# New API to retrieve kernel
op: Callable = et_runtime.operator_registry.aten.add.out
# Can also do:
op: Callable = et_runtime.operator_registry.get_kernel("aten::add.out")
a = torch.ones([2, 2])
b = torch.ones([2, 2])
c = torch.empty_like(a)
op(a, b, out=c) # calls into ExecuTorch kernel instead of PyTorch kernel.
# Print out schema
print(op._schema)
# No such kernel registered
et_runtime.operator_registry.aten.add.Tensor # Throws exception
Dynamically register a custom kernel
import torch
from executorch.runtime import Verification, Runtime, Program, Method
et_runtime: Runtime = Runtime.get()
# Register a custom kernel
op: Callable = et_runtime.operator_registry.register_kernel(
"aten::add.out",
"<kernel string>",
kernel_key,
)
Task Breakdown
- Add pybindings for
OperatorRegistry
andKernel
. ForKernel
pybind, it should expose anoperator()
in python, that takes in a list of arguments and inside the pybind implementation we wrap them with EValues and call the kernel. - Connect
Kernel
withEdgeOpOverload
.EdgeOpOverload
is an AOT operator concept and if we call it directly it will trigger the ATen kernel. Add another API inEdgeOpOverload
so that it can call theKernel
in ET. - Expose custom kernel registration API in python. To do this we need to leverage Ninja to compile the custom kernel code and register it. pytorch/pytorch has existing tools to do that.
Alternatives
No response
Additional context
No response
RFC (Optional)
No response
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
To triage