Skip to content

Refactor serialize.py #9572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions extension/flat_tensor/serialize/flat_tensor.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ table TensorMetadata {
// To retrieve a given tensor:
// 1. segment_base_offset: from the file header.
// 2. segment_offset: segments[segment_index].offset
// 3. tensor_offset: segments[segment_offset].tensor_metadata[j].offset
// Find the relevant index j by matching on tensor fqn.
// 3. tensor_offset: the offset within the segment. If there is only one item
// in the segment, offset=0.
offset: uint64;
}

Expand All @@ -55,6 +55,15 @@ table DataSegment {
size: uint64;
}

// Attributes a name to data referenced by FlatTensor.segments.
table NamedData {
// The unique id of the data blob.
key: string;

// Index of the segment in FlatTensor.segments.
segment_index: uint32;
}

// FlatTensor is a flatbuffer-based format for storing and loading tensors.
table FlatTensor {
// Schema version.
Expand All @@ -70,6 +79,10 @@ table FlatTensor {
// List of data segments that follow the FlatTensor data in this file, sorted by
// offset. Elements in this schema can refer to these segments by index.
segments: [DataSegment];

// List of blobs keyed by a unique name. Note that multiple 'NamedData'
// entries could point to the same segment index.
named_data: [NamedData];
}

root_type FlatTensor;
7 changes: 7 additions & 0 deletions extension/flat_tensor/serialize/flat_tensor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,16 @@ class DataSegment:
size: int


@dataclass
class NamedData:
key: str
segment_index: int


@dataclass
class FlatTensor:
version: int
tensor_alignment: int
tensors: List[TensorMetadata]
segments: List[DataSegment]
named_data: List[NamedData]
159 changes: 102 additions & 57 deletions extension/flat_tensor/serialize/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,33 @@
import os
import tempfile
from dataclasses import dataclass
from typing import ClassVar, Dict, List, Literal, Optional
from typing import ClassVar, Dict, List, Literal, Optional, Sequence

import pkg_resources
from executorch.exir._serialize._cord import Cord
from executorch.exir._serialize._dataclass import _DataclassEncoder, _json_to_dataclass

from executorch.exir._serialize._flatbuffer import _flatc_compile, _flatc_decompile
from executorch.exir._serialize._program import _insert_flatbuffer_header
from executorch.exir._serialize.data_serializer import DataPayload, DataSerializer
from executorch.exir._serialize.data_serializer import (
DataPayload,
DataSerializer,
TensorEntry,
)

from executorch.exir._serialize.padding import aligned_size, pad_to, padding_required

# Byte order of numbers written to flat tensor headers. Always little-endian
# regardless of the host system, since all commonly-used modern CPUs are little
# endian.
_HEADER_BYTEORDER: Literal["little"] = "little"

from executorch.extension.flat_tensor.serialize.flat_tensor_schema import (
DataSegment,
FlatTensor,
TensorMetadata,
)

# Byte order of numbers written to flat tensor headers. Always little-endian
# regardless of the host system, since all commonly-used modern CPUs are little
# endian.
_HEADER_BYTEORDER: Literal["little"] = "little"


def _serialize_to_flatbuffer(flat_tensor: FlatTensor) -> Cord:
"""Serializes a FlatTensor to a flatbuffer and returns the serialized data."""
Expand Down Expand Up @@ -209,6 +213,62 @@ def _get_extended_header(flat_tensor_data: bytes) -> Optional[FlatTensorHeader]:
return None


def _extract_tensors(
fqn_to_tensor: Dict[str, TensorEntry],
buffers: Sequence[bytes],
segments: List[Cord],
tensor_alignment: int,
) -> List[TensorMetadata]:
"""Places tensors into a single segment, aligned to tensor_alignment within
the segment.

Args:
fqn_to_tensor: A map from fully qualified names to tensor entries.
buffers: A sequence of tensor buffers.
segments: A list of segments to append the tensor data to. Modified in-place.
tensor_alignment: The alignment of the tensor data.

Returns:
A list of TensorMetadata, which describes the tensors in the segment.
"""
tensor_data: Cord = Cord()
tensors: List[TensorMetadata] = []
# {idx, offset}
saved_offsets: Dict[int, int] = {}
for fqn, tensor_entry in fqn_to_tensor.items():
assert tensor_entry.layout is not None
# Check index into the tensor buffers is valid.
assert tensor_entry.buffer_index < len(
buffers
), f"Invalid index {tensor_entry.buffer_index} is greater than tensor buffer size {len(buffers)}."

# Check if the tensor has already been appended to the flat_tensor_data.
offset = saved_offsets.get(tensor_entry.buffer_index, -1)
if offset == -1:
if len(tensor_data) > 0:
# Add padding to round off the previous tensor offset.
pad_length = padding_required(len(tensor_data), tensor_alignment)
tensor_data.append(b"\x00" * pad_length)
# Add to saved offsets.
offset = len(tensor_data)
saved_offsets[tensor_entry.buffer_index] = offset
# Append to flat_tensor_data at the offset.
tensor_data.append(buffers[tensor_entry.buffer_index])

tensors.append(
TensorMetadata(
fully_qualified_name=fqn,
scalar_type=tensor_entry.layout.scalar_type,
sizes=tensor_entry.layout.sizes,
dim_order=tensor_entry.layout.dim_order,
segment_index=len(segments),
offset=offset,
)
)
segments.append(tensor_data)
return tensors


class FlatTensorSerializer(DataSerializer):
"""A concrete implementation of the DataSerializer interface that
serializes and deserializes data to/from the FlatTensor format.
Expand All @@ -227,61 +287,46 @@ def serialize(
self,
data: DataPayload,
) -> Cord:
"""Serializes a list of tensor metadata and tensors into a blob."""

flat_tensor_metadata: List[TensorMetadata] = []
flat_tensor_data: Cord = Cord()

# {idx, offset}
saved_offsets: Dict[int, int] = {}

for fqn, tensor_entry in data.fqn_to_tensor.items():
assert tensor_entry.layout is not None
# Check index into the tensor buffers is valid.
assert tensor_entry.buffer_index < len(
data.buffers
), f"Invalid index {tensor_entry.buffer_index} is greater than tensor buffer size {len(data.buffers)}."

# Check if the tensor has already been appended to the flat_tensor_data.
offset = saved_offsets.get(tensor_entry.buffer_index, -1)
if offset == -1:
if len(flat_tensor_data) > 0:
# Add padding to round off the previous tensor offset.
pad_length = padding_required(
len(flat_tensor_data), self.config.tensor_alignment
)
flat_tensor_data.append(b"\x00" * pad_length)
# Add to saved offsets.
offset = len(flat_tensor_data)
saved_offsets[tensor_entry.buffer_index] = offset
# Append to flat_tensor_data at the offset.
flat_tensor_data.append(data.buffers[tensor_entry.buffer_index])

flat_tensor_metadata.append(
TensorMetadata(
fully_qualified_name=fqn,
scalar_type=tensor_entry.layout.scalar_type,
sizes=tensor_entry.layout.sizes,
dim_order=tensor_entry.layout.dim_order,
segment_index=0,
offset=offset,
"""Serializes a list of tensors and named data into a blob."""

segments: List[Cord] = []
tensors = _extract_tensors(
data.fqn_to_tensor,
data.buffers,
segments,
self.config.tensor_alignment,
)

data_segments: List[DataSegment] = []
segment_data = Cord()
for segment in segments:
prev_end = (
(data_segments[-1].offset + data_segments[-1].size)
if data_segments
else 0
)
data_segments.append(
DataSegment(
offset=aligned_size(prev_end, self.config.segment_alignment),
size=len(segment),
)
)

# Pad flat_tensor_data to segment alignment.
segment_pad_length = padding_required(
len(flat_tensor_data), self.config.segment_alignment
)
if segment_pad_length > 0:
flat_tensor_data.append(b"\x00" * segment_pad_length)
# Pad segment_data to segment alignment.
segment_pad_length = padding_required(
len(segment_data), self.config.segment_alignment
)
if segment_pad_length > 0:
segment_data.append(b"\x00" * segment_pad_length)
segment_data.append(segment)

# Create FlatTensor, which describes of the contents of the file and
# points to all the data segments. It will be serialized to flatbuffer.
flat_tensor = FlatTensor(
version=0, # Keep in sync with c++ version number in serialize.h
tensor_alignment=self.config.tensor_alignment,
tensors=flat_tensor_metadata,
segments=[DataSegment(offset=0, size=len(flat_tensor_data))],
tensors=tensors,
segments=data_segments,
named_data=[],
)

flatbuffer_payload = _serialize_to_flatbuffer(flat_tensor)
Expand All @@ -306,7 +351,7 @@ def serialize(
flatbuffer_offset=padded_header_length,
flatbuffer_size=len(flatbuffer_payload),
segment_base_offset=segment_base_offset,
segment_data_size=len(flat_tensor_data),
segment_data_size=len(segment_data),
).to_bytes()

# Pad header and payload to segment alignment.
Expand All @@ -326,15 +371,15 @@ def serialize(
assert eh.flatbuffer_size == original_flatbuffer_payload_size
assert eh.segment_base_offset == segment_base_offset
assert eh.flatbuffer_offset == padded_header_length
assert eh.segment_data_size == len(flat_tensor_data)
assert eh.segment_data_size == len(segment_data)

del header_data
del flatbuffer_payload

# Place everything into one segment.
payload = Cord()
payload.append(injected_flatbuffer_data)
payload.append(flat_tensor_data)
payload.append(segment_data)

return payload

Expand Down
Loading