Skip to content

Commit 1e6ed22

Browse files
JacobSzwejbkakirklandsign
authored andcommitted
Expand Program Interface
Differential Revision: D60977264 Pull Request resolved: pytorch#4680
1 parent f419294 commit 1e6ed22

File tree

4 files changed

+207
-1
lines changed

4 files changed

+207
-1
lines changed

runtime/executor/program.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,5 +410,90 @@ Result<FreeableBuffer> Program::LoadSegment(
410410
segment_base_offset_ + segment->offset(), segment->size(), segment_info);
411411
}
412412

413+
Error Program::load_mutable_subsegment_into(
414+
size_t mutable_data_segments_index,
415+
size_t offset_index,
416+
size_t size,
417+
void* buffer) const {
418+
EXECUTORCH_SCOPE_PROF("Program::load_subsegment_into");
419+
// Check that the program has segments.
420+
if (loader_ == nullptr || segment_base_offset_ == 0) {
421+
ET_LOG(Error, "No segments in program");
422+
return Error::NotFound;
423+
}
424+
425+
// Check that the program has mutable data segments.
426+
if (internal_program_->mutable_data_segments() == nullptr) {
427+
ET_LOG(Error, "No mutable data segments in program");
428+
return Error::NotFound;
429+
}
430+
if (mutable_data_segments_index >=
431+
internal_program_->mutable_data_segments()->size()) {
432+
ET_LOG(
433+
Error,
434+
"mutable_data_segments_index %zu out of range >= %" PRIu64,
435+
mutable_data_segments_index,
436+
(uint64_t)internal_program_->mutable_data_segments()->size());
437+
return Error::NotFound;
438+
}
439+
440+
// Grab the mutable data segment info.
441+
const auto& segment_offsets = internal_program_->mutable_data_segments()->Get(
442+
mutable_data_segments_index);
443+
444+
// Check that the offset is valid.
445+
if (segment_offsets->offsets() == nullptr) {
446+
ET_LOG(Error, "No offsets in mutable data segment");
447+
return Error::NotFound;
448+
}
449+
if (offset_index >= segment_offsets->offsets()->size()) {
450+
ET_LOG(
451+
Error,
452+
"offset index %zu out of range >= %" PRIu64,
453+
offset_index,
454+
(uint64_t)segment_offsets->offsets()->size());
455+
return Error::NotFound;
456+
}
457+
458+
// Grab the offset. Note: This offset is relative to the start of the segment,
459+
// so we will need to adjust when calling the loader.
460+
size_t offset = segment_offsets->offsets()->Get(offset_index);
461+
462+
// Grab the segment index
463+
size_t num_segments = internal_program_->segments()->size();
464+
if (segment_offsets->segment_index() >= num_segments) {
465+
ET_LOG(
466+
Error,
467+
"Segment index %u out of range (>= %zu)",
468+
segment_offsets->segment_index(),
469+
num_segments);
470+
return Error::NotFound;
471+
}
472+
473+
// Grab the segment
474+
auto segment =
475+
internal_program_->segments()->Get(segment_offsets->segment_index());
476+
477+
// Check size
478+
if (offset + size > segment->size()) {
479+
ET_LOG(
480+
Error,
481+
"offset %zu + size %zu out of range > %" PRIu64,
482+
offset,
483+
size,
484+
segment->size());
485+
return Error::InvalidArgument;
486+
}
487+
488+
DataLoader::SegmentInfo info = DataLoader::SegmentInfo(
489+
DataLoader::SegmentInfo::Type::Mutable,
490+
segment_offsets->segment_index(),
491+
nullptr);
492+
493+
// Load the data
494+
return loader_->load_into(
495+
segment_base_offset_ + segment->offset() + offset, size, info, buffer);
496+
}
497+
413498
} // namespace executor
414499
} // namespace torch

runtime/executor/program.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,30 @@ class Program final {
223223
__ET_NODISCARD Result<FreeableBuffer> LoadSegment(
224224
const DataLoader::SegmentInfo& segment_info) const;
225225

226+
/**
227+
* Loads a portion of a mutable segment into the provided buffer.
228+
*
229+
* @param[in] mutable_data_segments_index The index into the
230+
* mutable_data_segments_array.
231+
* @param[in] offset_index The index into the segment's offsets array.
232+
* @param[in] size The number of bytes to load.
233+
* @param[in] buffer The buffer to load data into. Must point to at least
234+
* `size` bytes of memory.
235+
*
236+
* @returns An error code on if the load was successful.
237+
* @retval Error::Ok The load was successful.
238+
* @retval Error::NotFound The program does not contain any segments or the
239+
* indices are out of range.
240+
* @returns Other errors depending on the implementation of
241+
* DataLoader: The Program.segment table is inconsistent, or the
242+
* data cannot be accessed.
243+
*/
244+
__ET_NODISCARD Error load_mutable_subsegment_into(
245+
size_t mutable_data_segments_index,
246+
size_t offset_index,
247+
size_t size,
248+
void* buffer) const;
249+
226250
private:
227251
Program(
228252
DataLoader* loader,

runtime/executor/test/program_test.cpp

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class ProgramTest : public ::testing::Test {
6262

6363
add_loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
6464

65-
// Load the serialized ModuleAdd data.
65+
// Load the serialized ModuleMultiEntry data.
6666
path = std::getenv("ET_MODULE_MULTI_ENTRY_PATH");
6767
Result<FileDataLoader> multi_loader = FileDataLoader::from(path);
6868
ASSERT_EQ(multi_loader.error(), Error::Ok);
@@ -98,6 +98,16 @@ class ProgramTestFriend final {
9898
return program->LoadSegment(segment_info);
9999
}
100100

101+
__ET_NODISCARD static Error load_mutable_subsegment_into(
102+
const Program* program,
103+
size_t mutable_data_segments_index,
104+
size_t offset_index,
105+
size_t size,
106+
void* buffer) {
107+
return program->load_mutable_subsegment_into(
108+
mutable_data_segments_index, offset_index, size, buffer);
109+
}
110+
101111
const static executorch_flatbuffer::Program* GetInternalProgram(
102112
const Program* program) {
103113
return program->internal_program_;
@@ -444,3 +454,89 @@ TEST_F(ProgramTest, LoadConstantSegmentWithNoConstantSegment) {
444454
// The constant buffer should exist.
445455
EXPECT_GE(flatbuffer_program->constant_buffer()->size(), 1);
446456
}
457+
458+
TEST_F(ProgramTest, LoadFromMutableSegment) {
459+
// Load the serialized ModuleSimpleTrain data.
460+
auto path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
461+
Result<FileDataLoader> training_loader = FileDataLoader::from(path);
462+
ASSERT_EQ(training_loader.error(), Error::Ok);
463+
464+
// This file should always be compatible.
465+
Result<FreeableBuffer> training_header = training_loader->load(
466+
/*offset=*/0,
467+
Program::kMinHeadBytes,
468+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::Program));
469+
ASSERT_EQ(training_header.error(), Error::Ok);
470+
EXPECT_EQ(
471+
Program::check_header(training_header->data(), training_header->size()),
472+
Program::HeaderStatus::CompatibleVersion);
473+
474+
Result<Program> program = Program::load(&training_loader.get());
475+
ASSERT_EQ(program.error(), Error::Ok);
476+
477+
// dummy buffers to load into
478+
uint8_t buffer[1] = {0};
479+
uint8_t buffer2[1] = {0};
480+
481+
// Load some mutable segment data
482+
Error err = ProgramTestFriend::load_mutable_subsegment_into(
483+
&program.get(), 0, 1, 1, buffer);
484+
EXPECT_EQ(err, Error::Ok);
485+
486+
// Check that the data loaded correctly, and then mutate it
487+
EXPECT_EQ(buffer[0], 232); // 232 comes from inspecting the file itself. The
488+
// file is seeded so this value should be stable.
489+
buffer[0] = 0;
490+
491+
// Load the same mutable segment data from file into a different buffer.
492+
err = ProgramTestFriend::load_mutable_subsegment_into(
493+
&program.get(),
494+
0, // mutable_data_segments_index
495+
1, // offset_index
496+
1, // size
497+
buffer2);
498+
EXPECT_EQ(err, Error::Ok);
499+
500+
// Check that new data loaded from the file does not reflect the change to
501+
// buffer.
502+
EXPECT_EQ(buffer2[0], 232);
503+
504+
const executorch_flatbuffer::Program* flatbuffer_program =
505+
ProgramTestFriend::GetInternalProgram(&program.get());
506+
507+
// Expect 1 segment. 1 mutable segment and no constant segment.
508+
EXPECT_EQ(flatbuffer_program->segments()->size(), 1);
509+
510+
// Expect a mutable data segment.
511+
EXPECT_EQ(flatbuffer_program->mutable_data_segments()->size(), 1);
512+
513+
// Expect the 0 index to be reserved and the offsets for weight and bias of
514+
// linear to be indices 1 and 2.
515+
EXPECT_EQ(
516+
flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->size(),
517+
3);
518+
EXPECT_EQ(
519+
flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(0),
520+
0);
521+
EXPECT_EQ(
522+
flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(1),
523+
0);
524+
EXPECT_EQ(
525+
flatbuffer_program->mutable_data_segments()->Get(0)->offsets()->Get(2),
526+
36);
527+
528+
// Loading beyond file should fail
529+
err = ProgramTestFriend::load_mutable_subsegment_into(
530+
&program.get(), 0, 1, 500, buffer);
531+
EXPECT_NE(err, Error::Ok);
532+
533+
// Loading beyond offsets should fail
534+
err = ProgramTestFriend::load_mutable_subsegment_into(
535+
&program.get(), 0, 500, 1, buffer);
536+
EXPECT_NE(err, Error::Ok);
537+
538+
// Loading beyond segments should fail
539+
err = ProgramTestFriend::load_mutable_subsegment_into(
540+
&program.get(), 500, 1, 1, buffer);
541+
EXPECT_NE(err, Error::Ok);
542+
}

runtime/executor/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def define_common_targets(is_fbcode = False):
107107
"ET_MODULE_LINEAR_CONSTANT_BUFFER_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleLinear-no-constant-segment.pte])",
108108
"ET_MODULE_LINEAR_CONSTANT_SEGMENT_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleLinear.pte])",
109109
"ET_MODULE_MULTI_ENTRY_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleMultipleEntry.pte])",
110+
"ET_MODULE_SIMPLE_TRAIN_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleSimpleTrain.pte])",
110111
}
111112

112113
runtime.cxx_test(

0 commit comments

Comments
 (0)