Skip to content

Commit 85d1e5b

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Use dependency injection for runner (#10326)
Summary: Pull Request resolved: #10326 X-link: pytorch-labs/tokenizers#53 Pass in runner components, move most of the instantiation logic from `load()` to a new static API `create()`. This adds testability to runner components. Next step would be moving most of the logic out into `extension/llm/runner/` so that it can be used on non-llama models. Currently the logic for getting tokenizer instance should not assume llama, which I can modify in next diff. Reviewed By: kirklandsign, iseeyuan Differential Revision: D73165546
1 parent 1b063ca commit 85d1e5b

File tree

13 files changed

+598
-128
lines changed

13 files changed

+598
-128
lines changed

examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ - (instancetype)initWithModelPath:(NSString*)modelPath
3131
self = [super init];
3232
if (self) {
3333
[ExecuTorchLog.sharedLog addSink:self];
34-
_runner = std::make_unique<example::Runner>(
34+
_runner = example::Runner::create(
3535
modelPath.UTF8String, tokenizerPath.UTF8String);
3636
}
3737
return self;

examples/models/llama/main.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
*
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
7+
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
78
*/
89

910
#include <gflags/gflags.h>
@@ -80,18 +81,16 @@ int32_t main(int32_t argc, char** argv) {
8081
}
8182
#endif
8283
// create llama runner
83-
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
84-
example::Runner runner(model_path, tokenizer_path, data_path);
84+
std::unique_ptr<example::Runner> runner =
85+
example::Runner::create(model_path, tokenizer_path, data_path);
8586

8687
if (warmup) {
87-
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
88-
runner.warmup(prompt, /*max_new_tokens=*/seq_len);
88+
runner->warmup(prompt, /*max_new_tokens=*/seq_len);
8989
}
9090
// generate
9191
executorch::extension::llm::GenerationConfig config{
9292
.seq_len = seq_len, .temperature = temperature};
93-
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
94-
runner.generate(prompt, config);
93+
runner->generate(prompt, config);
9594

9695
return 0;
9796
}

examples/models/llama/runner/runner.cpp

Lines changed: 130 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@
44
*
55
* This source code is licensed under the BSD-style license found in the
66
* LICENSE file in the root directory of this source tree.
7+
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
78
*/
89

910
// A simple llama2 runner that includes preprocessing and post processing logic.
1011
// The module takes in a string as input and emits a string as output.
1112

1213
#include <executorch/examples/models/llama/runner/runner.h>
1314

14-
#include <algorithm>
15-
#include <ctime>
16-
1715
#include <executorch/extension/llm/runner/util.h>
1816

1917
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
@@ -62,125 +60,162 @@ std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer(
6260
}
6361
} // namespace
6462

65-
Runner::Runner(
63+
std::unique_ptr<Runner> Runner::create(
6664
const std::string& model_path,
6765
const std::string& tokenizer_path,
68-
std::optional<const std::string> data_path)
69-
// NOTE: we observed ~2x loading performance increase on iPhone 15
70-
// and a ~5% improvement on Galaxy S22 by switching to
71-
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
72-
: tokenizer_path_(tokenizer_path),
73-
metadata_({
74-
{kEnableDynamicShape, false},
75-
{kMaxSeqLen, 128},
76-
{kMaxContextLen, 128},
77-
{kUseKVCache, true},
78-
{kUseSDPAWithKVCache, false},
79-
}) {
80-
if (data_path.has_value()) {
81-
module_ = std::make_unique<Module>(
82-
model_path, data_path.value(), Module::LoadMode::File);
83-
} else {
84-
module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
85-
}
66+
std::optional<const std::string> data_path,
67+
float temperature) {
8668
ET_LOG(
8769
Info,
8870
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
8971
model_path.c_str(),
9072
tokenizer_path.c_str());
91-
}
9273

93-
[[deprecated(
94-
"This constructor is deprecated. Use the constructor without temperature parameter instead.")]]
95-
Runner::Runner(
96-
const std::string& model_path,
97-
const std::string& tokenizer_path,
98-
const float temperature,
99-
std::optional<const std::string> data_path)
100-
: Runner(model_path, tokenizer_path, std::move(data_path)) {
101-
temperature_ = temperature;
102-
}
103-
104-
bool Runner::is_loaded() const {
105-
return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&
106-
text_prefiller_ && text_token_generator_;
107-
}
108-
109-
Error Runner::load() {
110-
if (is_loaded()) {
111-
return Error::Ok;
74+
// Create the Module
75+
std::unique_ptr<Module> module;
76+
if (data_path.has_value()) {
77+
module = std::make_unique<Module>(
78+
model_path, data_path.value(), Module::LoadMode::File);
79+
} else {
80+
module = std::make_unique<Module>(model_path, Module::LoadMode::File);
11281
}
113-
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
11482

115-
// Load tokenizer.
116-
tokenizer_ = load_tokenizer(tokenizer_path_);
117-
if (tokenizer_ == nullptr) {
83+
// Initialize metadata with default values
84+
std::unordered_map<std::string, int64_t> metadata({
85+
{kEnableDynamicShape, false},
86+
{kMaxSeqLen, 128},
87+
{kMaxContextLen, 128},
88+
{kUseKVCache, true},
89+
{kUseSDPAWithKVCache, false},
90+
});
91+
92+
// Create and load tokenizer
93+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer =
94+
load_tokenizer(tokenizer_path);
95+
96+
// Fallback to BPE tokenizer if tiktoken fails
97+
if (tokenizer == nullptr) {
11898
ET_LOG(
11999
Info,
120-
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
121-
tokenizer_path_.c_str());
122-
tokenizer_.reset();
123-
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
124-
tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
125-
auto err = tokenizer_->load(tokenizer_path_);
126-
ET_CHECK_TK_OK_OR_RETURN_ERROR(
127-
err,
128-
"Failed to load %s as a llama2.c tokenizer artifact",
129-
tokenizer_path_.c_str());
130-
return ::executorch::runtime::Error::InvalidArgument;
100+
"Failed to load %s as a Tiktoken, Sentencepiece or Llama2.c tokenizer, make sure the artifact is one of these types",
101+
tokenizer_path.c_str());
102+
return nullptr;
131103
}
132104

133105
ET_LOG(Info, "Reading metadata from model");
134106

135-
metadata_[kBosId] = tokenizer_->bos_tok();
107+
// Set tokenizer-related metadata
108+
metadata[kBosId] = tokenizer->bos_tok();
136109
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(
137-
std::unordered_set<uint64_t>{tokenizer_->eos_tok()});
138-
metadata_[kVocabSize] = tokenizer_->vocab_size();
139-
140-
const auto method_names =
141-
ET_UNWRAP(module_->method_names(), "Failed reading method names");
110+
std::unordered_set<uint64_t>{tokenizer->eos_tok()});
111+
metadata[kVocabSize] = tokenizer->vocab_size();
112+
113+
// Read metadata from the model
114+
auto method_names_result = module->method_names();
115+
if (method_names_result.error() != Error::Ok) {
116+
ET_LOG(Error, "Failed reading method names");
117+
return nullptr;
118+
}
119+
const auto method_names = method_names_result.get();
142120

143-
for (auto& pair : metadata_) {
121+
for (auto& pair : metadata) {
144122
const auto& method_name = pair.first;
145123
auto& value = pair.second;
146124

147125
if (method_names.count(method_name)) {
148-
value = ET_UNWRAP(module_->get(method_name))
149-
.toScalar()
150-
.to<decltype(metadata_)::mapped_type>();
126+
auto get_result = module->get(method_name);
127+
value = get_result.get().toScalar().to<decltype(metadata)::mapped_type>();
151128
} else {
152129
ET_LOG(
153130
Info,
154-
"Methond %s not found, using the default value %" PRId64,
131+
"Method %s not found, using the default value %" PRId64,
155132
method_name.c_str(),
156133
value);
157134
}
158135
ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
159136
}
137+
138+
// Get EOS IDs if available
160139
if (method_names.count(kEosIds)) {
161140
eos_ids->clear();
162-
for (const auto& eos_id : ET_UNWRAP(module_->execute(kEosIds))) {
141+
auto execute_result = module->execute(kEosIds);
142+
if (execute_result.error() != Error::Ok) {
143+
ET_LOG(Error, "Failed to execute %s", kEosIds);
144+
return nullptr;
145+
}
146+
for (const auto& eos_id : execute_result.get()) {
163147
auto value = eos_id.toScalar().to<int64_t>();
164148
eos_ids->emplace(value);
165149
ET_LOG(Info, "eos_id = %" PRId64, value);
166150
}
167151
}
168-
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
169-
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
170-
module_.get(), metadata_.at(kUseKVCache));
171-
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
172-
text_decoder_runner_.get(),
173-
metadata_.at(kUseKVCache),
174-
metadata_.at(kEnableDynamicShape),
175-
metadata_.at(kMaxSeqLen));
176-
177-
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
178-
tokenizer_.get(),
179-
text_decoder_runner_.get(),
180-
metadata_.at(kUseKVCache),
152+
153+
// Create text_decoder_runner. Use a shared_ptr so that it can be shared with
154+
// TextPrefiller and TextTokenGenerator
155+
auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
156+
module.get(), metadata.at(kUseKVCache));
157+
158+
// Create text_prefiller
159+
auto text_prefiller = std::make_unique<llm::TextPrefiller>(
160+
text_decoder_runner.get(),
161+
metadata.at(kUseKVCache),
162+
metadata.at(kEnableDynamicShape),
163+
metadata.at(kMaxSeqLen));
164+
165+
// Create text_token_generator with stats
166+
auto stats = std::make_unique<llm::Stats>();
167+
auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
168+
tokenizer.get(),
169+
text_decoder_runner.get(),
170+
metadata.at(kUseKVCache),
181171
std::move(eos_ids),
182-
&stats_);
172+
stats.get());
173+
174+
// Create and return the Runner instance
175+
return std::make_unique<Runner>(
176+
std::move(metadata),
177+
std::move(tokenizer),
178+
std::move(module),
179+
std::move(text_decoder_runner),
180+
std::move(text_prefiller),
181+
std::move(text_token_generator),
182+
std::move(stats),
183+
temperature);
184+
}
185+
186+
Runner::Runner(
187+
std::unordered_map<std::string, int64_t> metadata,
188+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
189+
std::unique_ptr<::executorch::extension::Module> module,
190+
std::unique_ptr<::executorch::extension::llm::TextDecoderRunner>
191+
text_decoder_runner,
192+
std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
193+
std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
194+
text_token_generator,
195+
std::unique_ptr<::executorch::extension::llm::Stats> stats,
196+
float temperature)
197+
: tokenizer_(std::move(tokenizer)),
198+
metadata_(std::move(metadata)),
199+
module_(std::move(module)),
200+
text_decoder_runner_(std::move(text_decoder_runner)),
201+
text_prefiller_(std::move(text_prefiller)),
202+
text_token_generator_(std::move(text_token_generator)),
203+
stats_(std::move(stats)),
204+
temperature_(temperature) {
205+
// Note: This constructor assumes that text_prefiller and text_token_generator
206+
// already have references to the Module and TextDecoderRunner they need
207+
}
208+
209+
bool Runner::is_loaded() const {
210+
return text_prefiller_->is_loaded() && text_token_generator_->is_loaded();
211+
}
183212

213+
Error Runner::load() {
214+
if (is_loaded()) {
215+
return Error::Ok;
216+
}
217+
ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load());
218+
ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load());
184219
return Error::Ok;
185220
}
186221

@@ -201,9 +236,9 @@ Error Runner::generate(
201236
// Use ones-initialized inputs.
202237
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
203238
if (!is_loaded()) {
204-
stats_.model_load_start_ms = llm::time_in_ms();
239+
stats_->model_load_start_ms = llm::time_in_ms();
205240
ET_CHECK_OK_OR_RETURN_ERROR(load());
206-
stats_.model_load_end_ms = llm::time_in_ms();
241+
stats_->model_load_end_ms = llm::time_in_ms();
207242
}
208243

209244
if (config.warming) {
@@ -229,7 +264,7 @@ Error Runner::generate(
229264
// First token time only measures the time it takes to encode the prompt and
230265
// return a response token.
231266

232-
stats_.inference_start_ms = llm::time_in_ms();
267+
stats_->inference_start_ms = llm::time_in_ms();
233268
shouldStop_ = false;
234269

235270
::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
@@ -270,8 +305,8 @@ Error Runner::generate(
270305
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
271306
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
272307
uint64_t cur_token = prefill_res.get();
273-
stats_.first_token_ms = llm::time_in_ms();
274-
stats_.prompt_eval_end_ms = llm::time_in_ms();
308+
stats_->first_token_ms = llm::time_in_ms();
309+
stats_->prompt_eval_end_ms = llm::time_in_ms();
275310

276311
// print the first token from prefill. No prev_token so use cur_token for it.
277312
wrapped_callback(
@@ -292,7 +327,7 @@ Error Runner::generate(
292327
temperature_ == -1.0f ? config.temperature : temperature_,
293328
wrapped_callback));
294329

295-
stats_.inference_end_ms = llm::time_in_ms();
330+
stats_->inference_end_ms = llm::time_in_ms();
296331
if (!config.warming) {
297332
printf("\n");
298333
}
@@ -305,17 +340,17 @@ Error Runner::generate(
305340
RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens);
306341
}
307342

308-
stats_.num_prompt_tokens = num_prompt_tokens;
309-
stats_.num_generated_tokens = num_generated_tokens;
343+
stats_->num_prompt_tokens = num_prompt_tokens;
344+
stats_->num_generated_tokens = num_generated_tokens;
310345

311346
if (config.warming) {
312347
ET_LOG(Info, "Warmup run finished!");
313348
} else {
314349
// Do not print report during warmup
315-
::executorch::llm::print_report(stats_);
350+
::executorch::llm::print_report(*stats_);
316351
}
317352
if (stats_callback) {
318-
stats_callback(stats_);
353+
stats_callback(*stats_);
319354
}
320355

321356
return Error::Ok;
@@ -329,8 +364,8 @@ Error Runner::warmup(const std::string& prompt, int32_t max_new_tokens) {
329364
// Call generate with the warmup config
330365
Error err = generate(prompt, config);
331366

332-
// Reset stats after warmup
333-
stats_.reset();
367+
// Reset stats after warmup, not resetting the std::unique_ptr!
368+
stats_->reset();
334369
return err;
335370
}
336371

0 commit comments

Comments
 (0)