13
13
from utils .hparams import hparams
14
14
from utils .indexed_datasets import IndexedDatasetBuilder
15
15
from utils .multiprocess_utils import chunked_multiprocess_run
16
- from utils .phoneme_utils import build_phoneme_list , locate_dictionary
16
+ from utils .phoneme_utils import load_phoneme_dictionary
17
17
from utils .plot import distribution_to_figure
18
- from utils .text_encoder import TokenTextEncoder
19
18
20
19
21
20
class BinarizationError (Exception ):
@@ -44,73 +43,88 @@ class BaseBinarizer:
44
43
the phoneme set.
45
44
"""
46
45
47
- def __init__ (self , data_dir = None , data_attrs = None ):
48
- if data_dir is None :
49
- data_dir = hparams ['raw_data_dir' ]
50
- if not isinstance (data_dir , list ):
51
- data_dir = [data_dir ]
52
-
53
- self .raw_data_dirs = [pathlib .Path (d ) for d in data_dir ]
46
+ def __init__ (self , datasets = None , data_attrs = None ):
47
+ if datasets is None :
48
+ datasets = hparams ['datasets' ]
49
+ self .datasets = datasets
50
+ self .raw_data_dirs = [pathlib .Path (ds ['raw_data_dir' ]) for ds in self .datasets ]
54
51
self .binary_data_dir = pathlib .Path (hparams ['binary_data_dir' ])
55
52
self .data_attrs = [] if data_attrs is None else data_attrs
56
53
57
54
self .binarization_args = hparams ['binarization_args' ]
58
55
self .augmentation_args = hparams .get ('augmentation_args' , {})
59
56
self .device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
60
57
61
- self .spk_map = None
62
- self .spk_ids = hparams ['spk_ids' ]
63
- self .speakers = hparams ['speakers' ]
58
+ self .spk_map = {}
59
+ self .spk_ids = None
64
60
self .build_spk_map ()
65
61
62
+ self .lang_map = {}
63
+ self .dictionaries = hparams ['dictionaries' ]
64
+ self .build_lang_map ()
65
+
66
66
self .items = {}
67
67
self .item_names : list = None
68
68
self ._train_item_names : list = None
69
69
self ._valid_item_names : list = None
70
70
71
- self .phone_encoder = TokenTextEncoder ( vocab_list = build_phoneme_list () )
71
+ self .phoneme_dictionary = load_phoneme_dictionary ( )
72
72
self .timestep = hparams ['hop_size' ] / hparams ['audio_sample_rate' ]
73
73
74
74
def build_spk_map (self ):
75
- assert isinstance (self .speakers , list ), 'Speakers must be a list'
76
- assert len (self .speakers ) == len (self .raw_data_dirs ), \
77
- 'Number of raw data dirs must equal number of speaker names!'
78
- if len (self .spk_ids ) == 0 :
79
- self .spk_ids = list (range (len (self .raw_data_dirs )))
80
- else :
81
- assert len (self .spk_ids ) == len (self .raw_data_dirs ), \
82
- 'Length of explicitly given spk_ids must equal the number of raw datasets.'
83
- assert max (self .spk_ids ) < hparams ['num_spk' ], \
84
- f'Index in spk_id sequence { self .spk_ids } is out of range. All values should be smaller than num_spk.'
85
-
86
- self .spk_map = {}
87
- for spk_name , spk_id in zip (self .speakers , self .spk_ids ):
75
+ spk_ids = [ds .get ('spk_id' ) for ds in self .datasets ]
76
+ assigned_spk_ids = {spk_id for spk_id in spk_ids if spk_id is not None }
77
+ idx = 0
78
+ for i in range (len (spk_ids )):
79
+ if spk_ids [i ] is not None :
80
+ continue
81
+ while idx in assigned_spk_ids :
82
+ idx += 1
83
+ spk_ids [i ] = idx
84
+ assigned_spk_ids .add (idx )
85
+ assert max (spk_ids ) < hparams ['num_spk' ], \
86
+ f'Index in spk_id sequence { spk_ids } is out of range. All values should be smaller than num_spk.'
87
+
88
+ for spk_id , dataset in zip (spk_ids , self .datasets ):
89
+ spk_name = dataset ['speaker' ]
88
90
if spk_name in self .spk_map and self .spk_map [spk_name ] != spk_id :
89
91
raise ValueError (f'Invalid speaker ID assignment. Name \' { spk_name } \' is assigned '
90
92
f'with different speaker IDs: { self .spk_map [spk_name ]} and { spk_id } .' )
91
93
self .spk_map [spk_name ] = spk_id
94
+ self .spk_ids = spk_ids
92
95
93
96
print ("| spk_map: " , self .spk_map )
94
97
95
- def load_meta_data (self , raw_data_dir : pathlib .Path , ds_id , spk_id ):
98
+ def build_lang_map (self ):
99
+ assert len (self .dictionaries .keys ()) <= hparams ['num_lang' ], \
100
+ 'Number of languages must not be greater than num_lang!'
101
+ for dataset in self .datasets :
102
+ assert dataset ['language' ] in self .dictionaries , f'Unrecognized language name: { dataset ["language" ]} '
103
+
104
+ for lang_id , lang_name in enumerate (sorted (self .dictionaries .keys ()), start = 1 ):
105
+ self .lang_map [lang_name ] = lang_id
106
+
107
+ print ("| lang_map: " , self .lang_map )
108
+
109
+ def load_meta_data (self , raw_data_dir : pathlib .Path , ds_id , spk , lang ) -> dict :
96
110
raise NotImplementedError ()
97
111
98
- def split_train_valid_set (self , item_names ):
112
+ def split_train_valid_set (self , prefixes : list ):
99
113
"""
100
114
Split the dataset into training set and validation set.
101
115
:return: train_item_names, valid_item_names
102
116
"""
103
- prefixes = {str (pr ): 1 for pr in hparams [ 'test_prefixes' ] }
117
+ prefixes = {str (pr ): 1 for pr in prefixes }
104
118
valid_item_names = {}
105
119
# Add prefixes that specified speaker index and matches exactly item name to test set
106
120
for prefix in deepcopy (prefixes ):
107
- if prefix in item_names :
121
+ if prefix in self . item_names :
108
122
valid_item_names [prefix ] = 1
109
123
prefixes .pop (prefix )
110
124
# Add prefixes that exactly matches item name without speaker id to test set
111
125
for prefix in deepcopy (prefixes ):
112
126
matched = False
113
- for name in item_names :
127
+ for name in self . item_names :
114
128
if name .split (':' )[- 1 ] == prefix :
115
129
valid_item_names [name ] = 1
116
130
matched = True
@@ -119,15 +133,15 @@ def split_train_valid_set(self, item_names):
119
133
# Add names with one of the remaining prefixes to test set
120
134
for prefix in deepcopy (prefixes ):
121
135
matched = False
122
- for name in item_names :
136
+ for name in self . item_names :
123
137
if name .startswith (prefix ):
124
138
valid_item_names [name ] = 1
125
139
matched = True
126
140
if matched :
127
141
prefixes .pop (prefix )
128
142
for prefix in deepcopy (prefixes ):
129
143
matched = False
130
- for name in item_names :
144
+ for name in self . item_names :
131
145
if name .split (':' )[- 1 ].startswith (prefix ):
132
146
valid_item_names [name ] = 1
133
147
matched = True
@@ -143,7 +157,7 @@ def split_train_valid_set(self, item_names):
143
157
144
158
valid_item_names = list (valid_item_names .keys ())
145
159
assert len (valid_item_names ) > 0 , 'Validation set is empty!'
146
- train_item_names = [x for x in item_names if x not in set (valid_item_names )]
160
+ train_item_names = [x for x in self . item_names if x not in set (valid_item_names )]
147
161
assert len (train_item_names ) > 0 , 'Training set is empty!'
148
162
149
163
return train_item_names , valid_item_names
@@ -167,21 +181,34 @@ def meta_data_iterator(self, prefix):
167
181
168
182
def process (self ):
169
183
# load each dataset
170
- for ds_id , spk_id , data_dir in zip (range (len (self .raw_data_dirs )), self .spk_ids , self .raw_data_dirs ):
171
- self .load_meta_data (pathlib .Path (data_dir ), ds_id = ds_id , spk_id = spk_id )
184
+ test_prefixes = []
185
+ for ds_id , dataset in enumerate (self .datasets ):
186
+ items = self .load_meta_data (
187
+ pathlib .Path (dataset ['raw_data_dir' ]),
188
+ ds_id = ds_id , spk = dataset ['speaker' ], lang = dataset ['language' ]
189
+ )
190
+ self .items .update (items )
191
+ test_prefixes .extend (
192
+ f'{ ds_id } :{ prefix } '
193
+ for prefix in dataset .get ('test_prefixes' , [])
194
+ )
172
195
self .item_names = sorted (list (self .items .keys ()))
173
- self ._train_item_names , self ._valid_item_names = self .split_train_valid_set (self . item_names )
196
+ self ._train_item_names , self ._valid_item_names = self .split_train_valid_set (test_prefixes )
174
197
175
198
if self .binarization_args ['shuffle' ]:
176
199
random .shuffle (self .item_names )
177
200
178
201
self .binary_data_dir .mkdir (parents = True , exist_ok = True )
179
202
180
- # Copy spk_map and dictionary to binary data dir
203
+ # Copy spk_map, lang_map and dictionary to binary data dir
181
204
spk_map_fn = self .binary_data_dir / 'spk_map.json'
182
205
with open (spk_map_fn , 'w' , encoding = 'utf-8' ) as f :
183
- json .dump (self .spk_map , f )
184
- shutil .copy (locate_dictionary (), self .binary_data_dir / 'dictionary.txt' )
206
+ json .dump (self .spk_map , f , ensure_ascii = False )
207
+ lang_map_fn = self .binary_data_dir / 'lang_map.json'
208
+ with open (lang_map_fn , 'w' , encoding = 'utf-8' ) as f :
209
+ json .dump (self .lang_map , f , ensure_ascii = False )
210
+ for lang , dict_path in hparams ['dictionaries' ].items ():
211
+ shutil .copy (dict_path , self .binary_data_dir / f'dictionary-{ lang } .txt' )
185
212
self .check_coverage ()
186
213
187
214
# Process valid set and train set
@@ -197,40 +224,47 @@ def process(self):
197
224
198
225
def check_coverage (self ):
199
226
# Group by phonemes in the dictionary.
200
- ph_required = set (build_phoneme_list ())
201
- phoneme_map = {}
202
- for ph in ph_required :
203
- phoneme_map [ph ] = 0
204
- ph_occurred = []
227
+ ph_idx_required = set (range (1 , len (self .phoneme_dictionary )))
228
+ ph_idx_occurred = set ()
229
+ ph_idx_count_map = {
230
+ idx : 0
231
+ for idx in ph_idx_required
232
+ }
205
233
206
234
# Load and count those phones that appear in the actual data
207
235
for item_name in self .items :
208
- ph_occurred += self .items [item_name ]['ph_seq' ]
209
- if len (ph_occurred ) == 0 :
210
- raise BinarizationError (f'Empty tokens in { item_name } .' )
211
- for ph in ph_occurred :
212
- if ph not in ph_required :
213
- continue
214
- phoneme_map [ph ] += 1
215
- ph_occurred = set (ph_occurred )
236
+ ph_idx_occurred .update (self .items [item_name ]['ph_seq' ])
237
+ for idx in self .items [item_name ]['ph_seq' ]:
238
+ ph_idx_count_map [idx ] += 1
239
+ ph_count_map = {
240
+ self .phoneme_dictionary .decode_one (idx , scalar = False ): count
241
+ for idx , count in ph_idx_count_map .items ()
242
+ }
243
+
244
+ def display_phoneme (phoneme ):
245
+ if isinstance (phoneme , tuple ):
246
+ return f'({ ", " .join (phoneme )} )'
247
+ return phoneme
216
248
217
249
print ('===== Phoneme Distribution Summary =====' )
218
- for i , key in enumerate (sorted (phoneme_map .keys ())):
219
- if i == len (ph_required ) - 1 :
250
+ keys = sorted (ph_count_map .keys (), key = lambda v : v [0 ] if isinstance (v , tuple ) else v )
251
+ for i , key in enumerate (keys ):
252
+ if i == len (ph_count_map ) - 1 :
220
253
end = '\n '
221
254
elif i % 10 == 9 :
222
255
end = ',\n '
223
256
else :
224
257
end = ', '
225
- print (f'\' { key } \' : { phoneme_map [key ]} ' , end = end )
258
+ key_disp = display_phoneme (key )
259
+ print (f'{ key_disp } : { ph_count_map [key ]} ' , end = end )
226
260
227
261
# Draw graph.
228
- x = sorted ( phoneme_map . keys ())
229
- values = [phoneme_map [k ] for k in x ]
262
+ xs = [ display_phoneme ( k ) for k in keys ]
263
+ ys = [ph_count_map [k ] for k in keys ]
230
264
plt = distribution_to_figure (
231
265
title = 'Phoneme Distribution Summary' ,
232
266
x_label = 'Phoneme' , y_label = 'Number of occurrences' ,
233
- items = x , values = values
267
+ items = xs , values = ys , rotate = len ( self . dictionaries ) > 1
234
268
)
235
269
filename = self .binary_data_dir / 'phoneme_distribution.jpg'
236
270
plt .savefig (fname = filename ,
@@ -239,19 +273,21 @@ def check_coverage(self):
239
273
print (f'| save summary to \' { filename } \' ' )
240
274
241
275
# Check unrecognizable or missing phonemes
242
- if ph_occurred != ph_required :
243
- unrecognizable_phones = ph_occurred .difference (ph_required )
244
- missing_phones = ph_required .difference (ph_occurred )
245
- raise BinarizationError ('transcriptions and dictionary mismatch.\n '
246
- f' (+) { sorted (unrecognizable_phones )} \n '
247
- f' (-) { sorted (missing_phones )} ' )
276
+ if ph_idx_occurred != ph_idx_required :
277
+ missing_phones = sorted ({
278
+ self .phoneme_dictionary .decode_one (idx , scalar = False )
279
+ for idx in ph_idx_required .difference (ph_idx_occurred )
280
+ }, key = lambda v : v [0 ] if isinstance (v , tuple ) else v )
281
+ raise BinarizationError (
282
+ f'The following phonemes are not covered in transcriptions: { missing_phones } '
283
+ )
248
284
249
285
def process_dataset (self , prefix , num_workers = 0 , apply_augmentation = False ):
250
286
args = []
251
287
builder = IndexedDatasetBuilder (self .binary_data_dir , prefix = prefix , allowed_attr = self .data_attrs )
252
288
total_sec = {k : 0.0 for k in self .spk_map }
253
289
total_raw_sec = {k : 0.0 for k in self .spk_map }
254
- extra_info = {'names' : {}, 'spk_ids' : {}, 'spk_names' : {}, 'lengths' : {}}
290
+ extra_info = {'names' : {}, 'ph_texts' : {}, ' spk_ids' : {}, 'spk_names' : {}, 'lengths' : {}}
255
291
max_no = - 1
256
292
257
293
for item_name , meta_data in self .meta_data_iterator (prefix ):
@@ -271,6 +307,7 @@ def postprocess(_item):
271
307
extra_info [k ] = {}
272
308
extra_info [k ][item_no ] = v .shape [0 ]
273
309
extra_info ['names' ][item_no ] = _item ['name' ].split (':' , 1 )[- 1 ]
310
+ extra_info ['ph_texts' ][item_no ] = _item ['ph_text' ]
274
311
extra_info ['spk_ids' ][item_no ] = _item ['spk_id' ]
275
312
extra_info ['spk_names' ][item_no ] = _item ['spk_name' ]
276
313
extra_info ['lengths' ][item_no ] = _item ['length' ]
@@ -287,6 +324,7 @@ def postprocess(_item):
287
324
extra_info [k ] = {}
288
325
extra_info [k ][aug_item_no ] = v .shape [0 ]
289
326
extra_info ['names' ][aug_item_no ] = aug_item ['name' ].split (':' , 1 )[- 1 ]
327
+ extra_info ['ph_texts' ][aug_item_no ] = aug_item ['ph_text' ]
290
328
extra_info ['spk_ids' ][aug_item_no ] = aug_item ['spk_id' ]
291
329
extra_info ['spk_names' ][aug_item_no ] = aug_item ['spk_name' ]
292
330
extra_info ['lengths' ][aug_item_no ] = aug_item ['length' ]
@@ -315,6 +353,7 @@ def postprocess(_item):
315
353
builder .finalize ()
316
354
if prefix == "train" :
317
355
extra_info .pop ("names" )
356
+ extra_info .pop ('ph_texts' )
318
357
extra_info .pop ("spk_names" )
319
358
with open (self .binary_data_dir / f"{ prefix } .meta" , "wb" ) as f :
320
359
# noinspection PyTypeChecker
0 commit comments