@@ -62,7 +62,7 @@ class ProgramTest : public ::testing::Test {
62
62
63
63
add_loader_ = std::make_unique<FileDataLoader>(std::move (loader.get ()));
64
64
65
- // Load the serialized ModuleAdd data.
65
+ // Load the serialized ModuleMultiEntry data.
66
66
path = std::getenv (" ET_MODULE_MULTI_ENTRY_PATH" );
67
67
Result<FileDataLoader> multi_loader = FileDataLoader::from (path);
68
68
ASSERT_EQ (multi_loader.error (), Error::Ok);
@@ -98,6 +98,16 @@ class ProgramTestFriend final {
98
98
return program->LoadSegment (segment_info);
99
99
}
100
100
101
+ __ET_NODISCARD static Error load_mutable_subsegment_into (
102
+ const Program* program,
103
+ size_t mutable_data_segments_index,
104
+ size_t offset_index,
105
+ size_t size,
106
+ void * buffer) {
107
+ return program->load_mutable_subsegment_into (
108
+ mutable_data_segments_index, offset_index, size, buffer);
109
+ }
110
+
101
111
const static executorch_flatbuffer::Program* GetInternalProgram (
102
112
const Program* program) {
103
113
return program->internal_program_ ;
@@ -444,3 +454,89 @@ TEST_F(ProgramTest, LoadConstantSegmentWithNoConstantSegment) {
444
454
// The constant buffer should exist.
445
455
EXPECT_GE (flatbuffer_program->constant_buffer ()->size (), 1 );
446
456
}
457
+
458
+ TEST_F (ProgramTest, LoadFromMutableSegment) {
459
+ // Load the serialized ModuleSimpleTrain data.
460
+ auto path = std::getenv (" ET_MODULE_SIMPLE_TRAIN_PATH" );
461
+ Result<FileDataLoader> training_loader = FileDataLoader::from (path);
462
+ ASSERT_EQ (training_loader.error (), Error::Ok);
463
+
464
+ // This file should always be compatible.
465
+ Result<FreeableBuffer> training_header = training_loader->load (
466
+ /* offset=*/ 0 ,
467
+ Program::kMinHeadBytes ,
468
+ DataLoader::SegmentInfo (DataLoader::SegmentInfo::Type::Program));
469
+ ASSERT_EQ (training_header.error (), Error::Ok);
470
+ EXPECT_EQ (
471
+ Program::check_header (training_header->data (), training_header->size ()),
472
+ Program::HeaderStatus::CompatibleVersion);
473
+
474
+ Result<Program> program = Program::load (&training_loader.get ());
475
+ ASSERT_EQ (program.error (), Error::Ok);
476
+
477
+ // dummy buffers to load into
478
+ uint8_t buffer[1 ] = {0 };
479
+ uint8_t buffer2[1 ] = {0 };
480
+
481
+ // Load some mutable segment data
482
+ Error err = ProgramTestFriend::load_mutable_subsegment_into (
483
+ &program.get (), 0 , 1 , 1 , buffer);
484
+ EXPECT_EQ (err, Error::Ok);
485
+
486
+ // Check that the data loaded correctly, and then mutate it
487
+ EXPECT_EQ (buffer[0 ], 232 ); // 232 comes from inspecting the file itself. The
488
+ // file is seeded so this value should be stable.
489
+ buffer[0 ] = 0 ;
490
+
491
+ // Load the same mutable segment data from file into a different buffer.
492
+ err = ProgramTestFriend::load_mutable_subsegment_into (
493
+ &program.get (),
494
+ 0 , // mutable_data_segments_index
495
+ 1 , // offset_index
496
+ 1 , // size
497
+ buffer2);
498
+ EXPECT_EQ (err, Error::Ok);
499
+
500
+ // Check that new data loaded from the file does not reflect the change to
501
+ // buffer.
502
+ EXPECT_EQ (buffer2[0 ], 232 );
503
+
504
+ const executorch_flatbuffer::Program* flatbuffer_program =
505
+ ProgramTestFriend::GetInternalProgram (&program.get ());
506
+
507
+ // Expect 1 segment. 1 mutable segment and no constant segment.
508
+ EXPECT_EQ (flatbuffer_program->segments ()->size (), 1 );
509
+
510
+ // Expect a mutable data segment.
511
+ EXPECT_EQ (flatbuffer_program->mutable_data_segments ()->size (), 1 );
512
+
513
+ // Expect the 0 index to be reserved and the offsets for weight and bias of
514
+ // linear to be indices 1 and 2.
515
+ EXPECT_EQ (
516
+ flatbuffer_program->mutable_data_segments ()->Get (0 )->offsets ()->size (),
517
+ 3 );
518
+ EXPECT_EQ (
519
+ flatbuffer_program->mutable_data_segments ()->Get (0 )->offsets ()->Get (0 ),
520
+ 0 );
521
+ EXPECT_EQ (
522
+ flatbuffer_program->mutable_data_segments ()->Get (0 )->offsets ()->Get (1 ),
523
+ 0 );
524
+ EXPECT_EQ (
525
+ flatbuffer_program->mutable_data_segments ()->Get (0 )->offsets ()->Get (2 ),
526
+ 36 );
527
+
528
+ // Loading beyond file should fail
529
+ err = ProgramTestFriend::load_mutable_subsegment_into (
530
+ &program.get (), 0 , 1 , 500 , buffer);
531
+ EXPECT_NE (err, Error::Ok);
532
+
533
+ // Loading beyond offsets should fail
534
+ err = ProgramTestFriend::load_mutable_subsegment_into (
535
+ &program.get (), 0 , 500 , 1 , buffer);
536
+ EXPECT_NE (err, Error::Ok);
537
+
538
+ // Loading beyond segments should fail
539
+ err = ProgramTestFriend::load_mutable_subsegment_into (
540
+ &program.get (), 500 , 1 , 1 , buffer);
541
+ EXPECT_NE (err, Error::Ok);
542
+ }
0 commit comments