Skip to content

Commit 5a942f2

Browse files
denerscggerganov
authored andcommitted
whisper : token-level timestamps with DTW (ggml-org#1485)
* whisper.cpp: impl dtw algo * WIP: producing and placing DTW timestamps on tokens * Fix compile and assertion errors. Attempt to DTW timestamp with single_segment=false. * Fix mistake causing incorrect alignment of dtw timestamps * implement N_TOP_MOST and CUSTOM alignment heads setting * whisper: fix typo on alignment heads enum * Fix issues related to changes in whisper.cpp * Fixed excessive memory use when using DTW timestamps. Other minor fixes to DTW timestamping function * decoder: save cross QKs only if requested * Calling median filter with ggml_map_custom1 * Reimpl aheads n_top_most and custom. Sanity checks on chosen aheads * Copying cross QKs from decoder backend correctly * dtw: cleanup * Fix incorrect n_frames passed to dtw when near end of audio * Fix aheads_masks_init for backend != CPU * whisper : minor style * main : add dtw (wip) * whisper: fix invalid memory access in aheads_masks_init * main : add dtw (cont) * whisper : minor --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent e517ab6 commit 5a942f2

File tree

3 files changed

+652
-21
lines changed

3 files changed

+652
-21
lines changed

examples/main/main.cpp

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,17 @@ void replace_all(std::string & s, const std::string & search, const std::string
2626

2727
// command-line parameters
2828
struct whisper_params {
29-
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
30-
int32_t n_processors = 1;
31-
int32_t offset_t_ms = 0;
32-
int32_t offset_n = 0;
33-
int32_t duration_ms = 0;
34-
int32_t progress_step = 5;
35-
int32_t max_context = -1;
36-
int32_t max_len = 0;
37-
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
38-
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
39-
int32_t audio_ctx = 0;
29+
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
30+
int32_t n_processors = 1;
31+
int32_t offset_t_ms = 0;
32+
int32_t offset_n = 0;
33+
int32_t duration_ms = 0;
34+
int32_t progress_step = 5;
35+
int32_t max_context = -1;
36+
int32_t max_len = 0;
37+
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
38+
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
39+
int32_t audio_ctx = 0;
4040

4141
float word_thold = 0.01f;
4242
float entropy_thold = 2.40f;
@@ -76,6 +76,8 @@ struct whisper_params {
7676

7777
std::string openvino_encode_device = "CPU";
7878

79+
std::string dtw = "";
80+
7981
std::vector<std::string> fname_inp = {};
8082
std::vector<std::string> fname_out = {};
8183
};
@@ -149,6 +151,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
149151
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
150152
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
151153
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
154+
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
152155
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
153156
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
154157
else {
@@ -208,6 +211,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
208211
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
209212
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
210213
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
214+
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
211215
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
212216
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
213217
fprintf(stderr, "\n");
@@ -649,7 +653,8 @@ bool output_json(
649653
times_o(token.t0, token.t1, false);
650654
}
651655
value_i("id", token.id, false);
652-
value_f("p", token.p, true);
656+
value_f("p", token.p, false);
657+
value_f("t_dtw", token.t_dtw, true);
653658
end_obj(j == (n - 1));
654659
}
655660
end_arr(!params.diarize && !params.tinydiarize);
@@ -889,6 +894,28 @@ int main(int argc, char ** argv) {
889894
struct whisper_context_params cparams = whisper_context_default_params();
890895
cparams.use_gpu = params.use_gpu;
891896

897+
if (!params.dtw.empty()) {
898+
cparams.dtw_token_timestamps = true;
899+
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
900+
901+
if (params.dtw == "tiny") cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY;
902+
if (params.dtw == "tiny.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN;
903+
if (params.dtw == "base") cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE;
904+
if (params.dtw == "base.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN;
905+
if (params.dtw == "small") cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL;
906+
if (params.dtw == "small.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN;
907+
if (params.dtw == "medium") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM;
908+
if (params.dtw == "medium.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN;
909+
if (params.dtw == "large.v1") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1;
910+
if (params.dtw == "large.v2") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2;
911+
if (params.dtw == "large.v3") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3;
912+
913+
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
914+
fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str());
915+
return 3;
916+
}
917+
}
918+
892919
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
893920

894921
if (ctx == nullptr) {

0 commit comments

Comments
 (0)