Skip to content

Commit 14d6609

Browse files
committed
Serialize NamedDataStoreOutput into PTD.
Update PTD serialization to account for blobs from the NamedDataStoreOutput. Something we can do in the future is to consolidate tensors (that go through the emitter) and blobs (that come from the NamedDataStore). Differential Revision: [D70939807](https://our.internmc.facebook.com/intern/diff/D70939807/) ghstack-source-id: 270931489 Pull Request resolved: #9125
1 parent bf72b5c commit 14d6609

File tree

4 files changed

+174
-54
lines changed

4 files changed

+174
-54
lines changed

exir/_serialize/_serialize.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66

77
# pyre-strict
88

9-
from typing import Dict, Optional, Tuple
9+
from typing import Dict, Optional, Set, Tuple
1010

1111
from executorch.exir._serialize import _serialize_pte_binary
1212

1313
from executorch.exir._serialize._cord import Cord
1414
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
1515
from executorch.exir._serialize.data_serializer import (
16+
DataEntry,
1617
DataPayload,
1718
DataSerializer,
1819
TensorEntry,
@@ -74,39 +75,54 @@ def serialize_for_executorch(
7475
tensor.extra_tensor_info.fully_qualified_name
7576
] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order)
7677

78+
if len(fqn_to_tensor_layout) == 0 and (
79+
named_data is None or len(named_data.external_data) == 0
80+
):
81+
return pte, ptd_files
82+
83+
all_external_files: Set[str] = set()
84+
if named_data is not None and len(named_data.external_data) > 0:
85+
assert (
86+
len(named_data.buffers) > 0
87+
), "External data exists, but there are no buffers provided."
88+
all_external_files = set(named_data.external_data.keys())
89+
7790
if len(fqn_to_tensor_layout) > 0:
7891
# emitter_output.external_constant_map contains the mapping from
7992
# {file: {fqn: index into external_constant_buffer}}
8093
# Contains the locations of the tensor buffers, and must be non-empty
8194
# if there are external tensors to serialize.
82-
assert emitter_output.external_constant_map is not None
83-
for (
84-
filename,
85-
fqn_to_index,
86-
) in (
87-
# pyre-ignore Undefined attribute [16]: Optional type has no attribute `items`.
88-
emitter_output.external_constant_map.items()
89-
):
90-
# Create a TensorEntry for each external tensor.
91-
fqn_to_tensor_entry: Dict[str, TensorEntry] = {}
92-
for fqn, index in fqn_to_index.items():
93-
assert fqn in fqn_to_tensor_layout
94-
fqn_to_tensor_entry[fqn] = TensorEntry(
95-
buffer_index=index,
96-
layout=fqn_to_tensor_layout[fqn],
97-
)
98-
99-
ptd_files[filename] = data_serializer.serialize(
100-
DataPayload(
101-
buffers=emitter_output.external_constant_buffer,
102-
fqn_to_tensor=fqn_to_tensor_entry,
103-
)
95+
assert (
96+
emitter_output.external_constant_map is not None
97+
), "External exists, but there are no buffers provided."
98+
all_external_files = all_external_files | set(
99+
emitter_output.external_constant_map.keys()
100+
)
101+
102+
for filename in all_external_files:
103+
fqn_to_tensor_entry: Dict[str, TensorEntry] = {}
104+
fqn_to_index = emitter_output.external_constant_map.get(filename, {})
105+
# Create a TensorEntry for each external tensor.
106+
for fqn, index in fqn_to_index.items():
107+
assert fqn in fqn_to_tensor_layout
108+
fqn_to_tensor_entry[fqn] = TensorEntry(
109+
buffer_index=index,
110+
layout=fqn_to_tensor_layout[fqn],
104111
)
105112

106-
if named_data is None or len(named_data.external_data) == 0:
107-
return pte, ptd_files
113+
# Extract external data.
114+
key_to_data: Dict[str, DataEntry] = {}
115+
key_to_buffer_index = named_data.external_data.get(filename, {})
116+
for key, index in key_to_buffer_index.items():
117+
key_to_data[key] = DataEntry(index, named_data.buffers[index].alignment)
108118

109-
if len(named_data.buffers) == 0:
110-
raise RuntimeError("External data exists, but there are no buffers provided.")
119+
# Serialize into PTD file.
120+
ptd_files[filename] = data_serializer.serialize(
121+
DataPayload(
122+
buffers=emitter_output.external_constant_buffer,
123+
fqn_to_tensor=fqn_to_tensor_entry,
124+
key_to_data=key_to_data,
125+
)
126+
)
111127

112128
return pte, ptd_files

exir/_serialize/data_serializer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,21 @@ class TensorEntry:
3838
layout: TensorLayout
3939

4040

41+
@dataclass
42+
class DataEntry:
43+
"""Represents a single blob in `DataPayload`, specifying its location
44+
and metadata.
45+
46+
Attributes:
47+
buffer_index: The index inside `DataPayload.buffers` that this
48+
DataEntryEntry refers to.
49+
alignment: The alignment of the data.
50+
"""
51+
52+
buffer_index: int
53+
alignment: int
54+
55+
4156
@dataclass
4257
class DataPayload:
4358
"""Contains the data and metadata required for serialization.
@@ -49,10 +64,12 @@ class DataPayload:
4964
Attributes:
5065
buffers: a sequence of tensor buffers.
5166
fqn_to_tensor: a map from fully qualified names to serializable tensors.
67+
key_to_data: a map from unique keys to serializable opaque data.
5268
"""
5369

5470
buffers: Sequence[bytes]
5571
fqn_to_tensor: Dict[str, TensorEntry]
72+
key_to_data: Dict[str, DataEntry]
5673

5774

5875
class DataSerializer(ABC):

extension/flat_tensor/serialize/serialize.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99
import json
10+
import math
1011
import os
1112
import tempfile
1213
from dataclasses import dataclass
@@ -19,6 +20,7 @@
1920
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
2021
from executorch.exir._serialize._program import _insert_flatbuffer_header
2122
from executorch.exir._serialize.data_serializer import (
23+
DataEntry,
2224
DataPayload,
2325
DataSerializer,
2426
TensorEntry,
@@ -29,6 +31,7 @@
2931
from executorch.extension.flat_tensor.serialize.flat_tensor_schema import (
3032
DataSegment,
3133
FlatTensor,
34+
NamedData,
3235
TensorMetadata,
3336
)
3437

@@ -202,6 +205,24 @@ def to_bytes(self) -> bytes:
202205
return data
203206

204207

208+
@dataclass
209+
class AlignedData:
210+
"""
211+
Holds data that should be aligned, for serialization.
212+
213+
Attributes:
214+
data: The data to serialize, as a cord.
215+
alignment: The alignment required for the data.
216+
"""
217+
218+
data: Cord
219+
alignment: int
220+
221+
def __init__(self, data: Cord, alignment: Optional[int] = None) -> None:
222+
self.data = data
223+
self.alignment = alignment or 1
224+
225+
205226
def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
206227
"""Returns the extended header of the flat_tensor data, if present and valid."""
207228
try:
@@ -216,7 +237,7 @@ def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
216237
def _extract_tensors(
217238
fqn_to_tensor: Dict[str, TensorEntry],
218239
buffers: Sequence[bytes],
219-
segments: List[Cord],
240+
segments: List[AlignedData],
220241
tensor_alignment: int,
221242
) -> List[TensorMetadata]:
222243
"""Places tensors into a single segment, aligned to tensor_alignment within
@@ -265,10 +286,43 @@ def _extract_tensors(
265286
offset=offset,
266287
)
267288
)
268-
segments.append(tensor_data)
289+
segments.append(AlignedData(tensor_data))
269290
return tensors
270291

271292

293+
def _extract_named_data(
294+
key_to_data: Dict[str, DataEntry],
295+
buffers: Sequence[bytes],
296+
segments: List[AlignedData],
297+
) -> List[NamedData]:
298+
"""Places named data into segments and record the alignment for each.
299+
300+
Args:
301+
key_to_data: A map from keys to opaque data entries.
302+
buffers: A sequence of buffers holding opaque blob data.
303+
segments: A list of segments to append data to. Modified in-place.
304+
305+
Returns:
306+
A list of NamedData describing the offsets to the opaque blob data.
307+
"""
308+
309+
# Map from buffer_idx to segment_idx.
310+
segment_index_map: Dict[int, int] = {}
311+
312+
named_data: List[NamedData] = []
313+
for key, data_entry in key_to_data.items():
314+
buffer_idx = data_entry.buffer_index
315+
segment_index = segment_index_map.get(buffer_idx, None)
316+
if segment_index is None:
317+
segment_index = len(segments)
318+
segment_index_map[buffer_idx] = segment_index
319+
segments.append(
320+
AlignedData(Cord(buffers[buffer_idx]), data_entry.alignment)
321+
)
322+
named_data.append(NamedData(key=key, segment_index=segment_index))
323+
return named_data
324+
325+
272326
class FlatTensorSerializer(DataSerializer):
273327
"""A concrete implementation of the DataSerializer interface that
274328
serializes and deserializes data to/from the FlatTensor format.
@@ -289,13 +343,14 @@ def serialize(
289343
) -> Cord:
290344
"""Serializes a list of tensors and named data into a blob."""
291345

292-
segments: List[Cord] = []
346+
segments: List[AlignedData] = []
293347
tensors = _extract_tensors(
294348
data.fqn_to_tensor,
295349
data.buffers,
296350
segments,
297351
self.config.tensor_alignment,
298352
)
353+
named_data = _extract_named_data(data.key_to_data, data.buffers, segments)
299354

300355
data_segments: List[DataSegment] = []
301356
segment_data = Cord()
@@ -305,19 +360,18 @@ def serialize(
305360
if data_segments
306361
else 0
307362
)
363+
alignment = math.lcm(self.config.segment_alignment, segment.alignment)
308364
data_segments.append(
309365
DataSegment(
310-
offset=aligned_size(prev_end, self.config.segment_alignment),
311-
size=len(segment),
366+
offset=aligned_size(prev_end, alignment),
367+
size=len(segment.data),
312368
)
313369
)
314370
# Pad segment_data to segment alignment.
315-
segment_pad_length = padding_required(
316-
len(segment_data), self.config.segment_alignment
317-
)
371+
segment_pad_length = padding_required(len(segment_data), alignment)
318372
if segment_pad_length > 0:
319373
segment_data.append(b"\x00" * segment_pad_length)
320-
segment_data.append(segment)
374+
segment_data.append(segment.data)
321375

322376
# Create FlatTensor, which describes of the contents of the file and
323377
# points to all the data segments. It will be serialized to flatbuffer.
@@ -326,7 +380,7 @@ def serialize(
326380
tensor_alignment=self.config.tensor_alignment,
327381
tensors=tensors,
328382
segments=data_segments,
329-
named_data=[],
383+
named_data=named_data,
330384
)
331385

332386
flatbuffer_payload = _serialize_to_flatbuffer(flat_tensor)

0 commit comments

Comments
 (0)