REBASED: Add support for the speaker encoder training using torch spectrograms (#1348)

* Add support for the speaker encoder training using torch spectrograms

* Remove useless function in speaker encoder dataset class
This commit is contained in:
Edresson Casanova 2022-03-10 10:54:51 -03:00 committed by GitHub
parent 07d96f7991
commit f381e29b91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 23 deletions

View File

@ -46,6 +46,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
sample_from_storage_p=c.storage["sample_from_storage_p"],
verbose=verbose,
augmentation_config=c.audio_augmentation,
use_torch_spec=c.model_params.get("use_torch_spec", False),
)
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None

View File

@ -20,6 +20,7 @@ class SpeakerEncoderDataset(Dataset):
skip_speakers=False,
verbose=False,
augmentation_config=None,
use_torch_spec=None,
):
"""
Args:
@ -37,6 +38,7 @@ class SpeakerEncoderDataset(Dataset):
self.skip_speakers = skip_speakers
self.ap = ap
self.verbose = verbose
self.use_torch_spec = use_torch_spec
self.__parse_items()
storage_max_size = storage_size * num_speakers_in_batch
self.storage = Storage(
@ -72,22 +74,6 @@ class SpeakerEncoderDataset(Dataset):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
return audio
def load_data(self, idx):
text, wav_file, speaker_name = self.items[idx]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
mel = self.ap.melspectrogram(wav).astype("float32")
# sample seq_len
assert text.size > 0, self.items[idx]["audio_file"]
assert wav.size > 0, self.items[idx]["audio_file"]
sample = {
"mel": mel,
"item_idx": self.items[idx]["audio_file"],
"speaker_name": speaker_name,
}
return sample
def __parse_items(self):
self.speaker_to_utters = {}
for i in self.items:
@ -241,8 +227,12 @@ class SpeakerEncoderDataset(Dataset):
self.gaussian_augmentation_config["max_amplitude"],
size=len(wav),
)
mel = self.ap.melspectrogram(wav)
feats_.append(torch.FloatTensor(mel))
if not self.use_torch_spec:
mel = self.ap.melspectrogram(wav)
feats_.append(torch.FloatTensor(mel))
else:
feats_.append(torch.FloatTensor(wav))
labels.append(torch.LongTensor(labels_))
feats.extend(feats_)

View File

@ -334,21 +334,21 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
return items
def vctk_old(root_path, meta_files=None, wavs_path="wav48"):
def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
test_speakers = meta_files
items = []
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for meta_file in meta_files:
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
file_id = txt_file.split(".")[0]
if isinstance(test_speakers, list): # if is list ignore this speakers ids
if speaker_id in test_speakers:
# ignore speakers
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id})
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id})
return items