add energy by default to Fastspeech2 config (#2326)

* add energy by default

* added energy to base tts

* fix energy dataset

* fix styles

* fix test
This commit is contained in:
manmay nakhashi 2023-03-06 14:50:25 +05:30 committed by GitHub
parent 478c8178b8
commit 624513018d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 17 additions and 11 deletions

View File

@ -123,7 +123,7 @@ class Fastspeech2Config(BaseTTSConfig):
base_model: str = "forward_tts"
# model specific params
model_args: ForwardTTSArgs = ForwardTTSArgs()
model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=True, use_energy=True)
# multi-speaker settings
num_speakers: int = 0

View File

@ -189,6 +189,8 @@ class TTSDataset(Dataset):
self._samples = new_samples
if hasattr(self, "f0_dataset"):
self.f0_dataset.samples = new_samples
if hasattr(self, "energy_dataset"):
self.energy_dataset.samples = new_samples
if hasattr(self, "phoneme_dataset"):
self.phoneme_dataset.samples = new_samples
@ -856,11 +858,11 @@ class EnergyDataset:
def __getitem__(self, idx):
item = self.samples[idx]
energy = self.compute_or_load(item["audio_file"])
energy = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"]))
if self.normalize_energy:
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
energy = self.normalize(energy)
return {"audio_file": item["audio_file"], "energy": energy}
return {"audio_unique_name": item["audio_unique_name"], "energy": energy}
def __len__(self):
return len(self.samples)
@ -884,7 +886,7 @@ class EnergyDataset:
if self.normalize_energy:
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
energy_mean, energy_std = self.compute_pitch_stats(computed_data)
energy_mean, energy_std = self.compute_energy_stats(computed_data)
energy_stats = {"mean": energy_mean, "std": energy_std}
np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True)
@ -900,7 +902,7 @@ class EnergyDataset:
@staticmethod
def _compute_and_save_energy(ap, wav_file, energy_file=None):
wav = ap.load_wav(wav_file)
energy = calculate_energy(wav)
energy = calculate_energy(wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length)
if energy_file:
np.save(energy_file, energy)
return energy
@ -931,11 +933,11 @@ class EnergyDataset:
energy[zero_idxs] = 0.0
return energy
def compute_or_load(self, wav_file):
def compute_or_load(self, wav_file, audio_unique_name):
"""
compute energy and return a numpy array of energy values
"""
energy_file = self.create_Energy_file_path(wav_file, self.cache_path)
energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path)
if not os.path.exists(energy_file):
energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)
else:
@ -943,14 +945,14 @@ class EnergyDataset:
return energy.astype(np.float32)
def collate_fn(self, batch):
audio_file = [item["audio_file"] for item in batch]
audio_unique_name = [item["audio_unique_name"] for item in batch]
energys = [item["energy"] for item in batch]
energy_lens = [len(item["energy"]) for item in batch]
energy_lens_max = max(energy_lens)
energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id())
for i, energy_len in enumerate(energy_lens):
energys_torch[i, :energy_len] = torch.LongTensor(energys[i])
return {"audio_file": audio_file, "energy": energys_torch, "energy_lens": energy_lens}
return {"audio_unique_name": audio_unique_name, "energy": energys_torch, "energy_lens": energy_lens}
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level

View File

@ -183,6 +183,7 @@ class BaseTTS(BaseTrainerModel):
attn_mask = batch["attns"]
waveform = batch["waveform"]
pitch = batch["pitch"]
energy = batch["energy"]
language_ids = batch["language_ids"]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
@ -231,6 +232,7 @@ class BaseTTS(BaseTrainerModel):
"item_idx": item_idx,
"waveform": waveform,
"pitch": pitch,
"energy": energy,
"language_ids": language_ids,
"audio_unique_names": batch["audio_unique_names"],
}
@ -313,6 +315,8 @@ class BaseTTS(BaseTrainerModel):
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
compute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None),
compute_energy=config.get("compute_energy", False),
energy_cache_path=config.get("energy_cache_path", None),
samples=samples,
ap=self.ap,
return_wav=config.return_wav if "return_wav" in config else False,

View File

@ -38,7 +38,7 @@ config = Fastspeech2Config(
f0_cache_path="tests/data/ljspeech/f0_cache/",
compute_f0=True,
compute_energy=True,
energy_cache_path="tests/data/ljspeech/f0_cache/",
energy_cache_path="tests/data/ljspeech/energy_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,

View File

@ -38,7 +38,7 @@ config = Fastspeech2Config(
f0_cache_path="tests/data/ljspeech/f0_cache/",
compute_f0=True,
compute_energy=True,
energy_cache_path="tests/data/ljspeech/f0_cache/",
energy_cache_path="tests/data/ljspeech/energy_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,