Fastspeech2 (#2073)

* added EnergyDataset

* add energy to Dataset

* add comupte_energy

* added energy params

* added energy to forward_tts

* added plot_avg_energy for visualisation

* Update forward_tts.py

* create file

* added fastspeech2 recipe

* add fastspeech2 config

* removed energy from fast pitch

* add energy loss to forward tts

* Update fastspeech2_config.py

* change run_name

* Update numpy_transforms.py

* fix typo

* fix typo

* fix typo

* linting issues

* use_energy default value --> False

* Update numpy_transforms.py

* linting fixes

* fix typo

* liniting_fix

* liniting_fix

* fix

* fixes

* fixes

* lint fix

* lint fixws

* added training test

* wrong import

* wrong import

* trailing whitespace

* style fix

* changed class name because of error

* class name change

* class name change

* change class name

* fixed styles
This commit is contained in:
manmay nakhashi 2023-01-16 03:09:22 +05:30 committed by GitHub
parent 14d45b5347
commit bc422f2f3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 866 additions and 6 deletions

View File

@ -100,6 +100,13 @@ class FastPitchConfig(BaseTTSConfig):
max_seq_len (int): max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage. Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
# dataset configs
compute_f0(bool):
Compute pitch. defaults to True
f0_cache_path(str):
pith cache path. defaults to None
""" """
model: str = "fast_pitch" model: str = "fast_pitch"

View File

@ -0,0 +1,198 @@
from dataclasses import dataclass, field
from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.forward_tts import ForwardTTSArgs
@dataclass
class Fastspeech2Config(BaseTTSConfig):
"""Configure `ForwardTTS` as FastPitch model.
Example:
>>> from TTS.tts.configs.fastspeech2_config import FastSpeech2Config
>>> config = FastSpeech2Config()
Args:
model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
base_model (str):
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
model_args (Coqpit):
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
data_dep_init_steps (int):
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
for the rest. Defaults to 10.
speakers_file (str):
Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to
speaker names. Defaults to `None`.
use_speaker_embedding (bool):
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
in the multi-speaker mode. Defaults to False.
use_d_vector_file (bool):
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
d_vector_file (str):
Path to the file including pre-computed speaker embeddings. Defaults to None.
d_vector_dim (int):
Dimension of the external speaker embeddings. Defaults to 0.
optimizer (str):
Name of the model optimizer. Defaults to `Adam`.
optimizer_params (dict):
Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`.
lr_scheduler (str):
Name of the learning rate scheduler. Defaults to `Noam`.
lr_scheduler_params (dict):
Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`.
lr (float):
Initial learning rate. Defaults to `1e-3`.
grad_clip (float):
Gradient norm clipping value. Defaults to `5.0`.
spec_loss_type (str):
Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
duration_loss_type (str):
Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
use_ssim_loss (bool):
Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True.
wd (float):
Weight decay coefficient. Defaults to `1e-7`.
ssim_loss_alpha (float):
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
dur_loss_alpha (float):
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
spec_loss_alpha (float):
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
pitch_loss_alpha (float):
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
energy_loss_alpha (float):
Weight for the energy predictor's loss. If set 0, disables the energy predictor. Defaults to 1.0.
binary_align_loss_alpha (float):
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
binary_loss_warmup_epochs (float):
Number of epochs to gradually increase the binary loss impact. Defaults to 150.
min_seq_len (int):
Minimum input sequence length to be used at training.
max_seq_len (int):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
# dataset configs
compute_f0(bool):
Compute pitch. defaults to True
f0_cache_path(str):
pith cache path. defaults to None
# dataset configs
compute_energy(bool):
Compute energy. defaults to True
energy_cache_path(str):
energy cache path. defaults to None
"""
model: str = "fastspeech2"
base_model: str = "forward_tts"
# model specific params
model_args: ForwardTTSArgs = ForwardTTSArgs()
# multi-speaker settings
num_speakers: int = 0
speakers_file: str = None
use_speaker_embedding: bool = False
use_d_vector_file: bool = False
d_vector_file: str = False
d_vector_dim: int = 0
# optimizer parameters
optimizer: str = "Adam"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
lr_scheduler: str = "NoamLR"
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
lr: float = 1e-4
grad_clip: float = 5.0
# loss params
spec_loss_type: str = "mse"
duration_loss_type: str = "mse"
use_ssim_loss: bool = True
ssim_loss_alpha: float = 1.0
spec_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0
pitch_loss_alpha: float = 0.1
energy_loss_alpha: float = 0.1
dur_loss_alpha: float = 0.1
binary_align_loss_alpha: float = 0.1
binary_loss_warmup_epochs: int = 150
# overrides
min_seq_len: int = 13
max_seq_len: int = 200
r: int = 1 # DO NOT CHANGE
# dataset configs
compute_f0: bool = True
f0_cache_path: str = None
# dataset configs
compute_energy: bool = True
energy_cache_path: str = None
# testing
test_sentences: List[str] = field(
default_factory=lambda: [
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963.",
]
)
def __post_init__(self):
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
if self.num_speakers > 0:
self.model_args.num_speakers = self.num_speakers
# speaker embedding settings
if self.use_speaker_embedding:
self.model_args.use_speaker_embedding = True
if self.speakers_file:
self.model_args.speakers_file = self.speakers_file
# d-vector settings
if self.use_d_vector_file:
self.model_args.use_d_vector_file = True
if self.d_vector_dim is not None and self.d_vector_dim > 0:
self.model_args.d_vector_dim = self.d_vector_dim
if self.d_vector_file:
self.model_args.d_vector_file = self.d_vector_file

View File

@ -217,6 +217,9 @@ class BaseTTSConfig(BaseTrainingConfig):
compute_f0 (int): compute_f0 (int):
(Not in use yet). (Not in use yet).
compute_energy (int):
(Not in use yet).
compute_linear_spec (bool): compute_linear_spec (bool):
If True data loader computes and returns linear spectrograms alongside the other data. If True data loader computes and returns linear spectrograms alongside the other data.
@ -312,6 +315,7 @@ class BaseTTSConfig(BaseTrainingConfig):
min_text_len: int = 1 min_text_len: int = 1
max_text_len: int = float("inf") max_text_len: int = float("inf")
compute_f0: bool = False compute_f0: bool = False
compute_energy: bool = False
compute_linear_spec: bool = False compute_linear_spec: bool = False
precompute_num_workers: int = 0 precompute_num_workers: int = 0
use_noise_augment: bool = False use_noise_augment: bool = False

View File

@ -11,6 +11,7 @@ from torch.utils.data import Dataset
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy
# to prevent too many open files error as suggested here # to prevent too many open files error as suggested here
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
@ -50,7 +51,9 @@ class TTSDataset(Dataset):
samples: List[Dict] = None, samples: List[Dict] = None,
tokenizer: "TTSTokenizer" = None, tokenizer: "TTSTokenizer" = None,
compute_f0: bool = False, compute_f0: bool = False,
compute_energy: bool = False,
f0_cache_path: str = None, f0_cache_path: str = None,
energy_cache_path: str = None,
return_wav: bool = False, return_wav: bool = False,
batch_group_size: int = 0, batch_group_size: int = 0,
min_text_len: int = 0, min_text_len: int = 0,
@ -84,8 +87,12 @@ class TTSDataset(Dataset):
compute_f0 (bool): compute f0 if True. Defaults to False. compute_f0 (bool): compute f0 if True. Defaults to False.
compute_energy (bool): compute energy if True. Defaults to False.
f0_cache_path (str): Path to store f0 cache. Defaults to None. f0_cache_path (str): Path to store f0 cache. Defaults to None.
energy_cache_path (str): Path to store energy cache. Defaults to None.
return_wav (bool): Return the waveform of the sample. Defaults to False. return_wav (bool): Return the waveform of the sample. Defaults to False.
batch_group_size (int): Range of batch randomization after sorting batch_group_size (int): Range of batch randomization after sorting
@ -128,7 +135,9 @@ class TTSDataset(Dataset):
self.compute_linear_spec = compute_linear_spec self.compute_linear_spec = compute_linear_spec
self.return_wav = return_wav self.return_wav = return_wav
self.compute_f0 = compute_f0 self.compute_f0 = compute_f0
self.compute_energy = compute_energy
self.f0_cache_path = f0_cache_path self.f0_cache_path = f0_cache_path
self.energy_cache_path = energy_cache_path
self.min_audio_len = min_audio_len self.min_audio_len = min_audio_len
self.max_audio_len = max_audio_len self.max_audio_len = max_audio_len
self.min_text_len = min_text_len self.min_text_len = min_text_len
@ -155,7 +164,10 @@ class TTSDataset(Dataset):
self.f0_dataset = F0Dataset( self.f0_dataset = F0Dataset(
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
) )
if compute_energy:
self.energy_dataset = EnergyDataset(
self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers
)
if self.verbose: if self.verbose:
self.print_logs() self.print_logs()
@ -211,6 +223,12 @@ class TTSDataset(Dataset):
assert item["audio_unique_name"] == out_dict["audio_unique_name"] assert item["audio_unique_name"] == out_dict["audio_unique_name"]
return out_dict return out_dict
def get_energy(self, idx):
out_dict = self.energy_dataset[idx]
item = self.samples[idx]
assert item["audio_unique_name"] == out_dict["audio_unique_name"]
return out_dict
@staticmethod @staticmethod
def get_attn_mask(attn_file): def get_attn_mask(attn_file):
return np.load(attn_file) return np.load(attn_file)
@ -252,12 +270,16 @@ class TTSDataset(Dataset):
f0 = None f0 = None
if self.compute_f0: if self.compute_f0:
f0 = self.get_f0(idx)["f0"] f0 = self.get_f0(idx)["f0"]
energy = None
if self.compute_energy:
energy = self.get_energy(idx)["energy"]
sample = { sample = {
"raw_text": raw_text, "raw_text": raw_text,
"token_ids": token_ids, "token_ids": token_ids,
"wav": wav, "wav": wav,
"pitch": f0, "pitch": f0,
"energy": energy,
"attn": attn, "attn": attn,
"item_idx": item["audio_file"], "item_idx": item["audio_file"],
"speaker_name": item["speaker_name"], "speaker_name": item["speaker_name"],
@ -490,7 +512,13 @@ class TTSDataset(Dataset):
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
else: else:
pitch = None pitch = None
# format energy
if self.compute_energy:
energy = prepare_data(batch["energy"])
assert mel.shape[1] == energy.shape[1], f"[!] {mel.shape} vs {energy.shape}"
energy = torch.FloatTensor(energy)[:, None, :].contiguous() # B x 1 xT
else:
energy = None
# format attention masks # format attention masks
attns = None attns = None
if batch["attn"][0] is not None: if batch["attn"][0] is not None:
@ -519,6 +547,7 @@ class TTSDataset(Dataset):
"waveform": wav_padded, "waveform": wav_padded,
"raw_text": batch["raw_text"], "raw_text": batch["raw_text"],
"pitch": pitch, "pitch": pitch,
"energy": energy,
"language_ids": language_ids, "language_ids": language_ids,
"audio_unique_names": batch["audio_unique_name"], "audio_unique_names": batch["audio_unique_name"],
} }
@ -777,3 +806,155 @@ class F0Dataset:
print("\n") print("\n")
print(f"{indent}> F0Dataset ") print(f"{indent}> F0Dataset ")
print(f"{indent}| > Number of instances : {len(self.samples)}") print(f"{indent}| > Number of instances : {len(self.samples)}")
class EnergyDataset:
"""Energy Dataset for computing Energy from wav files in CPU
Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It
also computes the mean and std of Energy values if `normalize_Energy` is True.
Args:
samples (Union[List[List], List[Dict]]):
List of samples. Each sample is a list or a dict.
ap (AudioProcessor):
AudioProcessor to compute Energy from wav files.
cache_path (str):
Path to cache Energy values. If `cache_path` is already present or None, it skips the pre-computation.
Defaults to None.
precompute_num_workers (int):
Number of workers used for pre-computing the Energy values. Defaults to 0.
normalize_Energy (bool):
Whether to normalize Energy values by mean and std. Defaults to True.
"""
def __init__(
self,
samples: Union[List[List], List[Dict]],
ap: "AudioProcessor",
verbose=False,
cache_path: str = None,
precompute_num_workers=0,
normalize_energy=True,
):
self.samples = samples
self.ap = ap
self.verbose = verbose
self.cache_path = cache_path
self.normalize_energy = normalize_energy
self.pad_id = 0.0
self.mean = None
self.std = None
if cache_path is not None and not os.path.exists(cache_path):
os.makedirs(cache_path)
self.precompute(precompute_num_workers)
if normalize_energy:
self.load_stats(cache_path)
def __getitem__(self, idx):
item = self.samples[idx]
energy = self.compute_or_load(item["audio_file"])
if self.normalize_energy:
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
energy = self.normalize(energy)
return {"audio_file": item["audio_file"], "energy": energy}
def __len__(self):
return len(self.samples)
def precompute(self, num_workers=0):
print("[*] Pre-computing energys...")
with tqdm.tqdm(total=len(self)) as pbar:
batch_size = num_workers if num_workers > 0 else 1
# we do not normalize at preproessing
normalize_energy = self.normalize_energy
self.normalize_energy = False
dataloder = torch.utils.data.DataLoader(
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
)
computed_data = []
for batch in dataloder:
energy = batch["energy"]
computed_data.append(e for e in energy)
pbar.update(batch_size)
self.normalize_energy = normalize_energy
if self.normalize_energy:
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
energy_mean, energy_std = self.compute_pitch_stats(computed_data)
energy_stats = {"mean": energy_mean, "std": energy_std}
np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True)
def get_pad_id(self):
return self.pad_id
@staticmethod
def create_energy_file_path(wav_file, cache_path):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
energy_file = os.path.join(cache_path, file_name + "_energy.npy")
return energy_file
@staticmethod
def _compute_and_save_energy(ap, wav_file, energy_file=None):
wav = ap.load_wav(wav_file)
energy = calculate_energy(wav)
if energy_file:
np.save(energy_file, energy)
return energy
@staticmethod
def compute_energy_stats(energy_vecs):
nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in energy_vecs])
mean, std = np.mean(nonzeros), np.std(nonzeros)
return mean, std
def load_stats(self, cache_path):
stats_path = os.path.join(cache_path, "energy_stats.npy")
stats = np.load(stats_path, allow_pickle=True).item()
self.mean = stats["mean"].astype(np.float32)
self.std = stats["std"].astype(np.float32)
def normalize(self, energy):
zero_idxs = np.where(energy == 0.0)[0]
energy = energy - self.mean
energy = energy / self.std
energy[zero_idxs] = 0.0
return energy
def denormalize(self, energy):
zero_idxs = np.where(energy == 0.0)[0]
energy *= self.std
energy += self.mean
energy[zero_idxs] = 0.0
return energy
def compute_or_load(self, wav_file):
"""
compute energy and return a numpy array of energy values
"""
energy_file = self.create_Energy_file_path(wav_file, self.cache_path)
if not os.path.exists(energy_file):
energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)
else:
energy = np.load(energy_file)
return energy.astype(np.float32)
def collate_fn(self, batch):
audio_file = [item["audio_file"] for item in batch]
energys = [item["energy"] for item in batch]
energy_lens = [len(item["energy"]) for item in batch]
energy_lens_max = max(energy_lens)
energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id())
for i, energy_len in enumerate(energy_lens):
energys_torch[i, :energy_len] = torch.LongTensor(energys[i])
return {"audio_file": audio_file, "energy": energys_torch, "energy_lens": energy_lens}
def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
print("\n")
print(f"{indent}> energyDataset ")
print(f"{indent}| > Number of instances : {len(self.samples)}")

View File

@ -801,6 +801,10 @@ class ForwardTTSLoss(nn.Module):
self.pitch_loss = MSELossMasked(False) self.pitch_loss = MSELossMasked(False)
self.pitch_loss_alpha = c.pitch_loss_alpha self.pitch_loss_alpha = c.pitch_loss_alpha
if c.model_args.use_energy:
self.energy_loss = MSELossMasked(False)
self.energy_loss_alpha = c.energy_loss_alpha
if c.use_ssim_loss: if c.use_ssim_loss:
self.ssim = SSIMLoss() if c.use_ssim_loss else None self.ssim = SSIMLoss() if c.use_ssim_loss else None
self.ssim_loss_alpha = c.ssim_loss_alpha self.ssim_loss_alpha = c.ssim_loss_alpha
@ -826,6 +830,8 @@ class ForwardTTSLoss(nn.Module):
dur_target, dur_target,
pitch_output, pitch_output,
pitch_target, pitch_target,
energy_output,
energy_target,
input_lens, input_lens,
alignment_logprob=None, alignment_logprob=None,
alignment_hard=None, alignment_hard=None,
@ -855,6 +861,11 @@ class ForwardTTSLoss(nn.Module):
loss = loss + self.pitch_loss_alpha * pitch_loss loss = loss + self.pitch_loss_alpha * pitch_loss
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
if hasattr(self, "energy_loss") and self.energy_loss_alpha > 0:
energy_loss = self.energy_loss(energy_output.transpose(1, 2), energy_target.transpose(1, 2), input_lens)
loss = loss + self.energy_loss_alpha * energy_loss
return_dict["loss_energy"] = self.energy_loss_alpha * energy_loss
if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0: if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0:
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens) aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
loss = loss + self.aligner_loss_alpha * aligner_loss loss = loss + self.aligner_loss_alpha * aligner_loss

View File

@ -15,7 +15,7 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
@ -42,6 +42,9 @@ class ForwardTTSArgs(Coqpit):
use_pitch (bool): use_pitch (bool):
Use pitch predictor to learn the pitch. Defaults to True. Use pitch predictor to learn the pitch. Defaults to True.
use_energy (bool):
Use energy predictor to learn the energy. Defaults to True.
duration_predictor_hidden_channels (int): duration_predictor_hidden_channels (int):
Number of hidden channels in the duration predictor. Defaults to 256. Number of hidden channels in the duration predictor. Defaults to 256.
@ -63,6 +66,18 @@ class ForwardTTSArgs(Coqpit):
pitch_embedding_kernel_size (int): pitch_embedding_kernel_size (int):
Kernel size of the projection layer in the pitch predictor. Defaults to 3. Kernel size of the projection layer in the pitch predictor. Defaults to 3.
energy_predictor_hidden_channels (int):
Number of hidden channels in the energy predictor. Defaults to 256.
energy_predictor_dropout_p (float):
Dropout rate for the energy predictor. Defaults to 0.1.
energy_predictor_kernel_size (int):
Kernel size of conv layers in the energy predictor. Defaults to 3.
energy_embedding_kernel_size (int):
Kernel size of the projection layer in the energy predictor. Defaults to 3.
positional_encoding (bool): positional_encoding (bool):
Whether to use positional encoding. Defaults to True. Whether to use positional encoding. Defaults to True.
@ -114,14 +129,25 @@ class ForwardTTSArgs(Coqpit):
out_channels: int = 80 out_channels: int = 80
hidden_channels: int = 384 hidden_channels: int = 384
use_aligner: bool = True use_aligner: bool = True
# pitch params
use_pitch: bool = True use_pitch: bool = True
pitch_predictor_hidden_channels: int = 256 pitch_predictor_hidden_channels: int = 256
pitch_predictor_kernel_size: int = 3 pitch_predictor_kernel_size: int = 3
pitch_predictor_dropout_p: float = 0.1 pitch_predictor_dropout_p: float = 0.1
pitch_embedding_kernel_size: int = 3 pitch_embedding_kernel_size: int = 3
# energy params
use_energy: bool = False
energy_predictor_hidden_channels: int = 256
energy_predictor_kernel_size: int = 3
energy_predictor_dropout_p: float = 0.1
energy_embedding_kernel_size: int = 3
# duration params
duration_predictor_hidden_channels: int = 256 duration_predictor_hidden_channels: int = 256
duration_predictor_kernel_size: int = 3 duration_predictor_kernel_size: int = 3
duration_predictor_dropout_p: float = 0.1 duration_predictor_dropout_p: float = 0.1
positional_encoding: bool = True positional_encoding: bool = True
poisitonal_encoding_use_scale: bool = True poisitonal_encoding_use_scale: bool = True
length_scale: int = 1 length_scale: int = 1
@ -158,7 +184,7 @@ class ForwardTTS(BaseTTS):
- FastPitch - FastPitch
- SpeedySpeech - SpeedySpeech
- FastSpeech - FastSpeech
- TODO: FastSpeech2 (requires average speech energy predictor) - FastSpeech2 (requires average speech energy predictor)
Args: Args:
config (Coqpit): Model coqpit class. config (Coqpit): Model coqpit class.
@ -187,6 +213,7 @@ class ForwardTTS(BaseTTS):
self.max_duration = self.args.max_duration self.max_duration = self.args.max_duration
self.use_aligner = self.args.use_aligner self.use_aligner = self.args.use_aligner
self.use_pitch = self.args.use_pitch self.use_pitch = self.args.use_pitch
self.use_energy = self.args.use_energy
self.binary_loss_weight = 0.0 self.binary_loss_weight = 0.0
self.length_scale = ( self.length_scale = (
@ -234,6 +261,20 @@ class ForwardTTS(BaseTTS):
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
) )
if self.args.use_energy:
self.energy_predictor = DurationPredictor(
self.args.hidden_channels + self.embedded_speaker_dim,
self.args.energy_predictor_hidden_channels,
self.args.energy_predictor_kernel_size,
self.args.energy_predictor_dropout_p,
)
self.energy_emb = nn.Conv1d(
1,
self.args.hidden_channels,
kernel_size=self.args.energy_embedding_kernel_size,
padding=int((self.args.energy_embedding_kernel_size - 1) / 2),
)
if self.args.use_aligner: if self.args.use_aligner:
self.aligner = AlignmentNetwork( self.aligner = AlignmentNetwork(
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
@ -440,6 +481,42 @@ class ForwardTTS(BaseTTS):
o_pitch_emb = self.pitch_emb(o_pitch) o_pitch_emb = self.pitch_emb(o_pitch)
return o_pitch_emb, o_pitch return o_pitch_emb, o_pitch
def _forward_energy_predictor(
self,
o_en: torch.FloatTensor,
x_mask: torch.IntTensor,
energy: torch.FloatTensor = None,
dr: torch.IntTensor = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Energy predictor forward pass.
1. Predict energy from encoder outputs.
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
3. Embed average energy values.
Args:
o_en (torch.FloatTensor): Encoder output.
x_mask (torch.IntTensor): Input sequence mask.
energy (torch.FloatTensor, optional): Ground truth energy values. Defaults to None.
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Energy embedding, energy prediction.
Shapes:
- o_en: :math:`(B, C, T_{en})`
- x_mask: :math:`(B, 1, T_{en})`
- pitch: :math:`(B, 1, T_{de})`
- dr: :math:`(B, T_{en})`
"""
o_energy = self.energy_predictor(o_en, x_mask)
if energy is not None:
avg_energy = average_over_durations(energy, dr)
o_energy_emb = self.energy_emb(avg_energy)
return o_energy_emb, o_energy, avg_energy
o_energy_emb = self.energy_emb(o_energy)
return o_energy_emb, o_energy
def _forward_aligner( def _forward_aligner(
self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor
) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
@ -502,6 +579,7 @@ class ForwardTTS(BaseTTS):
y: torch.FloatTensor = None, y: torch.FloatTensor = None,
dr: torch.IntTensor = None, dr: torch.IntTensor = None,
pitch: torch.FloatTensor = None, pitch: torch.FloatTensor = None,
energy: torch.FloatTensor = None,
aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument
) -> Dict: ) -> Dict:
"""Model's forward pass. """Model's forward pass.
@ -513,6 +591,7 @@ class ForwardTTS(BaseTTS):
y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None. y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None.
dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None. dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None.
pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None. pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None.
energy (torch.FloatTensor): energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None.
aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.
Shapes: Shapes:
@ -556,6 +635,12 @@ class ForwardTTS(BaseTTS):
if self.args.use_pitch: if self.args.use_pitch:
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr) o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr)
o_en = o_en + o_pitch_emb o_en = o_en + o_pitch_emb
# energy predictor pass
o_energy = None
avg_energy = None
if self.args.use_energy:
o_energy_emb, o_energy, avg_energy = self._forward_energy_predictor(o_en, x_mask, energy, dr)
o_en = o_en + o_energy_emb
# decoder pass # decoder pass
o_de, attn = self._forward_decoder( o_de, attn = self._forward_decoder(
o_en, dr, x_mask, y_lengths, g=None o_en, dr, x_mask, y_lengths, g=None
@ -567,6 +652,8 @@ class ForwardTTS(BaseTTS):
"attn_durations": o_attn, # for visualization [B, T_en, T_de'] "attn_durations": o_attn, # for visualization [B, T_en, T_de']
"pitch_avg": o_pitch, "pitch_avg": o_pitch,
"pitch_avg_gt": avg_pitch, "pitch_avg_gt": avg_pitch,
"energy_avg": o_energy,
"energy_avg_gt": avg_energy,
"alignments": attn, # [B, T_de, T_en] "alignments": attn, # [B, T_de, T_en]
"alignment_soft": alignment_soft, "alignment_soft": alignment_soft,
"alignment_mas": alignment_mas, "alignment_mas": alignment_mas,
@ -604,12 +691,18 @@ class ForwardTTS(BaseTTS):
if self.args.use_pitch: if self.args.use_pitch:
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask) o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
o_en = o_en + o_pitch_emb o_en = o_en + o_pitch_emb
# energy predictor pass
o_energy = None
if self.args.use_energy:
o_energy_emb, o_energy = self._forward_energy_predictor(o_en, x_mask)
o_en = o_en + o_energy_emb
# decoder pass # decoder pass
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None) o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None)
outputs = { outputs = {
"model_outputs": o_de, "model_outputs": o_de,
"alignments": attn, "alignments": attn,
"pitch": o_pitch, "pitch": o_pitch,
"energy": o_energy,
"durations_log": o_dr_log, "durations_log": o_dr_log,
} }
return outputs return outputs
@ -620,6 +713,7 @@ class ForwardTTS(BaseTTS):
mel_input = batch["mel_input"] mel_input = batch["mel_input"]
mel_lengths = batch["mel_lengths"] mel_lengths = batch["mel_lengths"]
pitch = batch["pitch"] if self.args.use_pitch else None pitch = batch["pitch"] if self.args.use_pitch else None
energy = batch["energy"] if self.args.use_energy else None
d_vectors = batch["d_vectors"] d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"] speaker_ids = batch["speaker_ids"]
durations = batch["durations"] durations = batch["durations"]
@ -627,7 +721,14 @@ class ForwardTTS(BaseTTS):
# forward pass # forward pass
outputs = self.forward( outputs = self.forward(
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input text_input,
text_lengths,
mel_lengths,
y=mel_input,
dr=durations,
pitch=pitch,
energy=energy,
aux_input=aux_input,
) )
# use aligner's output as the duration target # use aligner's output as the duration target
if self.use_aligner: if self.use_aligner:
@ -643,6 +744,8 @@ class ForwardTTS(BaseTTS):
dur_target=durations, dur_target=durations,
pitch_output=outputs["pitch_avg"] if self.use_pitch else None, pitch_output=outputs["pitch_avg"] if self.use_pitch else None,
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None, pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
energy_output=outputs["energy_avg"] if self.use_energy else None,
energy_target=outputs["energy_avg_gt"] if self.use_energy else None,
input_lens=text_lengths, input_lens=text_lengths,
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None, alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
alignment_soft=outputs["alignment_soft"], alignment_soft=outputs["alignment_soft"],
@ -683,6 +786,17 @@ class ForwardTTS(BaseTTS):
} }
figures.update(pitch_figures) figures.update(pitch_figures)
# plot energy figures
if self.args.use_energy:
energy_avg = abs(outputs["energy_avg_gt"][0, 0].data.cpu().numpy())
energy_avg_hat = abs(outputs["energy_avg"][0, 0].data.cpu().numpy())
chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
energy_figures = {
"energy_ground_truth": plot_avg_energy(energy_avg, chars, output_fig=False),
"energy_avg_predicted": plot_avg_energy(energy_avg_hat, chars, output_fig=False),
}
figures.update(energy_figures)
# plot the attention mask computed from the predicted durations # plot the attention mask computed from the predicted durations
if "attn_durations" in outputs: if "attn_durations" in outputs:
alignments_hat = outputs["attn_durations"][0].data.cpu().numpy() alignments_hat = outputs["attn_durations"][0].data.cpu().numpy()

View File

@ -123,6 +123,39 @@ def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False):
return fig return fig
def plot_avg_energy(energy, chars, fig_size=(30, 10), output_fig=False):
"""Plot energy curves on top of the input characters.
Args:
energy (np.array): energy values.
chars (str): Characters to place to the x-axis.
Shapes:
energy: :math:`(T,)`
"""
old_fig_size = plt.rcParams["figure.figsize"]
if fig_size is not None:
plt.rcParams["figure.figsize"] = fig_size
fig, ax = plt.subplots()
x = np.array(range(len(chars)))
my_xticks = chars
plt.xticks(x, my_xticks)
ax.set_xlabel("characters")
ax.set_ylabel("freq")
ax2 = ax.twinx()
ax2.plot(energy, linewidth=5.0, color="red")
ax2.set_ylabel("energy")
plt.rcParams["figure.figsize"] = old_fig_size
if not output_fig:
plt.close()
return fig
def visualize( def visualize(
alignment, alignment,
postnet_output, postnet_output,

View File

@ -4,7 +4,7 @@ import librosa
import numpy as np import numpy as np
import scipy import scipy
import soundfile as sf import soundfile as sf
from librosa import pyin from librosa import magphase, pyin
# For using kwargs # For using kwargs
# pylint: disable=unused-argument # pylint: disable=unused-argument
@ -303,6 +303,27 @@ def compute_f0(
return f0 return f0
def compute_energy(y: np.ndarray, **kwargs) -> np.ndarray:
"""Compute energy of a waveform using the same parameters used for computing melspectrogram.
Args:
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
Returns:
np.ndarray: energy. Shape :math:`[T_energy,]`. :math:`T_energy == T_wav / hop_length`
Examples:
>>> WAV_FILE = filename = librosa.util.example_audio_file()
>>> from TTS.config import BaseAudioConfig
>>> from TTS.utils.audio import AudioProcessor
>>> conf = BaseAudioConfig()
>>> ap = AudioProcessor(**conf)
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
>>> energy = ap.compute_energy(wav)
"""
x = stft(y=y, **kwargs)
mag, _ = magphase(x)
energy = np.sqrt(np.sum(mag**2, axis=0))
return energy
### Audio Processing ### ### Audio Processing ###
def find_endpoint( def find_endpoint(
*, *,

View File

@ -0,0 +1,102 @@
import os
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig
from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.forward_tts import ForwardTTS
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.utils.manage import ModelManager
output_path = os.path.dirname(os.path.abspath(__file__))
# init configs
dataset_config = BaseDatasetConfig(
formatter="ljspeech",
meta_file_train="metadata.csv",
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
path=os.path.join(output_path, "../LJSpeech-1.1/"),
)
audio_config = BaseAudioConfig(
sample_rate=22050,
do_trim_silence=True,
trim_db=60.0,
signal_norm=False,
mel_fmin=0.0,
mel_fmax=8000,
spec_gain=1.0,
log_func="np.log",
ref_level_db=20,
preemphasis=0.0,
)
config = Fastspeech2Config(
run_name="fastspeech2_ljspeech",
audio=audio_config,
batch_size=32,
eval_batch_size=16,
num_loader_workers=8,
num_eval_loader_workers=4,
compute_input_seq_cache=True,
compute_f0=True,
f0_cache_path=os.path.join(output_path, "f0_cache"),
compute_energy=True,
energy_cache_path=os.path.join(output_path, "energy_cache"),
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
precompute_num_workers=4,
print_step=50,
print_eval=False,
mixed_precision=False,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],
)
# compute alignments
if not config.model_args.use_aligner:
manager = ModelManager()
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
# TODO: make compute_attention python callable
os.system(
f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true"
)
# INITIALIZE THE AUDIO PROCESSOR
# Audio processor is used for feature extraction and audio I/O.
# It mainly serves to the dataloader and the training loggers.
ap = AudioProcessor.init_from_config(config)
# INITIALIZE THE TOKENIZER
# Tokenizer is used to convert text to sequences of token IDs.
# If characters are not defined in the config, default characters are passed to the config
tokenizer, config = TTSTokenizer.init_from_config(config)
# LOAD DATA SAMPLES
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
# You can define your custom sample loader returning the list of samples.
# Or define your custom formatter and pass it to the `load_tts_samples`.
# Check `TTS.tts.datasets.load_tts_samples` for more details.
train_samples, eval_samples = load_tts_samples(
dataset_config,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# init the model
model = ForwardTTS(config, ap, tokenizer, speaker_manager=None)
# init the trainer and 🚀
trainer = Trainer(
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
trainer.fit()

View File

@ -0,0 +1,95 @@
import glob
import json
import os
import shutil
from trainer import get_last_checkpoint
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
config_path = os.path.join(get_tests_output_path(), "fast_pitch_speaker_emb_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
audio_config = BaseAudioConfig(
sample_rate=22050,
do_trim_silence=True,
trim_db=60.0,
signal_norm=False,
mel_fmin=0.0,
mel_fmax=8000,
spec_gain=1.0,
log_func="np.log",
ref_level_db=20,
preemphasis=0.0,
)
config = Fastspeech2Config(
audio=audio_config,
batch_size=8,
eval_batch_size=8,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
f0_cache_path="tests/data/ljspeech/f0_cache/",
compute_f0=True,
compute_energy=True,
energy_cache_path="tests/data/ljspeech/f0_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
use_speaker_embedding=True,
test_sentences=[
"Be a voice, not an echo.",
],
)
config.audio.do_trim_silence = True
config.use_speaker_embedding = True
config.model_args.use_speaker_embedding = True
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.formatter ljspeech_test "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# Inference using TTS API
continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
speaker_id = "ljspeech-1"
continue_speakers_path = os.path.join(continue_path, "speakers.json")
# Check integrity of the config
with open(continue_config_path, "r", encoding="utf-8") as f:
config_loaded = json.load(f)
assert config_loaded["characters"] is not None
assert config_loaded["output_path"] in continue_path
assert config_loaded["test_delay_epochs"] == 0
# Load the model and run inference
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
run_cli(inference_command)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)

View File

@ -0,0 +1,94 @@
import glob
import json
import os
import shutil
from trainer import get_last_checkpoint
from tests import get_device_id, get_tests_output_path, run_cli
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
audio_config = BaseAudioConfig(
sample_rate=22050,
do_trim_silence=True,
trim_db=60.0,
signal_norm=False,
mel_fmin=0.0,
mel_fmax=8000,
spec_gain=1.0,
log_func="np.log",
ref_level_db=20,
preemphasis=0.0,
)
config = Fastspeech2Config(
audio=audio_config,
batch_size=8,
eval_batch_size=8,
num_loader_workers=0,
num_eval_loader_workers=0,
text_cleaner="english_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
f0_cache_path="tests/data/ljspeech/f0_cache/",
compute_f0=True,
compute_energy=True,
energy_cache_path="tests/data/ljspeech/f0_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
print_step=1,
print_eval=True,
test_sentences=[
"Be a voice, not an echo.",
],
use_speaker_embedding=False,
)
config.audio.do_trim_silence = True
config.use_speaker_embedding = False
config.model_args.use_speaker_embedding = False
config.audio.trim_db = 60
config.save_json(config_path)
# train the model for one epoch
command_train = (
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
f"--coqpit.output_path {output_path} "
"--coqpit.datasets.0.formatter ljspeech "
"--coqpit.datasets.0.meta_file_train metadata.csv "
"--coqpit.datasets.0.meta_file_val metadata.csv "
"--coqpit.datasets.0.path tests/data/ljspeech "
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
"--coqpit.test_delay_epochs 0"
)
run_cli(command_train)
# Find latest folder
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
# Inference using TTS API
continue_config_path = os.path.join(continue_path, "config.json")
continue_restore_path, _ = get_last_checkpoint(continue_path)
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
# Check integrity of the config
with open(continue_config_path, "r", encoding="utf-8") as f:
config_loaded = json.load(f)
assert config_loaded["characters"] is not None
assert config_loaded["output_path"] in continue_path
assert config_loaded["test_delay_epochs"] == 0
# Load the model and run inference
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
run_cli(inference_command)
# restore the model and continue training for one more epoch
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
run_cli(command_train)
shutil.rmtree(continue_path)