Fix audio_config handling

This commit is contained in:
Eren Gölge 2022-04-22 12:50:10 +02:00 committed by Eren G??lge
parent b3fb0e19e8
commit a05c82f9ef
3 changed files with 14 additions and 16 deletions

View File

@ -102,7 +102,7 @@ class FastPitchE2eConfig(BaseTTSConfig):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
"""
model: str = "fast_pitch_e2e_hifigan"
model: str = "fast_pitch_e2e"
base_model: str = "forward_tts_e2e"
# model specific params

View File

@ -9,7 +9,7 @@ import tqdm
from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.utils.audio.numpy_transforms import load_wav, wav_to_mel, wav_to_spec
from TTS.utils.audio.numpy_transforms import compute_f0, load_wav, wav_to_mel, wav_to_spec
# to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
@ -647,14 +647,14 @@ class F0Dataset:
def __init__(
self,
samples: Union[List[List], List[Dict]],
ap: "AudioProcessor",
audio_config: "AudioConfig",
verbose=False,
cache_path: str = None,
precompute_num_workers=0,
normalize_f0=True,
):
self.samples = samples
self.audio_config = ap
self.audio_config = audio_config
self.verbose = verbose
self.cache_path = cache_path
self.normalize_f0 = normalize_f0
@ -711,9 +711,9 @@ class F0Dataset:
return pitch_file
@staticmethod
def _compute_and_save_pitch(ap, wav_file, pitch_file=None):
wav = ap.load_wav(wav_file)
pitch = ap.compute_f0(wav)
def _compute_and_save_pitch(audio_config, wav_file, pitch_file=None):
wav = load_wav(wav_file)
pitch = compute_f0(x=wav, pitch_fmax=audio_config.pitch_fmax, hop_length=audio_config.hop_length, sample_rate=audio_config.sample_rate)
if pitch_file:
np.save(pitch_file, pitch)
return pitch

View File

@ -76,10 +76,9 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
precompute_num_workers=0,
normalize_f0=True,
):
self.audio_config = audio_config
super().__init__(
samples=samples,
ap=None,
audio_config=audio_config,
verbose=verbose,
cache_path=cache_path,
precompute_num_workers=precompute_num_workers,
@ -87,13 +86,13 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
)
@staticmethod
def _compute_and_save_pitch(config, wav_file, pitch_file=None):
def _compute_and_save_pitch(audio_config, wav_file, pitch_file=None):
wav, _ = load_audio(wav_file)
f0 = compute_f0(
x=wav.numpy()[0], sample_rate=config.sample_rate, hop_length=config.hop_length, pitch_fmax=config.pitch_fmax
x=wav.numpy()[0], sample_rate=audio_config.sample_rate, hop_length=audio_config.hop_length, pitch_fmax=audio_config.pitch_fmax
)
# skip the last F0 value to align with the spectrogram
if wav.shape[1] % config.hop_length != 0:
if wav.shape[1] % audio_config.hop_length != 0:
f0 = f0[:-1]
if pitch_file:
np.save(pitch_file, f0)
@ -105,7 +104,7 @@ class ForwardTTSE2eF0Dataset(F0Dataset):
"""
pitch_file = self.create_pitch_file_path(wav_file, self.cache_path)
if not os.path.exists(pitch_file):
pitch = self._compute_and_save_pitch(self.audio_config, wav_file, pitch_file)
pitch = self._compute_and_save_pitch(audio_config=self.audio_config, wav_file=wav_file, pitch_file=pitch_file)
else:
pitch = np.load(pitch_file)
return pitch.astype(np.float32)
@ -117,13 +116,12 @@ class ForwardTTSE2eDataset(TTSDataset):
compute_f0 = kwargs.pop("compute_f0", False)
kwargs["compute_f0"] = False
self.audio_config = kwargs["audio_config"]
del kwargs["audio_config"]
super().__init__(*args, **kwargs)
self.compute_f0 = compute_f0
self.pad_id = self.tokenizer.characters.pad_id
self.audio_config = kwargs["audio_config"]
if self.compute_f0:
self.f0_dataset = ForwardTTSE2eF0Dataset(
audio_config=self.audio_config,