diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 5828411c..b7424698 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -46,6 +46,7 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False sample_from_storage_p=c.storage["sample_from_storage_p"], verbose=verbose, augmentation_config=c.audio_augmentation, + use_torch_spec=c.model_params.get("use_torch_spec", False), ) # sampler = DistributedSampler(dataset) if num_gpus > 1 else None diff --git a/TTS/speaker_encoder/dataset.py b/TTS/speaker_encoder/dataset.py index 28a23e2f..07fa9246 100644 --- a/TTS/speaker_encoder/dataset.py +++ b/TTS/speaker_encoder/dataset.py @@ -20,6 +20,7 @@ class SpeakerEncoderDataset(Dataset): skip_speakers=False, verbose=False, augmentation_config=None, + use_torch_spec=None, ): """ Args: @@ -37,6 +38,7 @@ class SpeakerEncoderDataset(Dataset): self.skip_speakers = skip_speakers self.ap = ap self.verbose = verbose + self.use_torch_spec = use_torch_spec self.__parse_items() storage_max_size = storage_size * num_speakers_in_batch self.storage = Storage( @@ -72,22 +74,6 @@ class SpeakerEncoderDataset(Dataset): audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) return audio - def load_data(self, idx): - text, wav_file, speaker_name = self.items[idx] - wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) - mel = self.ap.melspectrogram(wav).astype("float32") - # sample seq_len - - assert text.size > 0, self.items[idx]["audio_file"] - assert wav.size > 0, self.items[idx]["audio_file"] - - sample = { - "mel": mel, - "item_idx": self.items[idx]["audio_file"], - "speaker_name": speaker_name, - } - return sample - def __parse_items(self): self.speaker_to_utters = {} for i in self.items: @@ -241,8 +227,12 @@ class SpeakerEncoderDataset(Dataset): self.gaussian_augmentation_config["max_amplitude"], size=len(wav), ) - mel = self.ap.melspectrogram(wav) - feats_.append(torch.FloatTensor(mel)) + + if not self.use_torch_spec: + mel = self.ap.melspectrogram(wav) + feats_.append(torch.FloatTensor(mel)) + else: + feats_.append(torch.FloatTensor(wav)) labels.append(torch.LongTensor(labels_)) feats.extend(feats_) diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index fa8d79bc..ac3080c3 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -334,21 +334,21 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic return items -def vctk_old(root_path, meta_files=None, wavs_path="wav48"): +def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None): """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" - test_speakers = meta_files items = [] meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) for meta_file in meta_files: _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) file_id = txt_file.split(".")[0] - if isinstance(test_speakers, list): # if is list ignore this speakers ids - if speaker_id in test_speakers: + # ignore speakers + if isinstance(ignored_speakers, list): + if speaker_id in ignored_speakers: continue with open(meta_file, "r", encoding="utf-8") as file_text: text = file_text.readlines()[0] wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") - items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id}) return items