Skip to content

Commit ed8f793

Browse files
authored
Serialize NamedDataStoreOutput into PTD.
Differential Revision: D70939807 Pull Request resolved: #9125
1 parent 9bcbdbd commit ed8f793

File tree

4 files changed

+185
-58
lines changed

4 files changed

+185
-58
lines changed

exir/_serialize/_serialize.py

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66

77
# pyre-strict
88

9-
from typing import Dict, Optional, Tuple
9+
from typing import Dict, Optional, Set, Tuple
1010

1111
from executorch.exir._serialize import _serialize_pte_binary
1212

1313
from executorch.exir._serialize._cord import Cord
1414
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput
1515
from executorch.exir._serialize.data_serializer import (
16+
DataEntry,
1617
DataPayload,
1718
DataSerializer,
1819
TensorEntry,
@@ -74,39 +75,59 @@ def serialize_for_executorch(
7475
tensor.extra_tensor_info.fully_qualified_name
7576
] = TensorLayout(tensor.scalar_type, tensor.sizes, tensor.dim_order)
7677

78+
if len(fqn_to_tensor_layout) == 0 and (
79+
named_data is None or len(named_data.external_data) == 0
80+
):
81+
return pte, ptd_files
82+
83+
# Consolidate tensors and opaque data with the same external tag so they
84+
# can be saved to the same PTD.
85+
all_external_tags: Set[str] = set()
86+
if named_data is not None and len(named_data.external_data) > 0:
87+
assert (
88+
len(named_data.buffers) > 0
89+
), "External data exists, but there are no buffers provided."
90+
all_external_tags = set(named_data.external_data.keys())
91+
7792
if len(fqn_to_tensor_layout) > 0:
7893
# emitter_output.external_constant_map contains the mapping from
7994
# {file: {fqn: index into external_constant_buffer}}
8095
# Contains the locations of the tensor buffers, and must be non-empty
8196
# if there are external tensors to serialize.
82-
assert emitter_output.external_constant_map is not None
83-
for (
84-
filename,
85-
fqn_to_index,
86-
) in (
87-
# pyre-ignore Undefined attribute [16]: Optional type has no attribute `items`.
88-
emitter_output.external_constant_map.items()
89-
):
90-
# Create a TensorEntry for each external tensor.
91-
fqn_to_tensor_entry: Dict[str, TensorEntry] = {}
92-
for fqn, index in fqn_to_index.items():
93-
assert fqn in fqn_to_tensor_layout
94-
fqn_to_tensor_entry[fqn] = TensorEntry(
95-
buffer_index=index,
96-
layout=fqn_to_tensor_layout[fqn],
97-
)
98-
99-
ptd_files[filename] = data_serializer.serialize(
100-
DataPayload(
101-
buffers=emitter_output.external_constant_buffer,
102-
fqn_to_tensor=fqn_to_tensor_entry,
103-
)
97+
assert (
98+
emitter_output.external_constant_map is not None
99+
), "External exists, but there are no buffers provided."
100+
all_external_tags = all_external_tags | set(
101+
emitter_output.external_constant_map.keys()
102+
)
103+
104+
for tag in all_external_tags:
105+
fqn_to_tensor_entry: Dict[str, TensorEntry] = {}
106+
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
107+
fqn_to_index = emitter_output.external_constant_map.get(tag, {})
108+
# Create a TensorEntry for each external tensor.
109+
for fqn, index in fqn_to_index.items():
110+
assert fqn in fqn_to_tensor_layout
111+
fqn_to_tensor_entry[fqn] = TensorEntry(
112+
buffer_index=index,
113+
layout=fqn_to_tensor_layout[fqn],
104114
)
105115

106-
if named_data is None or len(named_data.external_data) == 0:
107-
return pte, ptd_files
116+
# Extract external data.
117+
key_to_data: Dict[str, DataEntry] = {}
118+
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `get`.
119+
key_to_buffer_index = named_data.external_data.get(tag, {})
120+
for key, index in key_to_buffer_index.items():
121+
# pyre-ignore[16]: Undefined attribute: `Optional` has no attribute `buffers`.
122+
key_to_data[key] = DataEntry(index, named_data.buffers[index].alignment)
108123

109-
if len(named_data.buffers) == 0:
110-
raise RuntimeError("External data exists, but there are no buffers provided.")
124+
# Serialize into PTD file.
125+
ptd_files[tag] = data_serializer.serialize(
126+
DataPayload(
127+
buffers=emitter_output.external_constant_buffer,
128+
fqn_to_tensor=fqn_to_tensor_entry,
129+
key_to_data=key_to_data,
130+
)
131+
)
111132

112133
return pte, ptd_files

exir/_serialize/data_serializer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,21 @@ class TensorEntry:
3838
layout: TensorLayout
3939

4040

41+
@dataclass
42+
class DataEntry:
43+
"""Represents a single blob in `DataPayload`, specifying its location
44+
and metadata.
45+
46+
Attributes:
47+
buffer_index: The index inside `DataPayload.buffers` that this
48+
DataEntry refers to.
49+
alignment: The alignment of the data.
50+
"""
51+
52+
buffer_index: int
53+
alignment: int
54+
55+
4156
@dataclass
4257
class DataPayload:
4358
"""Contains the data and metadata required for serialization.
@@ -49,10 +64,12 @@ class DataPayload:
4964
Attributes:
5065
buffers: a sequence of tensor buffers.
5166
fqn_to_tensor: a map from fully qualified names to serializable tensors.
67+
key_to_data: a map from unique keys to serializable opaque data.
5268
"""
5369

5470
buffers: Sequence[bytes]
5571
fqn_to_tensor: Dict[str, TensorEntry]
72+
key_to_data: Dict[str, DataEntry]
5673

5774

5875
class DataSerializer(ABC):

extension/flat_tensor/serialize/serialize.py

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99
import json
10+
import math
1011
import os
1112
import tempfile
1213
from dataclasses import dataclass
@@ -19,6 +20,7 @@
1920
from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
2021
from executorch.exir._serialize._program import _insert_flatbuffer_header
2122
from executorch.exir._serialize.data_serializer import (
23+
DataEntry,
2224
DataPayload,
2325
DataSerializer,
2426
TensorEntry,
@@ -29,6 +31,7 @@
2931
from executorch.extension.flat_tensor.serialize.flat_tensor_schema import (
3032
DataSegment,
3133
FlatTensor,
34+
NamedData,
3235
TensorMetadata,
3336
)
3437

@@ -202,6 +205,24 @@ def to_bytes(self) -> bytes:
202205
return data
203206

204207

208+
@dataclass
209+
class AlignedData:
210+
"""
211+
Holds data that should be aligned, for serialization.
212+
213+
Attributes:
214+
data: The data to serialize, as a cord.
215+
alignment: The alignment required for the data.
216+
"""
217+
218+
data: Cord
219+
alignment: int
220+
221+
def __init__(self, data: Cord, alignment: Optional[int] = None) -> None:
222+
self.data = data
223+
self.alignment = alignment or 1
224+
225+
205226
def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
206227
"""Returns the extended header of the flat_tensor data, if present and valid."""
207228
try:
@@ -216,7 +237,7 @@ def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
216237
def _extract_tensors(
217238
fqn_to_tensor: Dict[str, TensorEntry],
218239
buffers: Sequence[bytes],
219-
segments: List[Cord],
240+
segments: List[AlignedData],
220241
tensor_alignment: int,
221242
) -> List[TensorMetadata]:
222243
"""Places tensors into a single segment, aligned to tensor_alignment within
@@ -265,10 +286,43 @@ def _extract_tensors(
265286
offset=offset,
266287
)
267288
)
268-
segments.append(tensor_data)
289+
segments.append(AlignedData(tensor_data))
269290
return tensors
270291

271292

293+
def _extract_named_data(
294+
key_to_data: Dict[str, DataEntry],
295+
buffers: Sequence[bytes],
296+
segments: List[AlignedData],
297+
) -> List[NamedData]:
298+
"""Places named data into segments and record the alignment for each.
299+
300+
Args:
301+
key_to_data: A map from keys to opaque data entries.
302+
buffers: A sequence of buffers holding opaque blob data.
303+
segments: A list of segments to append data to. Modified in-place.
304+
305+
Returns:
306+
A list of NamedData describing the offsets to the opaque blob data.
307+
"""
308+
309+
# Map from buffer_idx to segment_idx.
310+
segment_index_map: Dict[int, int] = {}
311+
312+
named_data: List[NamedData] = []
313+
for key, data_entry in key_to_data.items():
314+
buffer_idx = data_entry.buffer_index
315+
segment_index = segment_index_map.get(buffer_idx, None)
316+
if segment_index is None:
317+
segment_index = len(segments)
318+
segment_index_map[buffer_idx] = segment_index
319+
segments.append(
320+
AlignedData(Cord(buffers[buffer_idx]), data_entry.alignment)
321+
)
322+
named_data.append(NamedData(key=key, segment_index=segment_index))
323+
return named_data
324+
325+
272326
class FlatTensorSerializer(DataSerializer):
273327
"""A concrete implementation of the DataSerializer interface that
274328
serializes and deserializes data to/from the FlatTensor format.
@@ -289,35 +343,37 @@ def serialize(
289343
) -> Cord:
290344
"""Serializes a list of tensors and named data into a blob."""
291345

292-
segments: List[Cord] = []
346+
segments: List[AlignedData] = []
293347
tensors = _extract_tensors(
294348
data.fqn_to_tensor,
295349
data.buffers,
296350
segments,
297351
self.config.tensor_alignment,
298352
)
353+
named_data = _extract_named_data(data.key_to_data, data.buffers, segments)
299354

300355
data_segments: List[DataSegment] = []
301-
segment_data = Cord()
356+
aggregated_segment_data = Cord()
302357
for segment in segments:
303358
prev_end = (
304359
(data_segments[-1].offset + data_segments[-1].size)
305360
if data_segments
306361
else 0
307362
)
363+
alignment = math.lcm(self.config.segment_alignment, segment.alignment)
308364
data_segments.append(
309365
DataSegment(
310-
offset=aligned_size(prev_end, self.config.segment_alignment),
311-
size=len(segment),
366+
offset=aligned_size(prev_end, alignment),
367+
size=len(segment.data),
312368
)
313369
)
314-
# Pad segment_data to segment alignment.
370+
# Pad aggregated_segment_data to segment alignment.
315371
segment_pad_length = padding_required(
316-
len(segment_data), self.config.segment_alignment
372+
len(aggregated_segment_data), alignment
317373
)
318374
if segment_pad_length > 0:
319-
segment_data.append(b"\x00" * segment_pad_length)
320-
segment_data.append(segment)
375+
aggregated_segment_data.append(b"\x00" * segment_pad_length)
376+
aggregated_segment_data.append(segment.data)
321377

322378
# Create FlatTensor, which describes of the contents of the file and
323379
# points to all the data segments. It will be serialized to flatbuffer.
@@ -326,7 +382,7 @@ def serialize(
326382
tensor_alignment=self.config.tensor_alignment,
327383
tensors=tensors,
328384
segments=data_segments,
329-
named_data=[],
385+
named_data=named_data,
330386
)
331387

332388
flatbuffer_payload = _serialize_to_flatbuffer(flat_tensor)
@@ -351,7 +407,7 @@ def serialize(
351407
flatbuffer_offset=padded_header_length,
352408
flatbuffer_size=len(flatbuffer_payload),
353409
segment_base_offset=segment_base_offset,
354-
segment_data_size=len(segment_data),
410+
segment_data_size=len(aggregated_segment_data),
355411
).to_bytes()
356412

357413
# Pad header and payload to segment alignment.
@@ -371,15 +427,15 @@ def serialize(
371427
assert eh.flatbuffer_size == original_flatbuffer_payload_size
372428
assert eh.segment_base_offset == segment_base_offset
373429
assert eh.flatbuffer_offset == padded_header_length
374-
assert eh.segment_data_size == len(segment_data)
430+
assert eh.segment_data_size == len(aggregated_segment_data)
375431

376432
del header_data
377433
del flatbuffer_payload
378434

379435
# Place everything into one segment.
380436
payload = Cord()
381437
payload.append(injected_flatbuffer_data)
382-
payload.append(segment_data)
438+
payload.append(aggregated_segment_data)
383439

384440
return payload
385441

0 commit comments

Comments
 (0)