Skip to content

Commit 6ca39f8

Browse files
authored
Fix exir.load/save to handle named data store map
Differential Revision: D71601647 Pull Request resolved: #9485
1 parent 06ec67c commit 6ca39f8

File tree

4 files changed

+48
-0
lines changed

4 files changed

+48
-0
lines changed

exir/serde/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,3 +402,4 @@ class LoweredBackendModule:
402402
original_module: export_schema.ExportedProgram
403403
original_state_dict: str
404404
original_constants: str
405+
named_data_store: Optional[bytes] = None

exir/serde/serialize.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch
2323
import torch.export.exported_program as ep
2424
from executorch.exir import delegate
25+
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
2526
from executorch.exir.backend.compile_spec_schema import (
2627
CompileSpec as delegate_CompileSpec,
2728
)
@@ -268,6 +269,7 @@ def serialize_bytes(b: bytes) -> str:
268269
assert isinstance(serialized_artifact.exported_program, schema.ExportedProgram)
269270

270271
serialized_processed_bytes = serialize_bytes(lowered_module.processed_bytes)
272+
named_data_store = json.dumps(export_serialize._dataclass_to_dict(lowered_module.named_data_store_output),cls=export_serialize.EnumEncoder) if lowered_module.named_data_store_output else None
271273

272274
serialized_lowered_module = SerdeLoweredBackendModule(
273275
original_module=serialized_artifact.exported_program,
@@ -276,6 +278,7 @@ def serialize_bytes(b: bytes) -> str:
276278
processed_bytes=serialized_processed_bytes,
277279
compile_specs=serialized_compile_spec,
278280
backend_id=lowered_module.backend_id,
281+
named_data_store=named_data_store,
279282
)
280283

281284
json_lowered_module = json.dumps(
@@ -556,11 +559,19 @@ def deserialize_lowered_module(
556559
None,
557560
)
558561

562+
if serialized_lowered_module.named_data_store is None:
563+
named_data_store = None
564+
else:
565+
named_data_store = export_serialize._dict_to_dataclass(NamedDataStoreOutput, json.loads(serialized_lowered_module.named_data_store))
566+
for buffer in named_data_store.buffers:
567+
buffer.buffer = base64.b64decode(buffer.buffer.encode("ascii"))
568+
559569
lowered_module = ExirLoweredBackendModule(
560570
original_module,
561571
backend_id,
562572
processed_bytes,
563573
compile_specs,
574+
named_data_store
564575
)
565576
self.module.register_module(serialized_lowered_module_arg.name, lowered_module)
566577
return self.graph.get_attr(serialized_lowered_module_arg.name)

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ python_unittest(
9898
"//executorch/exir/backend/test:backend_with_compiler_demo",
9999
"//executorch/exir/backend/test:op_partitioner_demo",
100100
"//executorch/exir/serde:serialize",
101+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
101102
],
102103
)
103104

exir/tests/test_serde.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,27 @@
77
# pyre-strict
88

99
import io
10+
import tempfile
1011
import unittest
1112
from typing import Tuple
1213

1314
import executorch.exir as exir
1415

1516
import torch
17+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
18+
XnnpackFloatingPointPartitioner,
19+
)
1620
from executorch.exir import to_edge
1721
from executorch.exir.backend.backend_api import CompileSpec, to_backend
1822
from executorch.exir.backend.test.backend_with_compiler_demo import (
1923
BackendWithCompilerDemo,
2024
)
2125

2226
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
27+
from executorch.exir.program._program import (
28+
EdgeProgramManager,
29+
to_edge_transform_and_lower,
30+
)
2331
from executorch.exir.serde.serialize import deserialize, serialize
2432
from torch import nn
2533
from torch.export import export
@@ -202,6 +210,33 @@ def forward(self, a, x, b):
202210
edge_new = deserialize(serialize(edge.exported_program()))
203211
self.check_ep(edge.exported_program(), edge_new, inputs)
204212

213+
def test_delegate_xnnpack(self) -> None:
214+
class SimpleConv1DModel(nn.Module):
215+
def __init__(self):
216+
super(SimpleConv1DModel, self).__init__()
217+
self.conv1 = nn.Conv1d(
218+
in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1
219+
)
220+
221+
def forward(self, x):
222+
x = self.conv1(x)
223+
return x
224+
225+
x = torch.randn(64, 1, 100)
226+
model = SimpleConv1DModel()
227+
ep = torch.export.export(model, (x,))
228+
edge_orig = to_edge_transform_and_lower(
229+
ep, partitioner=[XnnpackFloatingPointPartitioner()]
230+
)
231+
232+
with tempfile.NamedTemporaryFile() as f:
233+
exir.save(edge_orig.exported_program(), f)
234+
edge_deserialized = EdgeProgramManager(exir.load(f))
235+
self.assertTrue(
236+
edge_orig.to_executorch().buffer
237+
== edge_deserialized.to_executorch().buffer
238+
)
239+
205240
def test_meta_stack_trace_module_hierarchy(self) -> None:
206241
class Model(nn.Module):
207242
def __init__(self):

0 commit comments

Comments
 (0)