4
4
*
5
5
* This source code is licensed under the BSD-style license found in the
6
6
* LICENSE file in the root directory of this source tree.
7
+ * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
7
8
*/
8
9
9
10
// A simple llama2 runner that includes preprocessing and post processing logic.
10
11
// The module takes in a string as input and emits a string as output.
11
12
12
13
#include < executorch/examples/models/llama/runner/runner.h>
13
14
14
- #include < algorithm>
15
- #include < ctime>
16
-
17
15
#include < executorch/extension/llm/runner/util.h>
18
16
19
17
#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
@@ -62,125 +60,162 @@ std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer(
62
60
}
63
61
} // namespace
64
62
65
- Runner:: Runner (
63
+ std::unique_ptr< Runner> Runner::create (
66
64
const std::string& model_path,
67
65
const std::string& tokenizer_path,
68
- std::optional<const std::string> data_path)
69
- // NOTE: we observed ~2x loading performance increase on iPhone 15
70
- // and a ~5% improvement on Galaxy S22 by switching to
71
- // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
72
- : tokenizer_path_(tokenizer_path),
73
- metadata_ ({
74
- {kEnableDynamicShape , false },
75
- {kMaxSeqLen , 128 },
76
- {kMaxContextLen , 128 },
77
- {kUseKVCache , true },
78
- {kUseSDPAWithKVCache , false },
79
- }) {
80
- if (data_path.has_value ()) {
81
- module_ = std::make_unique<Module>(
82
- model_path, data_path.value (), Module::LoadMode::File);
83
- } else {
84
- module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
85
- }
66
+ std::optional<const std::string> data_path,
67
+ float temperature) {
86
68
ET_LOG (
87
69
Info,
88
70
" Creating LLaMa runner: model_path=%s, tokenizer_path=%s" ,
89
71
model_path.c_str (),
90
72
tokenizer_path.c_str ());
91
- }
92
73
93
- [[deprecated(
94
- " This constructor is deprecated. Use the constructor without temperature parameter instead." )]]
95
- Runner::Runner (
96
- const std::string& model_path,
97
- const std::string& tokenizer_path,
98
- const float temperature,
99
- std::optional<const std::string> data_path)
100
- : Runner(model_path, tokenizer_path, std::move(data_path)) {
101
- temperature_ = temperature;
102
- }
103
-
104
- bool Runner::is_loaded () const {
105
- return module_->is_loaded () && tokenizer_ && text_decoder_runner_ &&
106
- text_prefiller_ && text_token_generator_;
107
- }
108
-
109
- Error Runner::load () {
110
- if (is_loaded ()) {
111
- return Error::Ok;
74
+ // Create the Module
75
+ std::unique_ptr<Module> module ;
76
+ if (data_path.has_value ()) {
77
+ module = std::make_unique<Module>(
78
+ model_path, data_path.value (), Module::LoadMode::File);
79
+ } else {
80
+ module = std::make_unique<Module>(model_path, Module::LoadMode::File);
112
81
}
113
- ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
114
82
115
- // Load tokenizer.
116
- tokenizer_ = load_tokenizer (tokenizer_path_);
117
- if (tokenizer_ == nullptr ) {
83
+ // Initialize metadata with default values
84
+ std::unordered_map<std::string, int64_t > metadata ({
85
+ {kEnableDynamicShape , false },
86
+ {kMaxSeqLen , 128 },
87
+ {kMaxContextLen , 128 },
88
+ {kUseKVCache , true },
89
+ {kUseSDPAWithKVCache , false },
90
+ });
91
+
92
+ // Create and load tokenizer
93
+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer =
94
+ load_tokenizer (tokenizer_path);
95
+
96
+ // Fallback to BPE tokenizer if tiktoken fails
97
+ if (tokenizer == nullptr ) {
118
98
ET_LOG (
119
99
Info,
120
- " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
121
- tokenizer_path_.c_str ());
122
- tokenizer_.reset ();
123
- // @lint-ignore CLANGTIDY facebook-hte-Deprecated
124
- tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
125
- auto err = tokenizer_->load (tokenizer_path_);
126
- ET_CHECK_TK_OK_OR_RETURN_ERROR (
127
- err,
128
- " Failed to load %s as a llama2.c tokenizer artifact" ,
129
- tokenizer_path_.c_str ());
130
- return ::executorch::runtime::Error::InvalidArgument;
100
+ " Failed to load %s as a Tiktoken, Sentencepiece or Llama2.c tokenizer, make sure the artifact is one of these types" ,
101
+ tokenizer_path.c_str ());
102
+ return nullptr ;
131
103
}
132
104
133
105
ET_LOG (Info, " Reading metadata from model" );
134
106
135
- metadata_[kBosId ] = tokenizer_->bos_tok ();
107
+ // Set tokenizer-related metadata
108
+ metadata[kBosId ] = tokenizer->bos_tok ();
136
109
auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>(
137
- std::unordered_set<uint64_t >{tokenizer_->eos_tok ()});
138
- metadata_[kVocabSize ] = tokenizer_->vocab_size ();
139
-
140
- const auto method_names =
141
- ET_UNWRAP (module_->method_names (), " Failed reading method names" );
110
+ std::unordered_set<uint64_t >{tokenizer->eos_tok ()});
111
+ metadata[kVocabSize ] = tokenizer->vocab_size ();
112
+
113
+ // Read metadata from the model
114
+ auto method_names_result = module ->method_names ();
115
+ if (method_names_result.error () != Error::Ok) {
116
+ ET_LOG (Error, " Failed reading method names" );
117
+ return nullptr ;
118
+ }
119
+ const auto method_names = method_names_result.get ();
142
120
143
- for (auto & pair : metadata_ ) {
121
+ for (auto & pair : metadata ) {
144
122
const auto & method_name = pair.first ;
145
123
auto & value = pair.second ;
146
124
147
125
if (method_names.count (method_name)) {
148
- value = ET_UNWRAP (module_->get (method_name))
149
- .toScalar ()
150
- .to <decltype (metadata_)::mapped_type>();
126
+ auto get_result = module ->get (method_name);
127
+ value = get_result.get ().toScalar ().to <decltype (metadata)::mapped_type>();
151
128
} else {
152
129
ET_LOG (
153
130
Info,
154
- " Methond %s not found, using the default value %" PRId64,
131
+ " Method %s not found, using the default value %" PRId64,
155
132
method_name.c_str (),
156
133
value);
157
134
}
158
135
ET_LOG (Info, " Metadata: %s = %" PRId64, method_name.c_str (), value);
159
136
}
137
+
138
+ // Get EOS IDs if available
160
139
if (method_names.count (kEosIds )) {
161
140
eos_ids->clear ();
162
- for (const auto & eos_id : ET_UNWRAP (module_->execute (kEosIds ))) {
141
+ auto execute_result = module ->execute (kEosIds );
142
+ if (execute_result.error () != Error::Ok) {
143
+ ET_LOG (Error, " Failed to execute %s" , kEosIds );
144
+ return nullptr ;
145
+ }
146
+ for (const auto & eos_id : execute_result.get ()) {
163
147
auto value = eos_id.toScalar ().to <int64_t >();
164
148
eos_ids->emplace (value);
165
149
ET_LOG (Info, " eos_id = %" PRId64, value);
166
150
}
167
151
}
168
- // @lint-ignore CLANGTIDY facebook-hte-Deprecated
169
- text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
170
- module_.get (), metadata_.at (kUseKVCache ));
171
- text_prefiller_ = std::make_unique<llm::TextPrefiller>(
172
- text_decoder_runner_.get (),
173
- metadata_.at (kUseKVCache ),
174
- metadata_.at (kEnableDynamicShape ),
175
- metadata_.at (kMaxSeqLen ));
176
-
177
- text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
178
- tokenizer_.get (),
179
- text_decoder_runner_.get (),
180
- metadata_.at (kUseKVCache ),
152
+
153
+ // Create text_decoder_runner. Use a shared_ptr so that it can be shared with
154
+ // TextPrefiller and TextTokenGenerator
155
+ auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
156
+ module .get (), metadata.at (kUseKVCache ));
157
+
158
+ // Create text_prefiller
159
+ auto text_prefiller = std::make_unique<llm::TextPrefiller>(
160
+ text_decoder_runner.get (),
161
+ metadata.at (kUseKVCache ),
162
+ metadata.at (kEnableDynamicShape ),
163
+ metadata.at (kMaxSeqLen ));
164
+
165
+ // Create text_token_generator with stats
166
+ auto stats = std::make_unique<llm::Stats>();
167
+ auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
168
+ tokenizer.get (),
169
+ text_decoder_runner.get (),
170
+ metadata.at (kUseKVCache ),
181
171
std::move (eos_ids),
182
- &stats_);
172
+ stats.get ());
173
+
174
+ // Create and return the Runner instance
175
+ return std::make_unique<Runner>(
176
+ std::move (metadata),
177
+ std::move (tokenizer),
178
+ std::move (module ),
179
+ std::move (text_decoder_runner),
180
+ std::move (text_prefiller),
181
+ std::move (text_token_generator),
182
+ std::move (stats),
183
+ temperature);
184
+ }
185
+
186
+ Runner::Runner (
187
+ std::unordered_map<std::string, int64_t > metadata,
188
+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
189
+ std::unique_ptr<::executorch::extension::Module> module ,
190
+ std::unique_ptr<::executorch::extension::llm::TextDecoderRunner>
191
+ text_decoder_runner,
192
+ std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
193
+ std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
194
+ text_token_generator,
195
+ std::unique_ptr<::executorch::extension::llm::Stats> stats,
196
+ float temperature)
197
+ : tokenizer_(std::move(tokenizer)),
198
+ metadata_ (std::move(metadata)),
199
+ module_(std::move(module )),
200
+ text_decoder_runner_(std::move(text_decoder_runner)),
201
+ text_prefiller_(std::move(text_prefiller)),
202
+ text_token_generator_(std::move(text_token_generator)),
203
+ stats_(std::move(stats)),
204
+ temperature_(temperature) {
205
+ // Note: This constructor assumes that text_prefiller and text_token_generator
206
+ // already have references to the Module and TextDecoderRunner they need
207
+ }
208
+
209
+ bool Runner::is_loaded () const {
210
+ return text_prefiller_->is_loaded () && text_token_generator_->is_loaded ();
211
+ }
183
212
213
+ Error Runner::load () {
214
+ if (is_loaded ()) {
215
+ return Error::Ok;
216
+ }
217
+ ET_CHECK_OK_OR_RETURN_ERROR (text_prefiller_->load ());
218
+ ET_CHECK_OK_OR_RETURN_ERROR (text_token_generator_->load ());
184
219
return Error::Ok;
185
220
}
186
221
@@ -201,9 +236,9 @@ Error Runner::generate(
201
236
// Use ones-initialized inputs.
202
237
ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
203
238
if (!is_loaded ()) {
204
- stats_. model_load_start_ms = llm::time_in_ms ();
239
+ stats_-> model_load_start_ms = llm::time_in_ms ();
205
240
ET_CHECK_OK_OR_RETURN_ERROR (load ());
206
- stats_. model_load_end_ms = llm::time_in_ms ();
241
+ stats_-> model_load_end_ms = llm::time_in_ms ();
207
242
}
208
243
209
244
if (config.warming ) {
@@ -229,7 +264,7 @@ Error Runner::generate(
229
264
// First token time only measures the time it takes to encode the prompt and
230
265
// return a response token.
231
266
232
- stats_. inference_start_ms = llm::time_in_ms ();
267
+ stats_-> inference_start_ms = llm::time_in_ms ();
233
268
shouldStop_ = false ;
234
269
235
270
::tokenizers::Result<std::vector<uint64_t >> encode_res = tokenizer_->encode (
@@ -270,8 +305,8 @@ Error Runner::generate(
270
305
auto prefill_res = text_prefiller_->prefill (prompt_tokens, pos);
271
306
ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
272
307
uint64_t cur_token = prefill_res.get ();
273
- stats_. first_token_ms = llm::time_in_ms ();
274
- stats_. prompt_eval_end_ms = llm::time_in_ms ();
308
+ stats_-> first_token_ms = llm::time_in_ms ();
309
+ stats_-> prompt_eval_end_ms = llm::time_in_ms ();
275
310
276
311
// print the first token from prefill. No prev_token so use cur_token for it.
277
312
wrapped_callback (
@@ -292,7 +327,7 @@ Error Runner::generate(
292
327
temperature_ == -1 .0f ? config.temperature : temperature_,
293
328
wrapped_callback));
294
329
295
- stats_. inference_end_ms = llm::time_in_ms ();
330
+ stats_-> inference_end_ms = llm::time_in_ms ();
296
331
if (!config.warming ) {
297
332
printf (" \n " );
298
333
}
@@ -305,17 +340,17 @@ Error Runner::generate(
305
340
RUNNER_ET_LOG (config.warming , " Max new tokens %i reached!" , max_new_tokens);
306
341
}
307
342
308
- stats_. num_prompt_tokens = num_prompt_tokens;
309
- stats_. num_generated_tokens = num_generated_tokens;
343
+ stats_-> num_prompt_tokens = num_prompt_tokens;
344
+ stats_-> num_generated_tokens = num_generated_tokens;
310
345
311
346
if (config.warming ) {
312
347
ET_LOG (Info, " Warmup run finished!" );
313
348
} else {
314
349
// Do not print report during warmup
315
- ::executorch::llm::print_report (stats_);
350
+ ::executorch::llm::print_report (* stats_);
316
351
}
317
352
if (stats_callback) {
318
- stats_callback (stats_);
353
+ stats_callback (* stats_);
319
354
}
320
355
321
356
return Error::Ok;
@@ -329,8 +364,8 @@ Error Runner::warmup(const std::string& prompt, int32_t max_new_tokens) {
329
364
// Call generate with the warmup config
330
365
Error err = generate (prompt, config);
331
366
332
- // Reset stats after warmup
333
- stats_. reset ();
367
+ // Reset stats after warmup, not resetting the std::unique_ptr!
368
+ stats_-> reset ();
334
369
return err;
335
370
}
336
371
0 commit comments