Fix audio_config handling

This commit is contained in:
Eren Gölge 2022-04-22 12:50:10 +02:00 committed by Eren G??lge
parent b3fb0e19e8
commit a05c82f9ef
3 changed files with 14 additions and 16 deletions

View File

@ -102,7 +102,7 @@ class FastPitchE2eConfig(BaseTTSConfig):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage. Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
""" """
model: str = "fast_pitch_e2e_hifigan" model: str = "fast_pitch_e2e"
base_model: str = "forward_tts_e2e" base_model: str = "forward_tts_e2e"
# model specific params # model specific params

View File

@ -9,7 +9,7 @@ import tqdm
from torch.utils.data import Dataset from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.utils.audio.numpy_transforms import load_wav, wav_to_mel, wav_to_spec from TTS.utils.audio.numpy_transforms import compute_f0, load_wav, wav_to_mel, wav_to_spec
# to prevent too many open files error as suggested here # to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
@ -647,14 +647,14 @@ class F0Dataset:
def __init__( def __init__(
self, self,
samples: Union[List[List], List[Dict]], samples: Union[List[List], List[Dict]],
ap: "AudioProcessor", audio_config: "AudioConfig",
verbose=False, verbose=False,
cache_path: str = None, cache_path: str = None,
precompute_num_workers=0, precompute_num_workers=0,
normalize_f0=True, normalize_f0=True,
): ):
self.samples = samples self.samples = samples
self.audio_config = ap self.audio_config = audio_config
self.verbose = verbose self.verbose = verbose
self.cache_path = cache_path self.cache_path = cache_path
self.normalize_f0 = normalize_f0 self.normalize_f0 = normalize_f0
@ -711,9 +711,9 @@ class F0Dataset:
return pitch_file return pitch_file
@staticmethod @staticmethod
def _compute_and_save_pitch(ap, wav_file, pitch_file=None): def _compute_and_save_pitch(audio_config, wav_file, pitch_file=None):
wav = ap.load_wav(wav_file) wav = load_wav(wav_file)
pitch = ap.compute_f0(wav) pitch = compute_f0(x=wav, pitch_fmax=audio_config.pitch_fmax, hop_length=audio_config.hop_length, sample_rate=audio_config.sample_rate)
if pitch_file: if pitch_file:
np.save(pitch_file, pitch) np.save(pitch_file, pitch)
return pitch return pitch

View File

@ -76,10 +76,9 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
precompute_num_workers=0, precompute_num_workers=0,
normalize_f0=True, normalize_f0=True,
): ):
self.audio_config = audio_config
super().__init__( super().__init__(
samples=samples, samples=samples,
ap=None, audio_config=audio_config,
verbose=verbose, verbose=verbose,
cache_path=cache_path, cache_path=cache_path,
precompute_num_workers=precompute_num_workers, precompute_num_workers=precompute_num_workers,
@ -87,13 +86,13 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
) )
@staticmethod @staticmethod
def _compute_and_save_pitch(config, wav_file, pitch_file=None): def _compute_and_save_pitch(audio_config, wav_file, pitch_file=None):
wav, _ = load_audio(wav_file) wav, _ = load_audio(wav_file)
f0 = compute_f0( f0 = compute_f0(
x=wav.numpy()[0], sample_rate=config.sample_rate, hop_length=config.hop_length, pitch_fmax=config.pitch_fmax x=wav.numpy()[0], sample_rate=audio_config.sample_rate, hop_length=audio_config.hop_length, pitch_fmax=audio_config.pitch_fmax
) )
# skip the last F0 value to align with the spectrogram # skip the last F0 value to align with the spectrogram
if wav.shape[1] % config.hop_length != 0: if wav.shape[1] % audio_config.hop_length != 0:
f0 = f0[:-1] f0 = f0[:-1]
if pitch_file: if pitch_file:
np.save(pitch_file, f0) np.save(pitch_file, f0)
@ -105,7 +104,7 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
""" """
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path) pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
if not os.path.exists(pitch_file): if not os.path.exists(pitch_file):
pitch = self._compute_and_save_pitch(self.audio_config, wav_file, pitch_file) pitch = self._compute_and_save_pitch(audio_config=self.audio_config, wav_file=wav_file, pitch_file=pitch_file)
else: else:
pitch = np.load(pitch_file) pitch = np.load(pitch_file)
return pitch.astype(np.float32) return pitch.astype(np.float32)
@ -117,13 +116,12 @@ class ForwardTTSE2eDataset(TTSDataset):
compute_f0 = kwargs.pop("compute_f0", False) compute_f0 = kwargs.pop("compute_f0", False)
kwargs["compute_f0"] = False kwargs["compute_f0"] = False
self.audio_config = kwargs["audio_config"]
del kwargs["audio_config"]
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.compute_f0 = compute_f0 self.compute_f0 = compute_f0
self.pad_id = self.tokenizer.characters.pad_id self.pad_id = self.tokenizer.characters.pad_id
self.audio_config = kwargs["audio_config"]
if self.compute_f0: if self.compute_f0:
self.f0_dataset = ForwardTTSE2eF0Dataset( self.f0_dataset = ForwardTTSE2eF0Dataset(
audio_config=self.audio_config, audio_config=self.audio_config,