Skip to content

Commit bcc515b

Browse files
tarun292facebook-github-bot
authored andcommitted
Fix exir.load/save to handle named data store map (#9485)
Summary: After named data store map was added recently to edge dialect the support for serializing/deserializing named data store map was missing in exir.load/save. This diff adds that and also adds a test to check for this. Reviewed By: pssrawat Differential Revision: D71601647
1 parent 0dd7e4e commit bcc515b

File tree

4 files changed

+47
-0
lines changed

4 files changed

+47
-0
lines changed

exir/serde/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,3 +402,4 @@ class LoweredBackendModule:
402402
original_module: export_schema.ExportedProgram
403403
original_state_dict: str
404404
original_constants: str
405+
named_data_store: bytes

exir/serde/serialize.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch
2323
import torch.export.exported_program as ep
2424
from executorch.exir import delegate
25+
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
2526
from executorch.exir.backend.compile_spec_schema import (
2627
CompileSpec as delegate_CompileSpec,
2728
)
@@ -276,6 +277,7 @@ def serialize_bytes(b: bytes) -> str:
276277
processed_bytes=serialized_processed_bytes,
277278
compile_specs=serialized_compile_spec,
278279
backend_id=lowered_module.backend_id,
280+
named_data_store=json.dumps(export_serialize._dataclass_to_dict(lowered_module.named_data_store_output),cls=export_serialize.EnumEncoder),
279281
)
280282

281283
json_lowered_module = json.dumps(
@@ -556,11 +558,19 @@ def deserialize_lowered_module(
556558
None,
557559
)
558560

561+
if serialized_lowered_module.named_data_store == "":
562+
named_data_store = None
563+
else:
564+
named_data_store = export_serialize._dict_to_dataclass(NamedDataStoreOutput, json.loads(serialized_lowered_module.named_data_store))
565+
for buffer in named_data_store.buffers:
566+
buffer.buffer = base64.b64decode(buffer.buffer.encode("ascii"))
567+
559568
lowered_module = ExirLoweredBackendModule(
560569
original_module,
561570
backend_id,
562571
processed_bytes,
563572
compile_specs,
573+
named_data_store
564574
)
565575
self.module.register_module(serialized_lowered_module_arg.name, lowered_module)
566576
return self.graph.get_attr(serialized_lowered_module_arg.name)

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ python_unittest(
9898
"//executorch/exir/backend/test:backend_with_compiler_demo",
9999
"//executorch/exir/backend/test:op_partitioner_demo",
100100
"//executorch/exir/serde:serialize",
101+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
101102
],
102103
)
103104

exir/tests/test_serde.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,27 @@
77
# pyre-strict
88

99
import io
10+
import tempfile
1011
import unittest
1112
from typing import Tuple
1213

1314
import executorch.exir as exir
1415

1516
import torch
17+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
18+
XnnpackFloatingPointPartitioner,
19+
)
1620
from executorch.exir import to_edge
1721
from executorch.exir.backend.backend_api import CompileSpec, to_backend
1822
from executorch.exir.backend.test.backend_with_compiler_demo import (
1923
BackendWithCompilerDemo,
2024
)
2125

2226
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
27+
from executorch.exir.program._program import (
28+
EdgeProgramManager,
29+
to_edge_transform_and_lower,
30+
)
2331
from executorch.exir.serde.serialize import deserialize, serialize
2432
from torch import nn
2533
from torch.export import export
@@ -202,6 +210,33 @@ def forward(self, a, x, b):
202210
edge_new = deserialize(serialize(edge.exported_program()))
203211
self.check_ep(edge.exported_program(), edge_new, inputs)
204212

213+
def test_delegate_xnnpack(self) -> None:
214+
class SimpleConv1DModel(nn.Module):
215+
def __init__(self):
216+
super(SimpleConv1DModel, self).__init__()
217+
self.conv1 = nn.Conv1d(
218+
in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1
219+
)
220+
221+
def forward(self, x):
222+
x = self.conv1(x)
223+
return x
224+
225+
x = torch.randn(64, 1, 100)
226+
model = SimpleConv1DModel()
227+
ep = torch.export.export(model, (x,))
228+
edge_orig = to_edge_transform_and_lower(
229+
ep, partitioner=[XnnpackFloatingPointPartitioner()]
230+
)
231+
232+
with tempfile.NamedTemporaryFile() as f:
233+
exir.save(edge_orig.exported_program(), f)
234+
edge_deserialized = EdgeProgramManager(exir.load(f))
235+
self.assertTrue(
236+
edge_orig.to_executorch().buffer
237+
== edge_deserialized.to_executorch().buffer
238+
)
239+
205240
def test_meta_stack_trace_module_hierarchy(self) -> None:
206241
class Model(nn.Module):
207242
def __init__(self):

0 commit comments

Comments
 (0)