Skip to content

Commit a073158

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Use dependency injection for runner (#10326)
Summary: X-link: pytorch-labs/tokenizers#53 Pass in runner components, move most of the instantiation logic from `load()` to a new static API `create()`. This adds testability to runner components. Next step would be moving most of the logic out into `extension/llm/runner/` so that it can be used on non-llama models. Currently the logic for getting tokenizer instance should not assume llama, which I can modify in next diff. Differential Revision: D73165546
1 parent cb3eba0 commit a073158

File tree

11 files changed

+547
-120
lines changed

11 files changed

+547
-120
lines changed

examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ - (instancetype)initWithModelPath:(NSString*)modelPath
3131
self = [super init];
3232
if (self) {
3333
[ExecuTorchLog.sharedLog addSink:self];
34-
_runner = std::make_unique<example::Runner>(
34+
_runner = example::Runner::create(
3535
modelPath.UTF8String, tokenizerPath.UTF8String);
3636
}
3737
return self;

examples/models/llama/main.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,18 @@ int32_t main(int32_t argc, char** argv) {
8181
#endif
8282
// create llama runner
8383
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
84-
example::Runner runner(model_path, tokenizer_path, data_path);
84+
std::unique_ptr<example::Runner> runner =
85+
example::Runner::create(model_path, tokenizer_path, data_path);
8586

8687
if (warmup) {
8788
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
88-
runner.warmup(prompt, /*max_new_tokens=*/seq_len);
89+
runner->warmup(prompt, /*max_new_tokens=*/seq_len);
8990
}
9091
// generate
9192
executorch::extension::llm::GenerationConfig config{
9293
.seq_len = seq_len, .temperature = temperature};
9394
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
94-
runner.generate(prompt, config);
95+
runner->generate(prompt, config);
9596

9697
return 0;
9798
}

examples/models/llama/runner/runner.cpp

Lines changed: 121 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@
44
*
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
7+
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
78
*/
89

910
// A simple llama2 runner that includes preprocessing and post processing logic.
1011
// The module takes in a string as input and emits a string as output.
1112

1213
#include <executorch/examples/models/llama/runner/runner.h>
1314

14-
#include <algorithm>
15-
#include <ctime>
16-
1715
#include <executorch/extension/llm/runner/util.h>
1816

1917
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
@@ -62,125 +60,155 @@ std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer(
6260
}
6361
} // namespace
6462

65-
Runner::Runner(
63+
std::unique_ptr<Runner> Runner::create(
6664
const std::string& model_path,
6765
const std::string& tokenizer_path,
68-
std::optional<const std::string> data_path)
69-
// NOTE: we observed ~2x loading performance increase on iPhone 15
70-
// and a ~5% improvement on Galaxy S22 by switching to
71-
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
72-
: tokenizer_path_(tokenizer_path),
73-
metadata_({
74-
{kEnableDynamicShape, false},
75-
{kMaxSeqLen, 128},
76-
{kMaxContextLen, 128},
77-
{kUseKVCache, true},
78-
{kUseSDPAWithKVCache, false},
79-
}) {
80-
if (data_path.has_value()) {
81-
module_ = std::make_unique<Module>(
82-
model_path, data_path.value(), Module::LoadMode::File);
83-
} else {
84-
module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
85-
}
66+
std::optional<const std::string> data_path,
67+
float temperature) {
8668
ET_LOG(
8769
Info,
8870
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
8971
model_path.c_str(),
9072
tokenizer_path.c_str());
91-
}
9273

93-
[[deprecated(
94-
"This constructor is deprecated. Use the constructor without temperature parameter instead.")]]
95-
Runner::Runner(
96-
const std::string& model_path,
97-
const std::string& tokenizer_path,
98-
const float temperature,
99-
std::optional<const std::string> data_path)
100-
: Runner(model_path, tokenizer_path, std::move(data_path)) {
101-
temperature_ = temperature;
102-
}
103-
104-
bool Runner::is_loaded() const {
105-
return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&
106-
text_prefiller_ && text_token_generator_;
107-
}
108-
109-
Error Runner::load() {
110-
if (is_loaded()) {
111-
return Error::Ok;
74+
// Create the Module
75+
std::unique_ptr<Module> module;
76+
if (data_path.has_value()) {
77+
module = std::make_unique<Module>(
78+
model_path, data_path.value(), Module::LoadMode::File);
79+
} else {
80+
module = std::make_unique<Module>(model_path, Module::LoadMode::File);
11281
}
113-
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
11482

115-
// Load tokenizer.
116-
tokenizer_ = load_tokenizer(tokenizer_path_);
117-
if (tokenizer_ == nullptr) {
83+
// Initialize metadata with default values
84+
std::unordered_map<std::string, int64_t> metadata({
85+
{kEnableDynamicShape, false},
86+
{kMaxSeqLen, 128},
87+
{kMaxContextLen, 128},
88+
{kUseKVCache, true},
89+
{kUseSDPAWithKVCache, false},
90+
});
91+
92+
// Create and load tokenizer
93+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer =
94+
load_tokenizer(tokenizer_path);
95+
96+
// Fallback to BPE tokenizer if tiktoken fails
97+
if (tokenizer == nullptr) {
11898
ET_LOG(
11999
Info,
120-
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
121-
tokenizer_path_.c_str());
122-
tokenizer_.reset();
123-
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
124-
tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
125-
auto err = tokenizer_->load(tokenizer_path_);
126-
ET_CHECK_TK_OK_OR_RETURN_ERROR(
127-
err,
128-
"Failed to load %s as a llama2.c tokenizer artifact",
129-
tokenizer_path_.c_str());
130-
return ::executorch::runtime::Error::InvalidArgument;
100+
"Failed to load %s as a Tiktoken, Sentencepiece or Llama2.c tokenizer, make sure the artifact is one of these types",
101+
tokenizer_path.c_str());
102+
return nullptr;
131103
}
132104

133105
ET_LOG(Info, "Reading metadata from model");
134106

135-
metadata_[kBosId] = tokenizer_->bos_tok();
107+
// Set tokenizer-related metadata
108+
metadata[kBosId] = tokenizer->bos_tok();
136109
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(
137-
std::unordered_set<uint64_t>{tokenizer_->eos_tok()});
138-
metadata_[kVocabSize] = tokenizer_->vocab_size();
139-
140-
const auto method_names =
141-
ET_UNWRAP(module_->method_names(), "Failed reading method names");
110+
std::unordered_set<uint64_t>{tokenizer->eos_tok()});
111+
metadata[kVocabSize] = tokenizer->vocab_size();
112+
113+
// Read metadata from the model
114+
auto method_names_result = module->method_names();
115+
if (method_names_result.error() != Error::Ok) {
116+
ET_LOG(Error, "Failed reading method names");
117+
return nullptr;
118+
}
119+
const auto method_names = method_names_result.get();
142120

143-
for (auto& pair : metadata_) {
121+
for (auto& pair : metadata) {
144122
const auto& method_name = pair.first;
145123
auto& value = pair.second;
146124

147125
if (method_names.count(method_name)) {
148-
value = ET_UNWRAP(module_->get(method_name))
149-
.toScalar()
150-
.to<decltype(metadata_)::mapped_type>();
126+
auto get_result = module->get(method_name);
127+
value = get_result.get().toScalar().to<decltype(metadata)::mapped_type>();
151128
} else {
152129
ET_LOG(
153130
Info,
154-
"Methond %s not found, using the default value %" PRId64,
131+
"Method %s not found, using the default value %" PRId64,
155132
method_name.c_str(),
156133
value);
157134
}
158135
ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
159136
}
137+
138+
// Get EOS IDs if available
160139
if (method_names.count(kEosIds)) {
161140
eos_ids->clear();
162-
for (const auto& eos_id : ET_UNWRAP(module_->execute(kEosIds))) {
141+
auto execute_result = module->execute(kEosIds);
142+
if (execute_result.error() != Error::Ok) {
143+
ET_LOG(Error, "Failed to execute %s", kEosIds);
144+
return nullptr;
145+
}
146+
for (const auto& eos_id : execute_result.get()) {
163147
auto value = eos_id.toScalar().to<int64_t>();
164148
eos_ids->emplace(value);
165149
ET_LOG(Info, "eos_id = %" PRId64, value);
166150
}
167151
}
168-
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
169-
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
170-
module_.get(), metadata_.at(kUseKVCache));
171-
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
172-
text_decoder_runner_.get(),
173-
metadata_.at(kUseKVCache),
174-
metadata_.at(kEnableDynamicShape),
175-
metadata_.at(kMaxSeqLen));
176-
177-
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
178-
tokenizer_.get(),
179-
text_decoder_runner_.get(),
180-
metadata_.at(kUseKVCache),
152+
153+
// Create text_decoder_runner. Use a shared_ptr so that it can be shared with
154+
// TextPrefiller and TextTokenGenerator
155+
auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
156+
module.get(), metadata.at(kUseKVCache));
157+
158+
// Create text_prefiller
159+
auto text_prefiller = std::make_unique<llm::TextPrefiller>(
160+
text_decoder_runner.get(),
161+
metadata.at(kUseKVCache),
162+
metadata.at(kEnableDynamicShape),
163+
metadata.at(kMaxSeqLen));
164+
165+
// Create text_token_generator with stats
166+
auto stats = std::make_unique<llm::Stats>();
167+
auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
168+
tokenizer.get(),
169+
text_decoder_runner.get(),
170+
metadata.at(kUseKVCache),
181171
std::move(eos_ids),
182-
&stats_);
172+
stats.get());
173+
174+
// Create and return the Runner instance
175+
return std::make_unique<Runner>(
176+
std::move(metadata),
177+
std::move(tokenizer),
178+
std::move(text_prefiller),
179+
std::move(text_token_generator),
180+
std::move(stats),
181+
temperature);
182+
}
183+
184+
Runner::Runner(
185+
std::unordered_map<std::string, int64_t> metadata,
186+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
187+
std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
188+
std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
189+
text_token_generator,
190+
std::unique_ptr<::executorch::extension::llm::Stats> stats,
191+
float temperature)
192+
: tokenizer_(std::move(tokenizer)),
193+
metadata_(std::move(metadata)),
194+
text_prefiller_(std::move(text_prefiller)),
195+
text_token_generator_(std::move(text_token_generator)),
196+
stats_(std::move(stats)),
197+
temperature_(temperature) {
198+
// Note: This constructor assumes that text_prefiller and text_token_generator
199+
// already have references to the Module and TextDecoderRunner they need
200+
}
201+
202+
bool Runner::is_loaded() const {
203+
return text_prefiller_->is_loaded() && text_token_generator_->is_loaded();
204+
}
183205

206+
Error Runner::load() {
207+
if (is_loaded()) {
208+
return Error::Ok;
209+
}
210+
ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load());
211+
ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load());
184212
return Error::Ok;
185213
}
186214

@@ -201,9 +229,9 @@ Error Runner::generate(
201229
// Use ones-initialized inputs.
202230
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
203231
if (!is_loaded()) {
204-
stats_.model_load_start_ms = llm::time_in_ms();
232+
stats_->model_load_start_ms = llm::time_in_ms();
205233
ET_CHECK_OK_OR_RETURN_ERROR(load());
206-
stats_.model_load_end_ms = llm::time_in_ms();
234+
stats_->model_load_end_ms = llm::time_in_ms();
207235
}
208236

209237
if (config.warming) {
@@ -229,7 +257,7 @@ Error Runner::generate(
229257
// First token time only measures the time it takes to encode the prompt and
230258
// return a response token.
231259

232-
stats_.inference_start_ms = llm::time_in_ms();
260+
stats_->inference_start_ms = llm::time_in_ms();
233261
shouldStop_ = false;
234262

235263
::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
@@ -270,8 +298,8 @@ Error Runner::generate(
270298
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
271299
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
272300
uint64_t cur_token = prefill_res.get();
273-
stats_.first_token_ms = llm::time_in_ms();
274-
stats_.prompt_eval_end_ms = llm::time_in_ms();
301+
stats_->first_token_ms = llm::time_in_ms();
302+
stats_->prompt_eval_end_ms = llm::time_in_ms();
275303

276304
// print the first token from prefill. No prev_token so use cur_token for it.
277305
wrapped_callback(
@@ -292,7 +320,7 @@ Error Runner::generate(
292320
temperature_ == -1.0f ? config.temperature : temperature_,
293321
wrapped_callback));
294322

295-
stats_.inference_end_ms = llm::time_in_ms();
323+
stats_->inference_end_ms = llm::time_in_ms();
296324
if (!config.warming) {
297325
printf("\n");
298326
}
@@ -305,17 +333,17 @@ Error Runner::generate(
305333
RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens);
306334
}
307335

308-
stats_.num_prompt_tokens = num_prompt_tokens;
309-
stats_.num_generated_tokens = num_generated_tokens;
336+
stats_->num_prompt_tokens = num_prompt_tokens;
337+
stats_->num_generated_tokens = num_generated_tokens;
310338

311339
if (config.warming) {
312340
ET_LOG(Info, "Warmup run finished!");
313341
} else {
314342
// Do not print report during warmup
315-
::executorch::llm::print_report(stats_);
343+
::executorch::llm::print_report(*stats_);
316344
}
317345
if (stats_callback) {
318-
stats_callback(stats_);
346+
stats_callback(*stats_);
319347
}
320348

321349
return Error::Ok;

0 commit comments

Comments
 (0)