mirror of https://github.com/coqui-ai/TTS.git
Fastspeech2 (#2073)
* added EnergyDataset * add energy to Dataset * add comupte_energy * added energy params * added energy to forward_tts * added plot_avg_energy for visualisation * Update forward_tts.py * create file * added fastspeech2 recipe * add fastspeech2 config * removed energy from fast pitch * add energy loss to forward tts * Update fastspeech2_config.py * change run_name * Update numpy_transforms.py * fix typo * fix typo * fix typo * linting issues * use_energy default value --> False * Update numpy_transforms.py * linting fixes * fix typo * liniting_fix * liniting_fix * fix * fixes * fixes * lint fix * lint fixws * added training test * wrong import * wrong import * trailing whitespace * style fix * changed class name because of error * class name change * class name change * change class name * fixed styles
This commit is contained in:
parent
14d45b5347
commit
bc422f2f3c
|
@ -100,6 +100,13 @@ class FastPitchConfig(BaseTTSConfig):
|
|||
|
||||
max_seq_len (int):
|
||||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||
|
||||
# dataset configs
|
||||
compute_f0(bool):
|
||||
Compute pitch. defaults to True
|
||||
|
||||
f0_cache_path(str):
|
||||
pith cache path. defaults to None
|
||||
"""
|
||||
|
||||
model: str = "fast_pitch"
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||
from TTS.tts.models.forward_tts import ForwardTTSArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fastspeech2Config(BaseTTSConfig):
|
||||
"""Configure `ForwardTTS` as FastPitch model.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.tts.configs.fastspeech2_config import FastSpeech2Config
|
||||
>>> config = FastSpeech2Config()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
|
||||
|
||||
base_model (str):
|
||||
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
|
||||
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
|
||||
|
||||
model_args (Coqpit):
|
||||
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
|
||||
|
||||
data_dep_init_steps (int):
|
||||
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
|
||||
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
|
||||
for the rest. Defaults to 10.
|
||||
|
||||
speakers_file (str):
|
||||
Path to the file containing the list of speakers. Needed at inference for loading matching speaker ids to
|
||||
speaker names. Defaults to `None`.
|
||||
|
||||
use_speaker_embedding (bool):
|
||||
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
|
||||
in the multi-speaker mode. Defaults to False.
|
||||
|
||||
use_d_vector_file (bool):
|
||||
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
||||
|
||||
d_vector_file (str):
|
||||
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||
|
||||
d_vector_dim (int):
|
||||
Dimension of the external speaker embeddings. Defaults to 0.
|
||||
|
||||
optimizer (str):
|
||||
Name of the model optimizer. Defaults to `Adam`.
|
||||
|
||||
optimizer_params (dict):
|
||||
Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`.
|
||||
|
||||
lr_scheduler (str):
|
||||
Name of the learning rate scheduler. Defaults to `Noam`.
|
||||
|
||||
lr_scheduler_params (dict):
|
||||
Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`.
|
||||
|
||||
lr (float):
|
||||
Initial learning rate. Defaults to `1e-3`.
|
||||
|
||||
grad_clip (float):
|
||||
Gradient norm clipping value. Defaults to `5.0`.
|
||||
|
||||
spec_loss_type (str):
|
||||
Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
|
||||
|
||||
duration_loss_type (str):
|
||||
Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
|
||||
|
||||
use_ssim_loss (bool):
|
||||
Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True.
|
||||
|
||||
wd (float):
|
||||
Weight decay coefficient. Defaults to `1e-7`.
|
||||
|
||||
ssim_loss_alpha (float):
|
||||
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
|
||||
|
||||
dur_loss_alpha (float):
|
||||
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
|
||||
|
||||
spec_loss_alpha (float):
|
||||
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
|
||||
|
||||
pitch_loss_alpha (float):
|
||||
Weight for the pitch predictor's loss. If set 0, disables the pitch predictor. Defaults to 1.0.
|
||||
|
||||
energy_loss_alpha (float):
|
||||
Weight for the energy predictor's loss. If set 0, disables the energy predictor. Defaults to 1.0.
|
||||
|
||||
binary_align_loss_alpha (float):
|
||||
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
|
||||
|
||||
binary_loss_warmup_epochs (float):
|
||||
Number of epochs to gradually increase the binary loss impact. Defaults to 150.
|
||||
|
||||
min_seq_len (int):
|
||||
Minimum input sequence length to be used at training.
|
||||
|
||||
max_seq_len (int):
|
||||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||
|
||||
# dataset configs
|
||||
compute_f0(bool):
|
||||
Compute pitch. defaults to True
|
||||
|
||||
f0_cache_path(str):
|
||||
pith cache path. defaults to None
|
||||
|
||||
# dataset configs
|
||||
compute_energy(bool):
|
||||
Compute energy. defaults to True
|
||||
|
||||
energy_cache_path(str):
|
||||
energy cache path. defaults to None
|
||||
"""
|
||||
|
||||
model: str = "fastspeech2"
|
||||
base_model: str = "forward_tts"
|
||||
|
||||
# model specific params
|
||||
model_args: ForwardTTSArgs = ForwardTTSArgs()
|
||||
|
||||
# multi-speaker settings
|
||||
num_speakers: int = 0
|
||||
speakers_file: str = None
|
||||
use_speaker_embedding: bool = False
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_file: str = False
|
||||
d_vector_dim: int = 0
|
||||
|
||||
# optimizer parameters
|
||||
optimizer: str = "Adam"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||
lr_scheduler: str = "NoamLR"
|
||||
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
|
||||
lr: float = 1e-4
|
||||
grad_clip: float = 5.0
|
||||
|
||||
# loss params
|
||||
spec_loss_type: str = "mse"
|
||||
duration_loss_type: str = "mse"
|
||||
use_ssim_loss: bool = True
|
||||
ssim_loss_alpha: float = 1.0
|
||||
spec_loss_alpha: float = 1.0
|
||||
aligner_loss_alpha: float = 1.0
|
||||
pitch_loss_alpha: float = 0.1
|
||||
energy_loss_alpha: float = 0.1
|
||||
dur_loss_alpha: float = 0.1
|
||||
binary_align_loss_alpha: float = 0.1
|
||||
binary_loss_warmup_epochs: int = 150
|
||||
|
||||
# overrides
|
||||
min_seq_len: int = 13
|
||||
max_seq_len: int = 200
|
||||
r: int = 1 # DO NOT CHANGE
|
||||
|
||||
# dataset configs
|
||||
compute_f0: bool = True
|
||||
f0_cache_path: str = None
|
||||
|
||||
# dataset configs
|
||||
compute_energy: bool = True
|
||||
energy_cache_path: str = None
|
||||
|
||||
# testing
|
||||
test_sentences: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Pass multi-speaker parameters to the model args as `model.init_multispeaker()` looks for it there.
|
||||
if self.num_speakers > 0:
|
||||
self.model_args.num_speakers = self.num_speakers
|
||||
|
||||
# speaker embedding settings
|
||||
if self.use_speaker_embedding:
|
||||
self.model_args.use_speaker_embedding = True
|
||||
if self.speakers_file:
|
||||
self.model_args.speakers_file = self.speakers_file
|
||||
|
||||
# d-vector settings
|
||||
if self.use_d_vector_file:
|
||||
self.model_args.use_d_vector_file = True
|
||||
if self.d_vector_dim is not None and self.d_vector_dim > 0:
|
||||
self.model_args.d_vector_dim = self.d_vector_dim
|
||||
if self.d_vector_file:
|
||||
self.model_args.d_vector_file = self.d_vector_file
|
|
@ -217,6 +217,9 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
compute_f0 (int):
|
||||
(Not in use yet).
|
||||
|
||||
compute_energy (int):
|
||||
(Not in use yet).
|
||||
|
||||
compute_linear_spec (bool):
|
||||
If True data loader computes and returns linear spectrograms alongside the other data.
|
||||
|
||||
|
@ -312,6 +315,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
min_text_len: int = 1
|
||||
max_text_len: int = float("inf")
|
||||
compute_f0: bool = False
|
||||
compute_energy: bool = False
|
||||
compute_linear_spec: bool = False
|
||||
precompute_num_workers: int = 0
|
||||
use_noise_augment: bool = False
|
||||
|
|
|
@ -11,6 +11,7 @@ from torch.utils.data import Dataset
|
|||
|
||||
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.audio.numpy_transforms import compute_energy as calculate_energy
|
||||
|
||||
# to prevent too many open files error as suggested here
|
||||
# https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936
|
||||
|
@ -50,7 +51,9 @@ class TTSDataset(Dataset):
|
|||
samples: List[Dict] = None,
|
||||
tokenizer: "TTSTokenizer" = None,
|
||||
compute_f0: bool = False,
|
||||
compute_energy: bool = False,
|
||||
f0_cache_path: str = None,
|
||||
energy_cache_path: str = None,
|
||||
return_wav: bool = False,
|
||||
batch_group_size: int = 0,
|
||||
min_text_len: int = 0,
|
||||
|
@ -84,8 +87,12 @@ class TTSDataset(Dataset):
|
|||
|
||||
compute_f0 (bool): compute f0 if True. Defaults to False.
|
||||
|
||||
compute_energy (bool): compute energy if True. Defaults to False.
|
||||
|
||||
f0_cache_path (str): Path to store f0 cache. Defaults to None.
|
||||
|
||||
energy_cache_path (str): Path to store energy cache. Defaults to None.
|
||||
|
||||
return_wav (bool): Return the waveform of the sample. Defaults to False.
|
||||
|
||||
batch_group_size (int): Range of batch randomization after sorting
|
||||
|
@ -128,7 +135,9 @@ class TTSDataset(Dataset):
|
|||
self.compute_linear_spec = compute_linear_spec
|
||||
self.return_wav = return_wav
|
||||
self.compute_f0 = compute_f0
|
||||
self.compute_energy = compute_energy
|
||||
self.f0_cache_path = f0_cache_path
|
||||
self.energy_cache_path = energy_cache_path
|
||||
self.min_audio_len = min_audio_len
|
||||
self.max_audio_len = max_audio_len
|
||||
self.min_text_len = min_text_len
|
||||
|
@ -155,7 +164,10 @@ class TTSDataset(Dataset):
|
|||
self.f0_dataset = F0Dataset(
|
||||
self.samples, self.ap, cache_path=f0_cache_path, precompute_num_workers=precompute_num_workers
|
||||
)
|
||||
|
||||
if compute_energy:
|
||||
self.energy_dataset = EnergyDataset(
|
||||
self.samples, self.ap, cache_path=energy_cache_path, precompute_num_workers=precompute_num_workers
|
||||
)
|
||||
if self.verbose:
|
||||
self.print_logs()
|
||||
|
||||
|
@ -211,6 +223,12 @@ class TTSDataset(Dataset):
|
|||
assert item["audio_unique_name"] == out_dict["audio_unique_name"]
|
||||
return out_dict
|
||||
|
||||
def get_energy(self, idx):
|
||||
out_dict = self.energy_dataset[idx]
|
||||
item = self.samples[idx]
|
||||
assert item["audio_unique_name"] == out_dict["audio_unique_name"]
|
||||
return out_dict
|
||||
|
||||
@staticmethod
|
||||
def get_attn_mask(attn_file):
|
||||
return np.load(attn_file)
|
||||
|
@ -252,12 +270,16 @@ class TTSDataset(Dataset):
|
|||
f0 = None
|
||||
if self.compute_f0:
|
||||
f0 = self.get_f0(idx)["f0"]
|
||||
energy = None
|
||||
if self.compute_energy:
|
||||
energy = self.get_energy(idx)["energy"]
|
||||
|
||||
sample = {
|
||||
"raw_text": raw_text,
|
||||
"token_ids": token_ids,
|
||||
"wav": wav,
|
||||
"pitch": f0,
|
||||
"energy": energy,
|
||||
"attn": attn,
|
||||
"item_idx": item["audio_file"],
|
||||
"speaker_name": item["speaker_name"],
|
||||
|
@ -490,7 +512,13 @@ class TTSDataset(Dataset):
|
|||
pitch = torch.FloatTensor(pitch)[:, None, :].contiguous() # B x 1 xT
|
||||
else:
|
||||
pitch = None
|
||||
|
||||
# format energy
|
||||
if self.compute_energy:
|
||||
energy = prepare_data(batch["energy"])
|
||||
assert mel.shape[1] == energy.shape[1], f"[!] {mel.shape} vs {energy.shape}"
|
||||
energy = torch.FloatTensor(energy)[:, None, :].contiguous() # B x 1 xT
|
||||
else:
|
||||
energy = None
|
||||
# format attention masks
|
||||
attns = None
|
||||
if batch["attn"][0] is not None:
|
||||
|
@ -519,6 +547,7 @@ class TTSDataset(Dataset):
|
|||
"waveform": wav_padded,
|
||||
"raw_text": batch["raw_text"],
|
||||
"pitch": pitch,
|
||||
"energy": energy,
|
||||
"language_ids": language_ids,
|
||||
"audio_unique_names": batch["audio_unique_name"],
|
||||
}
|
||||
|
@ -777,3 +806,155 @@ class F0Dataset:
|
|||
print("\n")
|
||||
print(f"{indent}> F0Dataset ")
|
||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
|
||||
|
||||
class EnergyDataset:
|
||||
"""Energy Dataset for computing Energy from wav files in CPU
|
||||
|
||||
Pre-compute Energy values for all the samples at initialization if `cache_path` is not None or already present. It
|
||||
also computes the mean and std of Energy values if `normalize_Energy` is True.
|
||||
|
||||
Args:
|
||||
samples (Union[List[List], List[Dict]]):
|
||||
List of samples. Each sample is a list or a dict.
|
||||
|
||||
ap (AudioProcessor):
|
||||
AudioProcessor to compute Energy from wav files.
|
||||
|
||||
cache_path (str):
|
||||
Path to cache Energy values. If `cache_path` is already present or None, it skips the pre-computation.
|
||||
Defaults to None.
|
||||
|
||||
precompute_num_workers (int):
|
||||
Number of workers used for pre-computing the Energy values. Defaults to 0.
|
||||
|
||||
normalize_Energy (bool):
|
||||
Whether to normalize Energy values by mean and std. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samples: Union[List[List], List[Dict]],
|
||||
ap: "AudioProcessor",
|
||||
verbose=False,
|
||||
cache_path: str = None,
|
||||
precompute_num_workers=0,
|
||||
normalize_energy=True,
|
||||
):
|
||||
self.samples = samples
|
||||
self.ap = ap
|
||||
self.verbose = verbose
|
||||
self.cache_path = cache_path
|
||||
self.normalize_energy = normalize_energy
|
||||
self.pad_id = 0.0
|
||||
self.mean = None
|
||||
self.std = None
|
||||
if cache_path is not None and not os.path.exists(cache_path):
|
||||
os.makedirs(cache_path)
|
||||
self.precompute(precompute_num_workers)
|
||||
if normalize_energy:
|
||||
self.load_stats(cache_path)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.samples[idx]
|
||||
energy = self.compute_or_load(item["audio_file"])
|
||||
if self.normalize_energy:
|
||||
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
|
||||
energy = self.normalize(energy)
|
||||
return {"audio_file": item["audio_file"], "energy": energy}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def precompute(self, num_workers=0):
|
||||
print("[*] Pre-computing energys...")
|
||||
with tqdm.tqdm(total=len(self)) as pbar:
|
||||
batch_size = num_workers if num_workers > 0 else 1
|
||||
# we do not normalize at preproessing
|
||||
normalize_energy = self.normalize_energy
|
||||
self.normalize_energy = False
|
||||
dataloder = torch.utils.data.DataLoader(
|
||||
batch_size=batch_size, dataset=self, shuffle=False, num_workers=num_workers, collate_fn=self.collate_fn
|
||||
)
|
||||
computed_data = []
|
||||
for batch in dataloder:
|
||||
energy = batch["energy"]
|
||||
computed_data.append(e for e in energy)
|
||||
pbar.update(batch_size)
|
||||
self.normalize_energy = normalize_energy
|
||||
|
||||
if self.normalize_energy:
|
||||
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
|
||||
energy_mean, energy_std = self.compute_pitch_stats(computed_data)
|
||||
energy_stats = {"mean": energy_mean, "std": energy_std}
|
||||
np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True)
|
||||
|
||||
def get_pad_id(self):
|
||||
return self.pad_id
|
||||
|
||||
@staticmethod
|
||||
def create_energy_file_path(wav_file, cache_path):
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
energy_file = os.path.join(cache_path, file_name + "_energy.npy")
|
||||
return energy_file
|
||||
|
||||
@staticmethod
|
||||
def _compute_and_save_energy(ap, wav_file, energy_file=None):
|
||||
wav = ap.load_wav(wav_file)
|
||||
energy = calculate_energy(wav)
|
||||
if energy_file:
|
||||
np.save(energy_file, energy)
|
||||
return energy
|
||||
|
||||
@staticmethod
|
||||
def compute_energy_stats(energy_vecs):
|
||||
nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in energy_vecs])
|
||||
mean, std = np.mean(nonzeros), np.std(nonzeros)
|
||||
return mean, std
|
||||
|
||||
def load_stats(self, cache_path):
|
||||
stats_path = os.path.join(cache_path, "energy_stats.npy")
|
||||
stats = np.load(stats_path, allow_pickle=True).item()
|
||||
self.mean = stats["mean"].astype(np.float32)
|
||||
self.std = stats["std"].astype(np.float32)
|
||||
|
||||
def normalize(self, energy):
|
||||
zero_idxs = np.where(energy == 0.0)[0]
|
||||
energy = energy - self.mean
|
||||
energy = energy / self.std
|
||||
energy[zero_idxs] = 0.0
|
||||
return energy
|
||||
|
||||
def denormalize(self, energy):
|
||||
zero_idxs = np.where(energy == 0.0)[0]
|
||||
energy *= self.std
|
||||
energy += self.mean
|
||||
energy[zero_idxs] = 0.0
|
||||
return energy
|
||||
|
||||
def compute_or_load(self, wav_file):
|
||||
"""
|
||||
compute energy and return a numpy array of energy values
|
||||
"""
|
||||
energy_file = self.create_Energy_file_path(wav_file, self.cache_path)
|
||||
if not os.path.exists(energy_file):
|
||||
energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)
|
||||
else:
|
||||
energy = np.load(energy_file)
|
||||
return energy.astype(np.float32)
|
||||
|
||||
def collate_fn(self, batch):
|
||||
audio_file = [item["audio_file"] for item in batch]
|
||||
energys = [item["energy"] for item in batch]
|
||||
energy_lens = [len(item["energy"]) for item in batch]
|
||||
energy_lens_max = max(energy_lens)
|
||||
energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id())
|
||||
for i, energy_len in enumerate(energy_lens):
|
||||
energys_torch[i, :energy_len] = torch.LongTensor(energys[i])
|
||||
return {"audio_file": audio_file, "energy": energys_torch, "energy_lens": energy_lens}
|
||||
|
||||
def print_logs(self, level: int = 0) -> None:
|
||||
indent = "\t" * level
|
||||
print("\n")
|
||||
print(f"{indent}> energyDataset ")
|
||||
print(f"{indent}| > Number of instances : {len(self.samples)}")
|
||||
|
|
|
@ -801,6 +801,10 @@ class ForwardTTSLoss(nn.Module):
|
|||
self.pitch_loss = MSELossMasked(False)
|
||||
self.pitch_loss_alpha = c.pitch_loss_alpha
|
||||
|
||||
if c.model_args.use_energy:
|
||||
self.energy_loss = MSELossMasked(False)
|
||||
self.energy_loss_alpha = c.energy_loss_alpha
|
||||
|
||||
if c.use_ssim_loss:
|
||||
self.ssim = SSIMLoss() if c.use_ssim_loss else None
|
||||
self.ssim_loss_alpha = c.ssim_loss_alpha
|
||||
|
@ -826,6 +830,8 @@ class ForwardTTSLoss(nn.Module):
|
|||
dur_target,
|
||||
pitch_output,
|
||||
pitch_target,
|
||||
energy_output,
|
||||
energy_target,
|
||||
input_lens,
|
||||
alignment_logprob=None,
|
||||
alignment_hard=None,
|
||||
|
@ -855,6 +861,11 @@ class ForwardTTSLoss(nn.Module):
|
|||
loss = loss + self.pitch_loss_alpha * pitch_loss
|
||||
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
|
||||
|
||||
if hasattr(self, "energy_loss") and self.energy_loss_alpha > 0:
|
||||
energy_loss = self.energy_loss(energy_output.transpose(1, 2), energy_target.transpose(1, 2), input_lens)
|
||||
loss = loss + self.energy_loss_alpha * energy_loss
|
||||
return_dict["loss_energy"] = self.energy_loss_alpha * energy_loss
|
||||
|
||||
if hasattr(self, "aligner_loss") and self.aligner_loss_alpha > 0:
|
||||
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
|
||||
loss = loss + self.aligner_loss_alpha * aligner_loss
|
||||
|
|
|
@ -15,7 +15,7 @@ from TTS.tts.models.base_tts import BaseTTS
|
|||
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
|
@ -42,6 +42,9 @@ class ForwardTTSArgs(Coqpit):
|
|||
use_pitch (bool):
|
||||
Use pitch predictor to learn the pitch. Defaults to True.
|
||||
|
||||
use_energy (bool):
|
||||
Use energy predictor to learn the energy. Defaults to True.
|
||||
|
||||
duration_predictor_hidden_channels (int):
|
||||
Number of hidden channels in the duration predictor. Defaults to 256.
|
||||
|
||||
|
@ -63,6 +66,18 @@ class ForwardTTSArgs(Coqpit):
|
|||
pitch_embedding_kernel_size (int):
|
||||
Kernel size of the projection layer in the pitch predictor. Defaults to 3.
|
||||
|
||||
energy_predictor_hidden_channels (int):
|
||||
Number of hidden channels in the energy predictor. Defaults to 256.
|
||||
|
||||
energy_predictor_dropout_p (float):
|
||||
Dropout rate for the energy predictor. Defaults to 0.1.
|
||||
|
||||
energy_predictor_kernel_size (int):
|
||||
Kernel size of conv layers in the energy predictor. Defaults to 3.
|
||||
|
||||
energy_embedding_kernel_size (int):
|
||||
Kernel size of the projection layer in the energy predictor. Defaults to 3.
|
||||
|
||||
positional_encoding (bool):
|
||||
Whether to use positional encoding. Defaults to True.
|
||||
|
||||
|
@ -114,14 +129,25 @@ class ForwardTTSArgs(Coqpit):
|
|||
out_channels: int = 80
|
||||
hidden_channels: int = 384
|
||||
use_aligner: bool = True
|
||||
# pitch params
|
||||
use_pitch: bool = True
|
||||
pitch_predictor_hidden_channels: int = 256
|
||||
pitch_predictor_kernel_size: int = 3
|
||||
pitch_predictor_dropout_p: float = 0.1
|
||||
pitch_embedding_kernel_size: int = 3
|
||||
|
||||
# energy params
|
||||
use_energy: bool = False
|
||||
energy_predictor_hidden_channels: int = 256
|
||||
energy_predictor_kernel_size: int = 3
|
||||
energy_predictor_dropout_p: float = 0.1
|
||||
energy_embedding_kernel_size: int = 3
|
||||
|
||||
# duration params
|
||||
duration_predictor_hidden_channels: int = 256
|
||||
duration_predictor_kernel_size: int = 3
|
||||
duration_predictor_dropout_p: float = 0.1
|
||||
|
||||
positional_encoding: bool = True
|
||||
poisitonal_encoding_use_scale: bool = True
|
||||
length_scale: int = 1
|
||||
|
@ -158,7 +184,7 @@ class ForwardTTS(BaseTTS):
|
|||
- FastPitch
|
||||
- SpeedySpeech
|
||||
- FastSpeech
|
||||
- TODO: FastSpeech2 (requires average speech energy predictor)
|
||||
- FastSpeech2 (requires average speech energy predictor)
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model coqpit class.
|
||||
|
@ -187,6 +213,7 @@ class ForwardTTS(BaseTTS):
|
|||
self.max_duration = self.args.max_duration
|
||||
self.use_aligner = self.args.use_aligner
|
||||
self.use_pitch = self.args.use_pitch
|
||||
self.use_energy = self.args.use_energy
|
||||
self.binary_loss_weight = 0.0
|
||||
|
||||
self.length_scale = (
|
||||
|
@ -234,6 +261,20 @@ class ForwardTTS(BaseTTS):
|
|||
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
|
||||
)
|
||||
|
||||
if self.args.use_energy:
|
||||
self.energy_predictor = DurationPredictor(
|
||||
self.args.hidden_channels + self.embedded_speaker_dim,
|
||||
self.args.energy_predictor_hidden_channels,
|
||||
self.args.energy_predictor_kernel_size,
|
||||
self.args.energy_predictor_dropout_p,
|
||||
)
|
||||
self.energy_emb = nn.Conv1d(
|
||||
1,
|
||||
self.args.hidden_channels,
|
||||
kernel_size=self.args.energy_embedding_kernel_size,
|
||||
padding=int((self.args.energy_embedding_kernel_size - 1) / 2),
|
||||
)
|
||||
|
||||
if self.args.use_aligner:
|
||||
self.aligner = AlignmentNetwork(
|
||||
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
|
||||
|
@ -440,6 +481,42 @@ class ForwardTTS(BaseTTS):
|
|||
o_pitch_emb = self.pitch_emb(o_pitch)
|
||||
return o_pitch_emb, o_pitch
|
||||
|
||||
def _forward_energy_predictor(
|
||||
self,
|
||||
o_en: torch.FloatTensor,
|
||||
x_mask: torch.IntTensor,
|
||||
energy: torch.FloatTensor = None,
|
||||
dr: torch.IntTensor = None,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Energy predictor forward pass.
|
||||
|
||||
1. Predict energy from encoder outputs.
|
||||
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
|
||||
3. Embed average energy values.
|
||||
|
||||
Args:
|
||||
o_en (torch.FloatTensor): Encoder output.
|
||||
x_mask (torch.IntTensor): Input sequence mask.
|
||||
energy (torch.FloatTensor, optional): Ground truth energy values. Defaults to None.
|
||||
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.FloatTensor, torch.FloatTensor]: Energy embedding, energy prediction.
|
||||
|
||||
Shapes:
|
||||
- o_en: :math:`(B, C, T_{en})`
|
||||
- x_mask: :math:`(B, 1, T_{en})`
|
||||
- pitch: :math:`(B, 1, T_{de})`
|
||||
- dr: :math:`(B, T_{en})`
|
||||
"""
|
||||
o_energy = self.energy_predictor(o_en, x_mask)
|
||||
if energy is not None:
|
||||
avg_energy = average_over_durations(energy, dr)
|
||||
o_energy_emb = self.energy_emb(avg_energy)
|
||||
return o_energy_emb, o_energy, avg_energy
|
||||
o_energy_emb = self.energy_emb(o_energy)
|
||||
return o_energy_emb, o_energy
|
||||
|
||||
def _forward_aligner(
|
||||
self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor
|
||||
) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
|
@ -502,6 +579,7 @@ class ForwardTTS(BaseTTS):
|
|||
y: torch.FloatTensor = None,
|
||||
dr: torch.IntTensor = None,
|
||||
pitch: torch.FloatTensor = None,
|
||||
energy: torch.FloatTensor = None,
|
||||
aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument
|
||||
) -> Dict:
|
||||
"""Model's forward pass.
|
||||
|
@ -513,6 +591,7 @@ class ForwardTTS(BaseTTS):
|
|||
y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None.
|
||||
dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None.
|
||||
pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None.
|
||||
energy (torch.FloatTensor): energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None.
|
||||
aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.
|
||||
|
||||
Shapes:
|
||||
|
@ -556,6 +635,12 @@ class ForwardTTS(BaseTTS):
|
|||
if self.args.use_pitch:
|
||||
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr)
|
||||
o_en = o_en + o_pitch_emb
|
||||
# energy predictor pass
|
||||
o_energy = None
|
||||
avg_energy = None
|
||||
if self.args.use_energy:
|
||||
o_energy_emb, o_energy, avg_energy = self._forward_energy_predictor(o_en, x_mask, energy, dr)
|
||||
o_en = o_en + o_energy_emb
|
||||
# decoder pass
|
||||
o_de, attn = self._forward_decoder(
|
||||
o_en, dr, x_mask, y_lengths, g=None
|
||||
|
@ -567,6 +652,8 @@ class ForwardTTS(BaseTTS):
|
|||
"attn_durations": o_attn, # for visualization [B, T_en, T_de']
|
||||
"pitch_avg": o_pitch,
|
||||
"pitch_avg_gt": avg_pitch,
|
||||
"energy_avg": o_energy,
|
||||
"energy_avg_gt": avg_energy,
|
||||
"alignments": attn, # [B, T_de, T_en]
|
||||
"alignment_soft": alignment_soft,
|
||||
"alignment_mas": alignment_mas,
|
||||
|
@ -604,12 +691,18 @@ class ForwardTTS(BaseTTS):
|
|||
if self.args.use_pitch:
|
||||
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
|
||||
o_en = o_en + o_pitch_emb
|
||||
# energy predictor pass
|
||||
o_energy = None
|
||||
if self.args.use_energy:
|
||||
o_energy_emb, o_energy = self._forward_energy_predictor(o_en, x_mask)
|
||||
o_en = o_en + o_energy_emb
|
||||
# decoder pass
|
||||
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None)
|
||||
outputs = {
|
||||
"model_outputs": o_de,
|
||||
"alignments": attn,
|
||||
"pitch": o_pitch,
|
||||
"energy": o_energy,
|
||||
"durations_log": o_dr_log,
|
||||
}
|
||||
return outputs
|
||||
|
@ -620,6 +713,7 @@ class ForwardTTS(BaseTTS):
|
|||
mel_input = batch["mel_input"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
pitch = batch["pitch"] if self.args.use_pitch else None
|
||||
energy = batch["energy"] if self.args.use_energy else None
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
durations = batch["durations"]
|
||||
|
@ -627,7 +721,14 @@ class ForwardTTS(BaseTTS):
|
|||
|
||||
# forward pass
|
||||
outputs = self.forward(
|
||||
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_lengths,
|
||||
y=mel_input,
|
||||
dr=durations,
|
||||
pitch=pitch,
|
||||
energy=energy,
|
||||
aux_input=aux_input,
|
||||
)
|
||||
# use aligner's output as the duration target
|
||||
if self.use_aligner:
|
||||
|
@ -643,6 +744,8 @@ class ForwardTTS(BaseTTS):
|
|||
dur_target=durations,
|
||||
pitch_output=outputs["pitch_avg"] if self.use_pitch else None,
|
||||
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
|
||||
energy_output=outputs["energy_avg"] if self.use_energy else None,
|
||||
energy_target=outputs["energy_avg_gt"] if self.use_energy else None,
|
||||
input_lens=text_lengths,
|
||||
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
|
||||
alignment_soft=outputs["alignment_soft"],
|
||||
|
@ -683,6 +786,17 @@ class ForwardTTS(BaseTTS):
|
|||
}
|
||||
figures.update(pitch_figures)
|
||||
|
||||
# plot energy figures
|
||||
if self.args.use_energy:
|
||||
energy_avg = abs(outputs["energy_avg_gt"][0, 0].data.cpu().numpy())
|
||||
energy_avg_hat = abs(outputs["energy_avg"][0, 0].data.cpu().numpy())
|
||||
chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
|
||||
energy_figures = {
|
||||
"energy_ground_truth": plot_avg_energy(energy_avg, chars, output_fig=False),
|
||||
"energy_avg_predicted": plot_avg_energy(energy_avg_hat, chars, output_fig=False),
|
||||
}
|
||||
figures.update(energy_figures)
|
||||
|
||||
# plot the attention mask computed from the predicted durations
|
||||
if "attn_durations" in outputs:
|
||||
alignments_hat = outputs["attn_durations"][0].data.cpu().numpy()
|
||||
|
|
|
@ -123,6 +123,39 @@ def plot_avg_pitch(pitch, chars, fig_size=(30, 10), output_fig=False):
|
|||
return fig
|
||||
|
||||
|
||||
def plot_avg_energy(energy, chars, fig_size=(30, 10), output_fig=False):
|
||||
"""Plot energy curves on top of the input characters.
|
||||
|
||||
Args:
|
||||
energy (np.array): energy values.
|
||||
chars (str): Characters to place to the x-axis.
|
||||
|
||||
Shapes:
|
||||
energy: :math:`(T,)`
|
||||
"""
|
||||
old_fig_size = plt.rcParams["figure.figsize"]
|
||||
if fig_size is not None:
|
||||
plt.rcParams["figure.figsize"] = fig_size
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
x = np.array(range(len(chars)))
|
||||
my_xticks = chars
|
||||
plt.xticks(x, my_xticks)
|
||||
|
||||
ax.set_xlabel("characters")
|
||||
ax.set_ylabel("freq")
|
||||
|
||||
ax2 = ax.twinx()
|
||||
ax2.plot(energy, linewidth=5.0, color="red")
|
||||
ax2.set_ylabel("energy")
|
||||
|
||||
plt.rcParams["figure.figsize"] = old_fig_size
|
||||
if not output_fig:
|
||||
plt.close()
|
||||
return fig
|
||||
|
||||
|
||||
def visualize(
|
||||
alignment,
|
||||
postnet_output,
|
||||
|
|
|
@ -4,7 +4,7 @@ import librosa
|
|||
import numpy as np
|
||||
import scipy
|
||||
import soundfile as sf
|
||||
from librosa import pyin
|
||||
from librosa import magphase, pyin
|
||||
|
||||
# For using kwargs
|
||||
# pylint: disable=unused-argument
|
||||
|
@ -303,6 +303,27 @@ def compute_f0(
|
|||
return f0
|
||||
|
||||
|
||||
def compute_energy(y: np.ndarray, **kwargs) -> np.ndarray:
|
||||
"""Compute energy of a waveform using the same parameters used for computing melspectrogram.
|
||||
Args:
|
||||
x (np.ndarray): Waveform. Shape :math:`[T_wav,]`
|
||||
Returns:
|
||||
np.ndarray: energy. Shape :math:`[T_energy,]`. :math:`T_energy == T_wav / hop_length`
|
||||
Examples:
|
||||
>>> WAV_FILE = filename = librosa.util.example_audio_file()
|
||||
>>> from TTS.config import BaseAudioConfig
|
||||
>>> from TTS.utils.audio import AudioProcessor
|
||||
>>> conf = BaseAudioConfig()
|
||||
>>> ap = AudioProcessor(**conf)
|
||||
>>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate]
|
||||
>>> energy = ap.compute_energy(wav)
|
||||
"""
|
||||
x = stft(y=y, **kwargs)
|
||||
mag, _ = magphase(x)
|
||||
energy = np.sqrt(np.sum(mag**2, axis=0))
|
||||
return energy
|
||||
|
||||
|
||||
### Audio Processing ###
|
||||
def find_endpoint(
|
||||
*,
|
||||
|
|
|
@ -0,0 +1,102 @@
|
|||
import os
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig, BaseDatasetConfig
|
||||
from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.forward_tts import ForwardTTS
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.manage import ModelManager
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# init configs
|
||||
dataset_config = BaseDatasetConfig(
|
||||
formatter="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
|
||||
path=os.path.join(output_path, "../LJSpeech-1.1/"),
|
||||
)
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=22050,
|
||||
do_trim_silence=True,
|
||||
trim_db=60.0,
|
||||
signal_norm=False,
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=8000,
|
||||
spec_gain=1.0,
|
||||
log_func="np.log",
|
||||
ref_level_db=20,
|
||||
preemphasis=0.0,
|
||||
)
|
||||
|
||||
config = Fastspeech2Config(
|
||||
run_name="fastspeech2_ljspeech",
|
||||
audio=audio_config,
|
||||
batch_size=32,
|
||||
eval_batch_size=16,
|
||||
num_loader_workers=8,
|
||||
num_eval_loader_workers=4,
|
||||
compute_input_seq_cache=True,
|
||||
compute_f0=True,
|
||||
f0_cache_path=os.path.join(output_path, "f0_cache"),
|
||||
compute_energy=True,
|
||||
energy_cache_path=os.path.join(output_path, "energy_cache"),
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1000,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||
precompute_num_workers=4,
|
||||
print_step=50,
|
||||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
max_seq_len=500000,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
)
|
||||
|
||||
# compute alignments
|
||||
if not config.model_args.use_aligner:
|
||||
manager = ModelManager()
|
||||
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
|
||||
# TODO: make compute_attention python callable
|
||||
os.system(
|
||||
f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true"
|
||||
)
|
||||
|
||||
# INITIALIZE THE AUDIO PROCESSOR
|
||||
# Audio processor is used for feature extraction and audio I/O.
|
||||
# It mainly serves to the dataloader and the training loggers.
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
|
||||
# INITIALIZE THE TOKENIZER
|
||||
# Tokenizer is used to convert text to sequences of token IDs.
|
||||
# If characters are not defined in the config, default characters are passed to the config
|
||||
tokenizer, config = TTSTokenizer.init_from_config(config)
|
||||
|
||||
# LOAD DATA SAMPLES
|
||||
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
|
||||
# You can define your custom sample loader returning the list of samples.
|
||||
# Or define your custom formatter and pass it to the `load_tts_samples`.
|
||||
# Check `TTS.tts.datasets.load_tts_samples` for more details.
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
dataset_config,
|
||||
eval_split=True,
|
||||
eval_split_max_size=config.eval_split_max_size,
|
||||
eval_split_size=config.eval_split_size,
|
||||
)
|
||||
|
||||
# init the model
|
||||
model = ForwardTTS(config, ap, tokenizer, speaker_manager=None)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
|
||||
)
|
||||
trainer.fit()
|
|
@ -0,0 +1,95 @@
|
|||
import glob
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "fast_pitch_speaker_emb_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=22050,
|
||||
do_trim_silence=True,
|
||||
trim_db=60.0,
|
||||
signal_norm=False,
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=8000,
|
||||
spec_gain=1.0,
|
||||
log_func="np.log",
|
||||
ref_level_db=20,
|
||||
preemphasis=0.0,
|
||||
)
|
||||
|
||||
config = Fastspeech2Config(
|
||||
audio=audio_config,
|
||||
batch_size=8,
|
||||
eval_batch_size=8,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
compute_f0=True,
|
||||
compute_energy=True,
|
||||
energy_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
use_speaker_embedding=True,
|
||||
test_sentences=[
|
||||
"Be a voice, not an echo.",
|
||||
],
|
||||
)
|
||||
config.audio.do_trim_silence = True
|
||||
config.use_speaker_embedding = True
|
||||
config.model_args.use_speaker_embedding = True
|
||||
config.audio.trim_db = 60
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.formatter ljspeech_test "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||
"--coqpit.test_delay_epochs 0"
|
||||
)
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# Inference using TTS API
|
||||
continue_config_path = os.path.join(continue_path, "config.json")
|
||||
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
speaker_id = "ljspeech-1"
|
||||
continue_speakers_path = os.path.join(continue_path, "speakers.json")
|
||||
|
||||
# Check integrity of the config
|
||||
with open(continue_config_path, "r", encoding="utf-8") as f:
|
||||
config_loaded = json.load(f)
|
||||
assert config_loaded["characters"] is not None
|
||||
assert config_loaded["output_path"] in continue_path
|
||||
assert config_loaded["test_delay_epochs"] == 0
|
||||
|
||||
# Load the model and run inference
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --speaker_idx {speaker_id} --speakers_file_path {continue_speakers_path} --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
|
||||
run_cli(inference_command)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
|
@ -0,0 +1,94 @@
|
|||
import glob
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from trainer import get_last_checkpoint
|
||||
|
||||
from tests import get_device_id, get_tests_output_path, run_cli
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.fastspeech2_config import Fastspeech2Config
|
||||
|
||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=22050,
|
||||
do_trim_silence=True,
|
||||
trim_db=60.0,
|
||||
signal_norm=False,
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=8000,
|
||||
spec_gain=1.0,
|
||||
log_func="np.log",
|
||||
ref_level_db=20,
|
||||
preemphasis=0.0,
|
||||
)
|
||||
|
||||
config = Fastspeech2Config(
|
||||
audio=audio_config,
|
||||
batch_size=8,
|
||||
eval_batch_size=8,
|
||||
num_loader_workers=0,
|
||||
num_eval_loader_workers=0,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path="tests/data/ljspeech/phoneme_cache/",
|
||||
f0_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
compute_f0=True,
|
||||
compute_energy=True,
|
||||
energy_cache_path="tests/data/ljspeech/f0_cache/",
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1,
|
||||
print_step=1,
|
||||
print_eval=True,
|
||||
test_sentences=[
|
||||
"Be a voice, not an echo.",
|
||||
],
|
||||
use_speaker_embedding=False,
|
||||
)
|
||||
config.audio.do_trim_silence = True
|
||||
config.use_speaker_embedding = False
|
||||
config.model_args.use_speaker_embedding = False
|
||||
config.audio.trim_db = 60
|
||||
config.save_json(config_path)
|
||||
|
||||
# train the model for one epoch
|
||||
command_train = (
|
||||
f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --config_path {config_path} "
|
||||
f"--coqpit.output_path {output_path} "
|
||||
"--coqpit.datasets.0.formatter ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt "
|
||||
"--coqpit.test_delay_epochs 0"
|
||||
)
|
||||
|
||||
run_cli(command_train)
|
||||
|
||||
# Find latest folder
|
||||
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
|
||||
|
||||
# Inference using TTS API
|
||||
continue_config_path = os.path.join(continue_path, "config.json")
|
||||
continue_restore_path, _ = get_last_checkpoint(continue_path)
|
||||
out_wav_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
|
||||
# Check integrity of the config
|
||||
with open(continue_config_path, "r", encoding="utf-8") as f:
|
||||
config_loaded = json.load(f)
|
||||
assert config_loaded["characters"] is not None
|
||||
assert config_loaded["output_path"] in continue_path
|
||||
assert config_loaded["test_delay_epochs"] == 0
|
||||
|
||||
# Load the model and run inference
|
||||
inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}"
|
||||
run_cli(inference_command)
|
||||
|
||||
# restore the model and continue training for one more epoch
|
||||
command_train = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' python TTS/bin/train_tts.py --continue_path {continue_path} "
|
||||
run_cli(command_train)
|
||||
shutil.rmtree(continue_path)
|
Loading…
Reference in New Issue