mirror of https://github.com/coqui-ai/TTS.git
add energy by default to Fastspeech2 config (#2326)
* add energy by default * added energy to base tts * fix energy dataset * fix styles * fix test
This commit is contained in:
parent
478c8178b8
commit
624513018d
|
@ -123,7 +123,7 @@ class Fastspeech2Config(BaseTTSConfig):
|
|||
base_model: str = "forward_tts"
|
||||
|
||||
# model specific params
|
||||
model_args: ForwardTTSArgs = ForwardTTSArgs()
|
||||
model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=True, use_energy=True)
|
||||
|
||||
# multi-speaker settings
|
||||
num_speakers: int = 0
|
||||
|
|
|
@ -189,6 +189,8 @@ class TTSDataset(Dataset):
|
|||
self._samples = new_samples
|
||||
if hasattr(self, "f0_dataset"):
|
||||
self.f0_dataset.samples = new_samples
|
||||
if hasattr(self, "energy_dataset"):
|
||||
self.energy_dataset.samples = new_samples
|
||||
if hasattr(self, "phoneme_dataset"):
|
||||
self.phoneme_dataset.samples = new_samples
|
||||
|
||||
|
@ -856,11 +858,11 @@ class EnergyDataset:
|
|||
|
||||
def __getitem__(self, idx):
|
||||
item = self.samples[idx]
|
||||
energy = self.compute_or_load(item["audio_file"])
|
||||
energy = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"]))
|
||||
if self.normalize_energy:
|
||||
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
|
||||
energy = self.normalize(energy)
|
||||
return {"audio_file": item["audio_file"], "energy": energy}
|
||||
return {"audio_unique_name": item["audio_unique_name"], "energy": energy}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
@ -884,7 +886,7 @@ class EnergyDataset:
|
|||
|
||||
if self.normalize_energy:
|
||||
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
|
||||
energy_mean, energy_std = self.compute_pitch_stats(computed_data)
|
||||
energy_mean, energy_std = self.compute_energy_stats(computed_data)
|
||||
energy_stats = {"mean": energy_mean, "std": energy_std}
|
||||
np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True)
|
||||
|
||||
|
@ -900,7 +902,7 @@ class EnergyDataset:
|
|||
@staticmethod
|
||||
def _compute_and_save_energy(ap, wav_file, energy_file=None):
|
||||
wav = ap.load_wav(wav_file)
|
||||
energy = calculate_energy(wav)
|
||||
energy = calculate_energy(wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length)
|
||||
if energy_file:
|
||||
np.save(energy_file, energy)
|
||||
return energy
|
||||
|
@ -931,11 +933,11 @@ class EnergyDataset:
|
|||
energy[zero_idxs] = 0.0
|
||||
return energy
|
||||
|
||||
def compute_or_load(self, wav_file):
|
||||
def compute_or_load(self, wav_file, audio_unique_name):
|
||||
"""
|
||||
compute energy and return a numpy array of energy values
|
||||
"""
|
||||
energy_file = self.create_Energy_file_path(wav_file, self.cache_path)
|
||||
energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path)
|
||||
if not os.path.exists(energy_file):
|
||||
energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)
|
||||
else:
|
||||
|
@ -943,14 +945,14 @@ class EnergyDataset:
|
|||
return energy.astype(np.float32)
|
||||
|
||||
def collate_fn(self, batch):
|
||||
audio_file = [item["audio_file"] for item in batch]
|
||||
audio_unique_name = [item["audio_unique_name"] for item in batch]
|
||||
energys = [item["energy"] for item in batch]
|
||||
energy_lens = [len(item["energy"]) for item in batch]
|
||||
energy_lens_max = max(energy_lens)
|
||||
energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id())
|
||||
for i, energy_len in enumerate(energy_lens):
|
||||
energys_torch[i, :energy_len] = torch.LongTensor(energys[i])
|
||||
return {"audio_file": audio_file, "energy": energys_torch, "energy_lens": energy_lens}
|
||||
return {"audio_unique_name": audio_unique_name, "energy": energys_torch, "energy_lens": energy_lens}
|
||||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
|
|
|
@ -183,6 +183,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
attn_mask = batch["attns"]
|
||||
waveform = batch["waveform"]
|
||||
pitch = batch["pitch"]
|
||||
energy = batch["energy"]
|
||||
language_ids = batch["language_ids"]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
@ -231,6 +232,7 @@ class BaseTTS(BaseTrainerModel):
|
|||
"item_idx": item_idx,
|
||||
"waveform": waveform,
|
||||
"pitch": pitch,
|
||||
"energy": energy,
|
||||
"language_ids": language_ids,
|
||||
"audio_unique_names": batch["audio_unique_names"],
|
||||
}
|
||||
|
@ -313,6 +315,8 @@ class BaseTTS(BaseTrainerModel):
|
|||
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||
compute_f0=config.get("compute_f0", False),
|
||||
f0_cache_path=config.get("f0_cache_path", None),
|
||||
compute_energy=config.get("compute_energy", False),
|
||||
energy_cache_path=config.get("energy_cache_path", None),
|
||||
samples=samples,
|
||||
ap=self.ap,
|
||||
return_wav=config.return_wav if "return_wav" in config else False,
|
||||
|
|
|
@ -38,7 +38,7 @@ config = Fastspeech2Config(
|
|||
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
compute_f0=True,
|
||||
compute_energy=True,
|
||||
energy_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
energy_cache_path="tests/data/ljspeech/energy_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
|
|
|
@ -38,7 +38,7 @@ config = Fastspeech2Config(
|
|||
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
compute_f0=True,
|
||||
compute_energy=True,
|
||||
energy_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
energy_cache_path="tests/data/ljspeech/energy_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
|
|
Loading…
Reference in New Issue