Skip to content

Commit 320d555

Browse files
pytorchbotlucylq
andauthored
Add NamedData to flat_tensor schema (#9571)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #9123 by @lucylq ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/57/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/57/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/57/orig @diff-train-skip-merge Co-authored-by: lucylq <[email protected]>
1 parent 012f120 commit 320d555

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

extension/flat_tensor/serialize/flat_tensor.fbs

+15-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ table TensorMetadata {
3535
// To retrieve a given tensor:
3636
// 1. segment_base_offset: from the file header.
3737
// 2. segment_offset: segments[segment_index].offset
38-
// 3. tensor_offset: segments[segment_offset].tensor_metadata[j].offset
39-
// Find the relevant index j by matching on tensor fqn.
38+
// 3. tensor_offset: the offset within the segment. If there is only one item
39+
// in the segment, offset=0.
4040
offset: uint64;
4141
}
4242

@@ -55,6 +55,15 @@ table DataSegment {
5555
size: uint64;
5656
}
5757

58+
// Attributes a name to data referenced by FlatTensor.segments.
59+
table NamedData {
60+
// The unique id of the data blob.
61+
key: string;
62+
63+
// Index of the segment in FlatTensor.segments.
64+
segment_index: uint32;
65+
}
66+
5867
// FlatTensor is a flatbuffer-based format for storing and loading tensors.
5968
table FlatTensor {
6069
// Schema version.
@@ -70,6 +79,10 @@ table FlatTensor {
7079
// List of data segments that follow the FlatTensor data in this file, sorted by
7180
// offset. Elements in this schema can refer to these segments by index.
7281
segments: [DataSegment];
82+
83+
// List of blobs keyed by a unique name. Note that multiple 'NamedData'
84+
// entries could point to the same segment index.
85+
named_data: [NamedData];
7386
}
7487

7588
root_type FlatTensor;

extension/flat_tensor/serialize/flat_tensor_schema.py

+7
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,16 @@ class DataSegment:
3131
size: int
3232

3333

34+
@dataclass
35+
class NamedData:
36+
key: str
37+
segment_index: int
38+
39+
3440
@dataclass
3541
class FlatTensor:
3642
version: int
3743
tensor_alignment: int
3844
tensors: List[TensorMetadata]
3945
segments: List[DataSegment]
46+
named_data: List[NamedData]

extension/flat_tensor/serialize/serialize.py

+1
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def serialize(
282282
tensor_alignment=self.config.tensor_alignment,
283283
tensors=flat_tensor_metadata,
284284
segments=[DataSegment(offset=0, size=len(flat_tensor_data))],
285+
named_data=[],
285286
)
286287

287288
flatbuffer_payload = _serialize_to_flatbuffer(flat_tensor)

0 commit comments

Comments
 (0)