Skip to content

Commit 387242c

Browse files
Reapply "Full implementation of multi-dictionary support (openvpi#238)"
This reverts commit 9957ba1.
1 parent 9957ba1 commit 387242c

30 files changed

+1035
-539
lines changed

basics/base_binarizer.py

Lines changed: 105 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
from utils.hparams import hparams
1414
from utils.indexed_datasets import IndexedDatasetBuilder
1515
from utils.multiprocess_utils import chunked_multiprocess_run
16-
from utils.phoneme_utils import build_phoneme_list, locate_dictionary
16+
from utils.phoneme_utils import load_phoneme_dictionary
1717
from utils.plot import distribution_to_figure
18-
from utils.text_encoder import TokenTextEncoder
1918

2019

2120
class BinarizationError(Exception):
@@ -44,73 +43,88 @@ class BaseBinarizer:
4443
the phoneme set.
4544
"""
4645

47-
def __init__(self, data_dir=None, data_attrs=None):
48-
if data_dir is None:
49-
data_dir = hparams['raw_data_dir']
50-
if not isinstance(data_dir, list):
51-
data_dir = [data_dir]
52-
53-
self.raw_data_dirs = [pathlib.Path(d) for d in data_dir]
46+
def __init__(self, datasets=None, data_attrs=None):
47+
if datasets is None:
48+
datasets = hparams['datasets']
49+
self.datasets = datasets
50+
self.raw_data_dirs = [pathlib.Path(ds['raw_data_dir']) for ds in self.datasets]
5451
self.binary_data_dir = pathlib.Path(hparams['binary_data_dir'])
5552
self.data_attrs = [] if data_attrs is None else data_attrs
5653

5754
self.binarization_args = hparams['binarization_args']
5855
self.augmentation_args = hparams.get('augmentation_args', {})
5956
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6057

61-
self.spk_map = None
62-
self.spk_ids = hparams['spk_ids']
63-
self.speakers = hparams['speakers']
58+
self.spk_map = {}
59+
self.spk_ids = None
6460
self.build_spk_map()
6561

62+
self.lang_map = {}
63+
self.dictionaries = hparams['dictionaries']
64+
self.build_lang_map()
65+
6666
self.items = {}
6767
self.item_names: list = None
6868
self._train_item_names: list = None
6969
self._valid_item_names: list = None
7070

71-
self.phone_encoder = TokenTextEncoder(vocab_list=build_phoneme_list())
71+
self.phoneme_dictionary = load_phoneme_dictionary()
7272
self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']
7373

7474
def build_spk_map(self):
75-
assert isinstance(self.speakers, list), 'Speakers must be a list'
76-
assert len(self.speakers) == len(self.raw_data_dirs), \
77-
'Number of raw data dirs must equal number of speaker names!'
78-
if len(self.spk_ids) == 0:
79-
self.spk_ids = list(range(len(self.raw_data_dirs)))
80-
else:
81-
assert len(self.spk_ids) == len(self.raw_data_dirs), \
82-
'Length of explicitly given spk_ids must equal the number of raw datasets.'
83-
assert max(self.spk_ids) < hparams['num_spk'], \
84-
f'Index in spk_id sequence {self.spk_ids} is out of range. All values should be smaller than num_spk.'
85-
86-
self.spk_map = {}
87-
for spk_name, spk_id in zip(self.speakers, self.spk_ids):
75+
spk_ids = [ds.get('spk_id') for ds in self.datasets]
76+
assigned_spk_ids = {spk_id for spk_id in spk_ids if spk_id is not None}
77+
idx = 0
78+
for i in range(len(spk_ids)):
79+
if spk_ids[i] is not None:
80+
continue
81+
while idx in assigned_spk_ids:
82+
idx += 1
83+
spk_ids[i] = idx
84+
assigned_spk_ids.add(idx)
85+
assert max(spk_ids) < hparams['num_spk'], \
86+
f'Index in spk_id sequence {spk_ids} is out of range. All values should be smaller than num_spk.'
87+
88+
for spk_id, dataset in zip(spk_ids, self.datasets):
89+
spk_name = dataset['speaker']
8890
if spk_name in self.spk_map and self.spk_map[spk_name] != spk_id:
8991
raise ValueError(f'Invalid speaker ID assignment. Name \'{spk_name}\' is assigned '
9092
f'with different speaker IDs: {self.spk_map[spk_name]} and {spk_id}.')
9193
self.spk_map[spk_name] = spk_id
94+
self.spk_ids = spk_ids
9295

9396
print("| spk_map: ", self.spk_map)
9497

95-
def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id):
98+
def build_lang_map(self):
99+
assert len(self.dictionaries.keys()) <= hparams['num_lang'], \
100+
'Number of languages must not be greater than num_lang!'
101+
for dataset in self.datasets:
102+
assert dataset['language'] in self.dictionaries, f'Unrecognized language name: {dataset["language"]}'
103+
104+
for lang_id, lang_name in enumerate(sorted(self.dictionaries.keys()), start=1):
105+
self.lang_map[lang_name] = lang_id
106+
107+
print("| lang_map: ", self.lang_map)
108+
109+
def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk, lang) -> dict:
96110
raise NotImplementedError()
97111

98-
def split_train_valid_set(self, item_names):
112+
def split_train_valid_set(self, prefixes: list):
99113
"""
100114
Split the dataset into training set and validation set.
101115
:return: train_item_names, valid_item_names
102116
"""
103-
prefixes = {str(pr): 1 for pr in hparams['test_prefixes']}
117+
prefixes = {str(pr): 1 for pr in prefixes}
104118
valid_item_names = {}
105119
# Add prefixes that specified speaker index and matches exactly item name to test set
106120
for prefix in deepcopy(prefixes):
107-
if prefix in item_names:
121+
if prefix in self.item_names:
108122
valid_item_names[prefix] = 1
109123
prefixes.pop(prefix)
110124
# Add prefixes that exactly matches item name without speaker id to test set
111125
for prefix in deepcopy(prefixes):
112126
matched = False
113-
for name in item_names:
127+
for name in self.item_names:
114128
if name.split(':')[-1] == prefix:
115129
valid_item_names[name] = 1
116130
matched = True
@@ -119,15 +133,15 @@ def split_train_valid_set(self, item_names):
119133
# Add names with one of the remaining prefixes to test set
120134
for prefix in deepcopy(prefixes):
121135
matched = False
122-
for name in item_names:
136+
for name in self.item_names:
123137
if name.startswith(prefix):
124138
valid_item_names[name] = 1
125139
matched = True
126140
if matched:
127141
prefixes.pop(prefix)
128142
for prefix in deepcopy(prefixes):
129143
matched = False
130-
for name in item_names:
144+
for name in self.item_names:
131145
if name.split(':')[-1].startswith(prefix):
132146
valid_item_names[name] = 1
133147
matched = True
@@ -143,7 +157,7 @@ def split_train_valid_set(self, item_names):
143157

144158
valid_item_names = list(valid_item_names.keys())
145159
assert len(valid_item_names) > 0, 'Validation set is empty!'
146-
train_item_names = [x for x in item_names if x not in set(valid_item_names)]
160+
train_item_names = [x for x in self.item_names if x not in set(valid_item_names)]
147161
assert len(train_item_names) > 0, 'Training set is empty!'
148162

149163
return train_item_names, valid_item_names
@@ -167,21 +181,34 @@ def meta_data_iterator(self, prefix):
167181

168182
def process(self):
169183
# load each dataset
170-
for ds_id, spk_id, data_dir in zip(range(len(self.raw_data_dirs)), self.spk_ids, self.raw_data_dirs):
171-
self.load_meta_data(pathlib.Path(data_dir), ds_id=ds_id, spk_id=spk_id)
184+
test_prefixes = []
185+
for ds_id, dataset in enumerate(self.datasets):
186+
items = self.load_meta_data(
187+
pathlib.Path(dataset['raw_data_dir']),
188+
ds_id=ds_id, spk=dataset['speaker'], lang=dataset['language']
189+
)
190+
self.items.update(items)
191+
test_prefixes.extend(
192+
f'{ds_id}:{prefix}'
193+
for prefix in dataset.get('test_prefixes', [])
194+
)
172195
self.item_names = sorted(list(self.items.keys()))
173-
self._train_item_names, self._valid_item_names = self.split_train_valid_set(self.item_names)
196+
self._train_item_names, self._valid_item_names = self.split_train_valid_set(test_prefixes)
174197

175198
if self.binarization_args['shuffle']:
176199
random.shuffle(self.item_names)
177200

178201
self.binary_data_dir.mkdir(parents=True, exist_ok=True)
179202

180-
# Copy spk_map and dictionary to binary data dir
203+
# Copy spk_map, lang_map and dictionary to binary data dir
181204
spk_map_fn = self.binary_data_dir / 'spk_map.json'
182205
with open(spk_map_fn, 'w', encoding='utf-8') as f:
183-
json.dump(self.spk_map, f)
184-
shutil.copy(locate_dictionary(), self.binary_data_dir / 'dictionary.txt')
206+
json.dump(self.spk_map, f, ensure_ascii=False)
207+
lang_map_fn = self.binary_data_dir / 'lang_map.json'
208+
with open(lang_map_fn, 'w', encoding='utf-8') as f:
209+
json.dump(self.lang_map, f, ensure_ascii=False)
210+
for lang, dict_path in hparams['dictionaries'].items():
211+
shutil.copy(dict_path, self.binary_data_dir / f'dictionary-{lang}.txt')
185212
self.check_coverage()
186213

187214
# Process valid set and train set
@@ -197,40 +224,47 @@ def process(self):
197224

198225
def check_coverage(self):
199226
# Group by phonemes in the dictionary.
200-
ph_required = set(build_phoneme_list())
201-
phoneme_map = {}
202-
for ph in ph_required:
203-
phoneme_map[ph] = 0
204-
ph_occurred = []
227+
ph_idx_required = set(range(1, len(self.phoneme_dictionary)))
228+
ph_idx_occurred = set()
229+
ph_idx_count_map = {
230+
idx: 0
231+
for idx in ph_idx_required
232+
}
205233

206234
# Load and count those phones that appear in the actual data
207235
for item_name in self.items:
208-
ph_occurred += self.items[item_name]['ph_seq']
209-
if len(ph_occurred) == 0:
210-
raise BinarizationError(f'Empty tokens in {item_name}.')
211-
for ph in ph_occurred:
212-
if ph not in ph_required:
213-
continue
214-
phoneme_map[ph] += 1
215-
ph_occurred = set(ph_occurred)
236+
ph_idx_occurred.update(self.items[item_name]['ph_seq'])
237+
for idx in self.items[item_name]['ph_seq']:
238+
ph_idx_count_map[idx] += 1
239+
ph_count_map = {
240+
self.phoneme_dictionary.decode_one(idx, scalar=False): count
241+
for idx, count in ph_idx_count_map.items()
242+
}
243+
244+
def display_phoneme(phoneme):
245+
if isinstance(phoneme, tuple):
246+
return f'({", ".join(phoneme)})'
247+
return phoneme
216248

217249
print('===== Phoneme Distribution Summary =====')
218-
for i, key in enumerate(sorted(phoneme_map.keys())):
219-
if i == len(ph_required) - 1:
250+
keys = sorted(ph_count_map.keys(), key=lambda v: v[0] if isinstance(v, tuple) else v)
251+
for i, key in enumerate(keys):
252+
if i == len(ph_count_map) - 1:
220253
end = '\n'
221254
elif i % 10 == 9:
222255
end = ',\n'
223256
else:
224257
end = ', '
225-
print(f'\'{key}\': {phoneme_map[key]}', end=end)
258+
key_disp = display_phoneme(key)
259+
print(f'{key_disp}: {ph_count_map[key]}', end=end)
226260

227261
# Draw graph.
228-
x = sorted(phoneme_map.keys())
229-
values = [phoneme_map[k] for k in x]
262+
xs = [display_phoneme(k) for k in keys]
263+
ys = [ph_count_map[k] for k in keys]
230264
plt = distribution_to_figure(
231265
title='Phoneme Distribution Summary',
232266
x_label='Phoneme', y_label='Number of occurrences',
233-
items=x, values=values
267+
items=xs, values=ys, rotate=len(self.dictionaries) > 1
234268
)
235269
filename = self.binary_data_dir / 'phoneme_distribution.jpg'
236270
plt.savefig(fname=filename,
@@ -239,19 +273,21 @@ def check_coverage(self):
239273
print(f'| save summary to \'{filename}\'')
240274

241275
# Check unrecognizable or missing phonemes
242-
if ph_occurred != ph_required:
243-
unrecognizable_phones = ph_occurred.difference(ph_required)
244-
missing_phones = ph_required.difference(ph_occurred)
245-
raise BinarizationError('transcriptions and dictionary mismatch.\n'
246-
f' (+) {sorted(unrecognizable_phones)}\n'
247-
f' (-) {sorted(missing_phones)}')
276+
if ph_idx_occurred != ph_idx_required:
277+
missing_phones = sorted({
278+
self.phoneme_dictionary.decode_one(idx, scalar=False)
279+
for idx in ph_idx_required.difference(ph_idx_occurred)
280+
}, key=lambda v: v[0] if isinstance(v, tuple) else v)
281+
raise BinarizationError(
282+
f'The following phonemes are not covered in transcriptions: {missing_phones}'
283+
)
248284

249285
def process_dataset(self, prefix, num_workers=0, apply_augmentation=False):
250286
args = []
251287
builder = IndexedDatasetBuilder(self.binary_data_dir, prefix=prefix, allowed_attr=self.data_attrs)
252288
total_sec = {k: 0.0 for k in self.spk_map}
253289
total_raw_sec = {k: 0.0 for k in self.spk_map}
254-
extra_info = {'names': {}, 'spk_ids': {}, 'spk_names': {}, 'lengths': {}}
290+
extra_info = {'names': {}, 'ph_texts': {}, 'spk_ids': {}, 'spk_names': {}, 'lengths': {}}
255291
max_no = -1
256292

257293
for item_name, meta_data in self.meta_data_iterator(prefix):
@@ -271,6 +307,7 @@ def postprocess(_item):
271307
extra_info[k] = {}
272308
extra_info[k][item_no] = v.shape[0]
273309
extra_info['names'][item_no] = _item['name'].split(':', 1)[-1]
310+
extra_info['ph_texts'][item_no] = _item['ph_text']
274311
extra_info['spk_ids'][item_no] = _item['spk_id']
275312
extra_info['spk_names'][item_no] = _item['spk_name']
276313
extra_info['lengths'][item_no] = _item['length']
@@ -287,6 +324,7 @@ def postprocess(_item):
287324
extra_info[k] = {}
288325
extra_info[k][aug_item_no] = v.shape[0]
289326
extra_info['names'][aug_item_no] = aug_item['name'].split(':', 1)[-1]
327+
extra_info['ph_texts'][aug_item_no] = aug_item['ph_text']
290328
extra_info['spk_ids'][aug_item_no] = aug_item['spk_id']
291329
extra_info['spk_names'][aug_item_no] = aug_item['spk_name']
292330
extra_info['lengths'][aug_item_no] = aug_item['length']
@@ -315,6 +353,7 @@ def postprocess(_item):
315353
builder.finalize()
316354
if prefix == "train":
317355
extra_info.pop("names")
356+
extra_info.pop('ph_texts')
318357
extra_info.pop("spk_names")
319358
with open(self.binary_data_dir / f"{prefix}.meta", "wb") as f:
320359
# noinspection PyTypeChecker

basics/base_exporter.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import json
2+
import pathlib
3+
import shutil
24
from pathlib import Path
35
from typing import Union
46

@@ -31,6 +33,18 @@ def build_spk_map(self) -> dict:
3133
else:
3234
return {}
3335

36+
# noinspection PyMethodMayBeStatic
37+
def build_lang_map(self) -> dict:
38+
lang_map_fn = pathlib.Path(hparams['work_dir']) / 'lang_map.json'
39+
if lang_map_fn.exists():
40+
with open(lang_map_fn, 'r', encoding='utf8') as f:
41+
lang_map = json.load(f)
42+
assert isinstance(lang_map, dict) and len(lang_map) > 0, 'Invalid or empty language map!'
43+
assert len(lang_map) == len(set(lang_map.values())), 'Duplicate language id in language map!'
44+
return lang_map
45+
else:
46+
return {}
47+
3448
def build_model(self) -> nn.Module:
3549
"""
3650
Creates an instance of nn.Module and load its state dict on the target device.
@@ -44,6 +58,19 @@ def export_model(self, path: Path):
4458
"""
4559
raise NotImplementedError()
4660

61+
# noinspection PyMethodMayBeStatic
62+
def export_dictionaries(self, path: Path):
63+
dicts = hparams.get('dictionaries')
64+
if dicts is not None:
65+
for lang in dicts.keys():
66+
fn = f'dictionary-{lang}.txt'
67+
shutil.copy(pathlib.Path(hparams['work_dir']) / fn, path)
68+
print(f'| export dictionary => {path / fn}')
69+
else:
70+
fn = 'dictionary.txt'
71+
shutil.copy(pathlib.Path(hparams['work_dir']) / fn, path)
72+
print(f'| export dictionary => {path / fn}')
73+
4774
def export_attachments(self, path: Path):
4875
"""
4976
Exports related files and configs (e.g. the dictionary) to the target directory.

basics/base_svs_infer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(self, device=None):
2929
self.device = device
3030
self.timestep = hparams['hop_size'] / hparams['audio_sample_rate']
3131
self.spk_map = {}
32+
self.lang_map = {}
3233
self.model: torch.nn.Module = None
3334

3435
def build_model(self, ckpt_steps=None) -> torch.nn.Module:
@@ -50,7 +51,11 @@ def load_speaker_mix(self, param_src: dict, summary_dst: dict,
5051
spk_mix_map = param_src.get(param_key) # { spk_name: value } or { spk_name: "value value value ..." }
5152
dynamic = False
5253
if spk_mix_map is None:
53-
# Get the first speaker
54+
assert len(self.spk_map) == 1, (
55+
"This is a multi-speaker model. "
56+
"Please specify a speaker or speaker mix by --spk option."
57+
)
58+
# Get the only speaker
5459
for name in self.spk_map.keys():
5560
spk_mix_map = {name: 1.0}
5661
break

0 commit comments

Comments
 (0)