Skip to content

Full implementation of multi-dictionary support #238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 47 commits into from
Mar 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
da622be
Add multi-dictionary preprocessing and training
yqzhishen Jul 15, 2024
a151ecf
Fix lang_map.json copy
yqzhishen Jul 15, 2024
c2b4c5f
Merge branch 'refs/heads/main' into multi-dict
yqzhishen Jul 15, 2024
b5a876b
Add language embed (inject to txt_embed) for acoustic models
yqzhishen Jul 16, 2024
d282e28
Save language sequence in variance preprocessing
yqzhishen Jul 18, 2024
70676cc
Display merged phoneme groups properly in distribution plots
yqzhishen Jul 18, 2024
62c093e
Add multi-dictionary inference
yqzhishen Jul 19, 2024
96b9a60
Save original phoneme texts for duration plots
yqzhishen Jul 19, 2024
dbe3840
Fix duration plots displaying bug
yqzhishen Jul 19, 2024
b5f20a5
Explicit `languages` argument passing
yqzhishen Jul 19, 2024
a7dbb93
Add language embed (inject to txt_embed) for variance models
yqzhishen Jul 20, 2024
2a17561
Fix argument passing
yqzhishen Jul 20, 2024
8b215db
Add log for lang_map.json copy
yqzhishen Jul 20, 2024
6f80697
Add language embedding scale
yqzhishen Jul 20, 2024
655e9ba
Add language embedding type
yqzhishen Jul 21, 2024
c6b96cf
Preprocessing: only apply lang embed on cross-lingual phonemes
yqzhishen Jul 22, 2024
8377728
Inference: only apply lang embed on cross-lingual phonemes
yqzhishen Jul 24, 2024
3d0a9ba
Revert "Add language embedding type"
yqzhishen Jul 24, 2024
932c4f4
Revert lang_embed_scale
yqzhishen Jul 27, 2024
a0ec7e3
Adapt ONNX exporters for multi-language models
yqzhishen Jul 27, 2024
4a4b2b0
Refactor configuration schemas for datasets
yqzhishen Jul 27, 2024
678e3e6
Add check of existence for merged phonemes
yqzhishen Jul 28, 2024
d0d7b73
Fix spk_id assignment
yqzhishen Jul 28, 2024
f3a969c
Fix languages.json filename
yqzhishen Jul 28, 2024
bf44910
Fix `languages` key in dsconfig.yaml
yqzhishen Jul 28, 2024
fb5f589
Set `use_lang_id` to false if there are no cross-lingual phonemes
yqzhishen Jul 31, 2024
333d9ef
Support defining extra phonemes
yqzhishen Aug 2, 2024
d3cd5cd
Refactor configs
yqzhishen Aug 2, 2024
f729db8
Prefer file copies in work_dir when loading dictionaries
yqzhishen Aug 3, 2024
453cb0f
Fix cannot locate dictionary
yqzhishen Aug 4, 2024
663db52
Fix unexpected loading error when dictionary changes
yqzhishen Aug 17, 2024
8de1f72
Merge branch 'main' into multi-dict
yqzhishen Nov 15, 2024
6c7bb08
Update toplevel.py (#219)
AnAndroNerd Nov 16, 2024
da79ef2
Fix unexpected config passing
yqzhishen Jan 4, 2025
5d56329
Update lynxnet backbone (#228)
yxlllc Jan 16, 2025
3f8bc85
Improve fastspeech2 encoder using Rotary Position Embedding (RoPE) in…
yxlllc Feb 10, 2025
575d0ab
support mini-nsf-hifigan vocoder
yxlllc Aug 17, 2024
51da9ec
discard negative pad
yxlllc Aug 18, 2024
960bf90
fix MHA inference using low torch version
yxlllc Feb 14, 2025
84b32ed
Fix missing phoneme list sorting
yqzhishen Feb 16, 2025
7741b55
Fix single-language dictionary parsing language tag
yqzhishen Feb 17, 2025
58edd2f
Add `pitch_controllable` flag to vocoder exporter
yqzhishen Mar 22, 2025
38335bf
support noise injection
yxlllc Mar 22, 2025
4a56fce
Allow merging global phonemes and language-specific phonemes
yqzhishen Mar 28, 2025
21a0f6b
Check for conflicts between short names and global tags
yqzhishen Mar 28, 2025
7b58b46
Finish documentation for multi-dictionary
yqzhishen Mar 28, 2025
7fc5686
Merge branch 'main' into multi-dict
yqzhishen Mar 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 105 additions & 66 deletions basics/base_binarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from utils.hparams import hparams
from utils.indexed_datasets import IndexedDatasetBuilder
from utils.multiprocess_utils import chunked_multiprocess_run
from utils.phoneme_utils import build_phoneme_list, locate_dictionary
from utils.phoneme_utils import load_phoneme_dictionary
from utils.plot import distribution_to_figure
from utils.text_encoder import TokenTextEncoder


class BinarizationError(Exception):
Expand Down Expand Up @@ -44,73 +43,88 @@ class BaseBinarizer:
the phoneme set.
"""

def __init__(self, data_dir=None, data_attrs=None):
if data_dir is None:
data_dir = hparams['raw_data_dir']
if not isinstance(data_dir, list):
data_dir = [data_dir]

self.raw_data_dirs = [pathlib.Path(d) for d in data_dir]
def __init__(self, datasets=None, data_attrs=None):
if datasets is None:
datasets = hparams['datasets']
self.datasets = datasets
self.raw_data_dirs = [pathlib.Path(ds['raw_data_dir']) for ds in self.datasets]
self.binary_data_dir = pathlib.Path(hparams['binary_data_dir'])
self.data_attrs = [] if data_attrs is None else data_attrs

self.binarization_args = hparams['binarization_args']
self.augmentation_args = hparams.get('augmentation_args', {})
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

self.spk_map = None
self.spk_ids = hparams['spk_ids']
self.speakers = hparams['speakers']
self.spk_map = {}
self.spk_ids = None
self.build_spk_map()

self.lang_map = {}
self.dictionaries = hparams['dictionaries']
self.build_lang_map()

self.items = {}
self.item_names: list = None
self._train_item_names: list = None
self._valid_item_names: list = None

self.phone_encoder = TokenTextEncoder(vocab_list=build_phoneme_list())
self.phoneme_dictionary = load_phoneme_dictionary()
self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']

def build_spk_map(self):
assert isinstance(self.speakers, list), 'Speakers must be a list'
assert len(self.speakers) == len(self.raw_data_dirs), \
'Number of raw data dirs must equal number of speaker names!'
if len(self.spk_ids) == 0:
self.spk_ids = list(range(len(self.raw_data_dirs)))
else:
assert len(self.spk_ids) == len(self.raw_data_dirs), \
'Length of explicitly given spk_ids must equal the number of raw datasets.'
assert max(self.spk_ids) < hparams['num_spk'], \
f'Index in spk_id sequence {self.spk_ids} is out of range. All values should be smaller than num_spk.'

self.spk_map = {}
for spk_name, spk_id in zip(self.speakers, self.spk_ids):
spk_ids = [ds.get('spk_id') for ds in self.datasets]
assigned_spk_ids = {spk_id for spk_id in spk_ids if spk_id is not None}
idx = 0
for i in range(len(spk_ids)):
if spk_ids[i] is not None:
continue
while idx in assigned_spk_ids:
idx += 1
spk_ids[i] = idx
assigned_spk_ids.add(idx)
assert max(spk_ids) < hparams['num_spk'], \
f'Index in spk_id sequence {spk_ids} is out of range. All values should be smaller than num_spk.'

for spk_id, dataset in zip(spk_ids, self.datasets):
spk_name = dataset['speaker']
if spk_name in self.spk_map and self.spk_map[spk_name] != spk_id:
raise ValueError(f'Invalid speaker ID assignment. Name \'{spk_name}\' is assigned '
f'with different speaker IDs: {self.spk_map[spk_name]} and {spk_id}.')
self.spk_map[spk_name] = spk_id
self.spk_ids = spk_ids

print("| spk_map: ", self.spk_map)

def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id):
def build_lang_map(self):
assert len(self.dictionaries.keys()) <= hparams['num_lang'], \
'Number of languages must not be greater than num_lang!'
for dataset in self.datasets:
assert dataset['language'] in self.dictionaries, f'Unrecognized language name: {dataset["language"]}'

for lang_id, lang_name in enumerate(sorted(self.dictionaries.keys()), start=1):
self.lang_map[lang_name] = lang_id

print("| lang_map: ", self.lang_map)

def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk, lang) -> dict:
raise NotImplementedError()

def split_train_valid_set(self, item_names):
def split_train_valid_set(self, prefixes: list):
"""
Split the dataset into training set and validation set.
:return: train_item_names, valid_item_names
"""
prefixes = {str(pr): 1 for pr in hparams['test_prefixes']}
prefixes = {str(pr): 1 for pr in prefixes}
valid_item_names = {}
# Add prefixes that specified speaker index and matches exactly item name to test set
for prefix in deepcopy(prefixes):
if prefix in item_names:
if prefix in self.item_names:
valid_item_names[prefix] = 1
prefixes.pop(prefix)
# Add prefixes that exactly matches item name without speaker id to test set
for prefix in deepcopy(prefixes):
matched = False
for name in item_names:
for name in self.item_names:
if name.split(':')[-1] == prefix:
valid_item_names[name] = 1
matched = True
Expand All @@ -119,15 +133,15 @@ def split_train_valid_set(self, item_names):
# Add names with one of the remaining prefixes to test set
for prefix in deepcopy(prefixes):
matched = False
for name in item_names:
for name in self.item_names:
if name.startswith(prefix):
valid_item_names[name] = 1
matched = True
if matched:
prefixes.pop(prefix)
for prefix in deepcopy(prefixes):
matched = False
for name in item_names:
for name in self.item_names:
if name.split(':')[-1].startswith(prefix):
valid_item_names[name] = 1
matched = True
Expand All @@ -143,7 +157,7 @@ def split_train_valid_set(self, item_names):

valid_item_names = list(valid_item_names.keys())
assert len(valid_item_names) > 0, 'Validation set is empty!'
train_item_names = [x for x in item_names if x not in set(valid_item_names)]
train_item_names = [x for x in self.item_names if x not in set(valid_item_names)]
assert len(train_item_names) > 0, 'Training set is empty!'

return train_item_names, valid_item_names
Expand All @@ -167,21 +181,34 @@ def meta_data_iterator(self, prefix):

def process(self):
# load each dataset
for ds_id, spk_id, data_dir in zip(range(len(self.raw_data_dirs)), self.spk_ids, self.raw_data_dirs):
self.load_meta_data(pathlib.Path(data_dir), ds_id=ds_id, spk_id=spk_id)
test_prefixes = []
for ds_id, dataset in enumerate(self.datasets):
items = self.load_meta_data(
pathlib.Path(dataset['raw_data_dir']),
ds_id=ds_id, spk=dataset['speaker'], lang=dataset['language']
)
self.items.update(items)
test_prefixes.extend(
f'{ds_id}:{prefix}'
for prefix in dataset.get('test_prefixes', [])
)
self.item_names = sorted(list(self.items.keys()))
self._train_item_names, self._valid_item_names = self.split_train_valid_set(self.item_names)
self._train_item_names, self._valid_item_names = self.split_train_valid_set(test_prefixes)

if self.binarization_args['shuffle']:
random.shuffle(self.item_names)

self.binary_data_dir.mkdir(parents=True, exist_ok=True)

# Copy spk_map and dictionary to binary data dir
# Copy spk_map, lang_map and dictionary to binary data dir
spk_map_fn = self.binary_data_dir / 'spk_map.json'
with open(spk_map_fn, 'w', encoding='utf-8') as f:
json.dump(self.spk_map, f)
shutil.copy(locate_dictionary(), self.binary_data_dir / 'dictionary.txt')
json.dump(self.spk_map, f, ensure_ascii=False)
lang_map_fn = self.binary_data_dir / 'lang_map.json'
with open(lang_map_fn, 'w', encoding='utf-8') as f:
json.dump(self.lang_map, f, ensure_ascii=False)
for lang, dict_path in hparams['dictionaries'].items():
shutil.copy(dict_path, self.binary_data_dir / f'dictionary-{lang}.txt')
self.check_coverage()

# Process valid set and train set
Expand All @@ -197,40 +224,47 @@ def process(self):

def check_coverage(self):
# Group by phonemes in the dictionary.
ph_required = set(build_phoneme_list())
phoneme_map = {}
for ph in ph_required:
phoneme_map[ph] = 0
ph_occurred = []
ph_idx_required = set(range(1, len(self.phoneme_dictionary)))
ph_idx_occurred = set()
ph_idx_count_map = {
idx: 0
for idx in ph_idx_required
}

# Load and count those phones that appear in the actual data
for item_name in self.items:
ph_occurred += self.items[item_name]['ph_seq']
if len(ph_occurred) == 0:
raise BinarizationError(f'Empty tokens in {item_name}.')
for ph in ph_occurred:
if ph not in ph_required:
continue
phoneme_map[ph] += 1
ph_occurred = set(ph_occurred)
ph_idx_occurred.update(self.items[item_name]['ph_seq'])
for idx in self.items[item_name]['ph_seq']:
ph_idx_count_map[idx] += 1
ph_count_map = {
self.phoneme_dictionary.decode_one(idx, scalar=False): count
for idx, count in ph_idx_count_map.items()
}

def display_phoneme(phoneme):
if isinstance(phoneme, tuple):
return f'({", ".join(phoneme)})'
return phoneme

print('===== Phoneme Distribution Summary =====')
for i, key in enumerate(sorted(phoneme_map.keys())):
if i == len(ph_required) - 1:
keys = sorted(ph_count_map.keys(), key=lambda v: v[0] if isinstance(v, tuple) else v)
for i, key in enumerate(keys):
if i == len(ph_count_map) - 1:
end = '\n'
elif i % 10 == 9:
end = ',\n'
else:
end = ', '
print(f'\'{key}\': {phoneme_map[key]}', end=end)
key_disp = display_phoneme(key)
print(f'{key_disp}: {ph_count_map[key]}', end=end)

# Draw graph.
x = sorted(phoneme_map.keys())
values = [phoneme_map[k] for k in x]
xs = [display_phoneme(k) for k in keys]
ys = [ph_count_map[k] for k in keys]
plt = distribution_to_figure(
title='Phoneme Distribution Summary',
x_label='Phoneme', y_label='Number of occurrences',
items=x, values=values
items=xs, values=ys, rotate=len(self.dictionaries) > 1
)
filename = self.binary_data_dir / 'phoneme_distribution.jpg'
plt.savefig(fname=filename,
Expand All @@ -239,19 +273,21 @@ def check_coverage(self):
print(f'| save summary to \'{filename}\'')

# Check unrecognizable or missing phonemes
if ph_occurred != ph_required:
unrecognizable_phones = ph_occurred.difference(ph_required)
missing_phones = ph_required.difference(ph_occurred)
raise BinarizationError('transcriptions and dictionary mismatch.\n'
f' (+) {sorted(unrecognizable_phones)}\n'
f' (-) {sorted(missing_phones)}')
if ph_idx_occurred != ph_idx_required:
missing_phones = sorted({
self.phoneme_dictionary.decode_one(idx, scalar=False)
for idx in ph_idx_required.difference(ph_idx_occurred)
}, key=lambda v: v[0] if isinstance(v, tuple) else v)
raise BinarizationError(
f'The following phonemes are not covered in transcriptions: {missing_phones}'
)

def process_dataset(self, prefix, num_workers=0, apply_augmentation=False):
args = []
builder = IndexedDatasetBuilder(self.binary_data_dir, prefix=prefix, allowed_attr=self.data_attrs)
total_sec = {k: 0.0 for k in self.spk_map}
total_raw_sec = {k: 0.0 for k in self.spk_map}
extra_info = {'names': {}, 'spk_ids': {}, 'spk_names': {}, 'lengths': {}}
extra_info = {'names': {}, 'ph_texts': {}, 'spk_ids': {}, 'spk_names': {}, 'lengths': {}}
max_no = -1

for item_name, meta_data in self.meta_data_iterator(prefix):
Expand All @@ -271,6 +307,7 @@ def postprocess(_item):
extra_info[k] = {}
extra_info[k][item_no] = v.shape[0]
extra_info['names'][item_no] = _item['name'].split(':', 1)[-1]
extra_info['ph_texts'][item_no] = _item['ph_text']
extra_info['spk_ids'][item_no] = _item['spk_id']
extra_info['spk_names'][item_no] = _item['spk_name']
extra_info['lengths'][item_no] = _item['length']
Expand All @@ -287,6 +324,7 @@ def postprocess(_item):
extra_info[k] = {}
extra_info[k][aug_item_no] = v.shape[0]
extra_info['names'][aug_item_no] = aug_item['name'].split(':', 1)[-1]
extra_info['ph_texts'][aug_item_no] = aug_item['ph_text']
extra_info['spk_ids'][aug_item_no] = aug_item['spk_id']
extra_info['spk_names'][aug_item_no] = aug_item['spk_name']
extra_info['lengths'][aug_item_no] = aug_item['length']
Expand Down Expand Up @@ -315,6 +353,7 @@ def postprocess(_item):
builder.finalize()
if prefix == "train":
extra_info.pop("names")
extra_info.pop('ph_texts')
extra_info.pop("spk_names")
with open(self.binary_data_dir / f"{prefix}.meta", "wb") as f:
# noinspection PyTypeChecker
Expand Down
27 changes: 27 additions & 0 deletions basics/base_exporter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import pathlib
import shutil
from pathlib import Path
from typing import Union

Expand Down Expand Up @@ -31,6 +33,18 @@ def build_spk_map(self) -> dict:
else:
return {}

# noinspection PyMethodMayBeStatic
def build_lang_map(self) -> dict:
lang_map_fn = pathlib.Path(hparams['work_dir']) / 'lang_map.json'
if lang_map_fn.exists():
with open(lang_map_fn, 'r', encoding='utf8') as f:
lang_map = json.load(f)
assert isinstance(lang_map, dict) and len(lang_map) > 0, 'Invalid or empty language map!'
assert len(lang_map) == len(set(lang_map.values())), 'Duplicate language id in language map!'
return lang_map
else:
return {}

def build_model(self) -> nn.Module:
"""
Creates an instance of nn.Module and load its state dict on the target device.
Expand All @@ -44,6 +58,19 @@ def export_model(self, path: Path):
"""
raise NotImplementedError()

# noinspection PyMethodMayBeStatic
def export_dictionaries(self, path: Path):
dicts = hparams.get('dictionaries')
if dicts is not None:
for lang in dicts.keys():
fn = f'dictionary-{lang}.txt'
shutil.copy(pathlib.Path(hparams['work_dir']) / fn, path)
print(f'| export dictionary => {path / fn}')
else:
fn = 'dictionary.txt'
shutil.copy(pathlib.Path(hparams['work_dir']) / fn, path)
print(f'| export dictionary => {path / fn}')

def export_attachments(self, path: Path):
"""
Exports related files and configs (e.g. the dictionary) to the target directory.
Expand Down
7 changes: 6 additions & 1 deletion basics/base_svs_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, device=None):
self.device = device
self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']
self.spk_map = {}
self.lang_map = {}
self.model: torch.nn.Module = None

def build_model(self, ckpt_steps=None) -> torch.nn.Module:
Expand All @@ -50,7 +51,11 @@ def load_speaker_mix(self, param_src: dict, summary_dst: dict,
spk_mix_map = param_src.get(param_key) # { spk_name: value } or { spk_name: "value value value ..." }
dynamic = False
if spk_mix_map is None:
# Get the first speaker
assert len(self.spk_map) == 1, (
"This is a multi-speaker model. "
"Please specify a speaker or speaker mix by --spk option."
)
# Get the only speaker
for name in self.spk_map.keys():
spk_mix_map = {name: 1.0}
break
Expand Down
Loading