Skip to content

Fix exir.load/save to handle named data store map #9485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exir/serde/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,4 @@ class LoweredBackendModule:
original_module: export_schema.ExportedProgram
original_state_dict: str
original_constants: str
named_data_store: Optional[bytes] = None
11 changes: 11 additions & 0 deletions exir/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
import torch.export.exported_program as ep
from executorch.exir import delegate
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
from executorch.exir.backend.compile_spec_schema import (
CompileSpec as delegate_CompileSpec,
)
Expand Down Expand Up @@ -268,6 +269,7 @@ def serialize_bytes(b: bytes) -> str:
assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram)

serialized_processed_bytes = serialize_bytes(lowered_module.processed_bytes)
named_data_store = json.dumps(export_serialize._dataclass_to_dict(lowered_module.named_data_store_output),cls=export_serialize.EnumEncoder) if lowered_module.named_data_store_output else None

serialized_lowered_module = SerdeLoweredBackendModule(
original_module=serialized_artifact.exported_program,
Expand All @@ -276,6 +278,7 @@ def serialize_bytes(b: bytes) -> str:
processed_bytes=serialized_processed_bytes,
compile_specs=serialized_compile_spec,
backend_id=lowered_module.backend_id,
named_data_store=named_data_store,
)

json_lowered_module = json.dumps(
Expand Down Expand Up @@ -556,11 +559,19 @@ def deserialize_lowered_module(
None,
)

if serialized_lowered_module.named_data_store is None:
named_data_store = None
else:
named_data_store = export_serialize._dict_to_dataclass(NamedDataStoreOutput, json.loads(serialized_lowered_module.named_data_store))
for buffer in named_data_store.buffers:
buffer.buffer = base64.b64decode(buffer.buffer.encode("ascii"))

lowered_module = ExirLoweredBackendModule(
original_module,
backend_id,
processed_bytes,
compile_specs,
named_data_store
)
self.module.register_module(serialized_lowered_module_arg.name, lowered_module)
return self.graph.get_attr(serialized_lowered_module_arg.name)
Expand Down
1 change: 1 addition & 0 deletions exir/tests/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ python_unittest(
"//executorch/exir/backend/test:backend_with_compiler_demo",
"//executorch/exir/backend/test:op_partitioner_demo",
"//executorch/exir/serde:serialize",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
],
)

Expand Down
35 changes: 35 additions & 0 deletions exir/tests/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,27 @@
# pyre-strict

import io
import tempfile
import unittest
from typing import Tuple

import executorch.exir as exir

import torch
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackFloatingPointPartitioner,
)
from executorch.exir import to_edge
from executorch.exir.backend.backend_api import CompileSpec, to_backend
from executorch.exir.backend.test.backend_with_compiler_demo import (
BackendWithCompilerDemo,
)

from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
from executorch.exir.program._program import (
EdgeProgramManager,
to_edge_transform_and_lower,
)
from executorch.exir.serde.serialize import deserialize, serialize
from torch import nn
from torch.export import export
Expand Down Expand Up @@ -202,6 +210,33 @@ def forward(self, a, x, b):
edge_new = deserialize(serialize(edge.exported_program()))
self.check_ep(edge.exported_program(), edge_new, inputs)

def test_delegate_xnnpack(self) -> None:
class SimpleConv1DModel(nn.Module):
def __init__(self):
super(SimpleConv1DModel, self).__init__()
self.conv1 = nn.Conv1d(
in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1
)

def forward(self, x):
x = self.conv1(x)
return x

x = torch.randn(64, 1, 100)
model = SimpleConv1DModel()
ep = torch.export.export(model, (x,))
edge_orig = to_edge_transform_and_lower(
ep, partitioner=[XnnpackFloatingPointPartitioner()]
)

with tempfile.NamedTemporaryFile() as f:
exir.save(edge_orig.exported_program(), f)
edge_deserialized = EdgeProgramManager(exir.load(f))
self.assertTrue(
edge_orig.to_executorch().buffer
== edge_deserialized.to_executorch().buffer
)

def test_meta_stack_trace_module_hierarchy(self) -> None:
class Model(nn.Module):
def __init__(self):
Expand Down
Loading