mirror of https://github.com/coqui-ai/TTS.git
Add test sentences during the training
This commit is contained in:
parent
2f868dd5c2
commit
c4ceaabe2c
|
@ -12,7 +12,8 @@ import sys
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||||
from TTS.tts.layers.xtts.gpt import GPT
|
from TTS.tts.layers.xtts.gpt import GPT
|
||||||
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
|
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig, Xtts
|
||||||
|
from TTS.tts.configs.xtts_config import XttsConfig
|
||||||
|
|
||||||
from TTS.tts.models.base_tts import BaseTTS
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
@ -25,20 +26,21 @@ from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from trainer.torch import DistributedSampler
|
from trainer.torch import DistributedSampler
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||||
|
|
||||||
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTConfig(TortoiseConfig):
|
class GPTTrainerConfig(XttsConfig):
|
||||||
lr: float = 5e-06
|
lr: float = 5e-06
|
||||||
training_seed: int = 1
|
training_seed: int = 1
|
||||||
optimizer_wd_only_on_weights: bool = False
|
optimizer_wd_only_on_weights: bool = False
|
||||||
weighted_loss_attrs: dict = field(default_factory=lambda: {})
|
weighted_loss_attrs: dict = field(default_factory=lambda: {})
|
||||||
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
|
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
|
||||||
|
test_sentences: List[dict] = field(default_factory=lambda: [])
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class XttsAudioConfig(XttsAudioConfig):
|
class XttsAudioConfig(XttsAudioConfig):
|
||||||
|
@ -58,7 +60,8 @@ class GPTArgs(XttsArgs):
|
||||||
tokenizer_file: str = ""
|
tokenizer_file: str = ""
|
||||||
mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth"
|
mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth"
|
||||||
dvae_checkpoint: str = ""
|
dvae_checkpoint: str = ""
|
||||||
gpt_checkpoint: str = ""
|
xtts_checkpoint: str = ""
|
||||||
|
gpt_checkpoint: str = "" # if defined it will replace the gpt weights on xtts model
|
||||||
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
|
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,28 +83,18 @@ class GPTTrainer(BaseTTS):
|
||||||
"""
|
"""
|
||||||
super().__init__(config, ap=None, tokenizer=None)
|
super().__init__(config, ap=None, tokenizer=None)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
# init XTTS model
|
||||||
|
self.xtts = Xtts(self.config)
|
||||||
|
# create the tokenizer with the target vocabulary
|
||||||
|
self.xtts.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
|
||||||
|
# init gpt encoder and hifigan decoder
|
||||||
|
self.xtts.init_models()
|
||||||
|
# set mel stats
|
||||||
|
if self.args.mel_norm_file:
|
||||||
|
self.xtts.mel_stats = load_fsspec(self.args.mel_norm_file)
|
||||||
|
|
||||||
self.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
|
if self.args.xtts_checkpoint:
|
||||||
|
self.load_checkpoint(self.config, self.args.xtts_checkpoint, eval=False, strict=False)
|
||||||
self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size()
|
|
||||||
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
|
|
||||||
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
|
|
||||||
|
|
||||||
self.gpt = GPT(
|
|
||||||
layers=self.args.gpt_layers,
|
|
||||||
model_dim=self.args.gpt_n_model_channels,
|
|
||||||
start_text_token=self.args.gpt_start_text_token,
|
|
||||||
stop_text_token=self.args.gpt_stop_text_token,
|
|
||||||
heads=self.args.gpt_n_heads,
|
|
||||||
max_text_tokens=self.args.gpt_max_text_tokens,
|
|
||||||
max_mel_tokens=self.args.gpt_max_audio_tokens,
|
|
||||||
max_prompt_tokens=self.args.gpt_max_prompt_tokens,
|
|
||||||
number_text_tokens=self.args.gpt_number_text_tokens,
|
|
||||||
num_audio_tokens=self.args.gpt_num_audio_tokens,
|
|
||||||
start_audio_token=self.args.gpt_start_audio_token,
|
|
||||||
stop_audio_token=self.args.gpt_stop_audio_token,
|
|
||||||
).cuda()
|
|
||||||
|
|
||||||
|
|
||||||
# load GPT if available
|
# load GPT if available
|
||||||
if self.args.gpt_checkpoint:
|
if self.args.gpt_checkpoint:
|
||||||
|
@ -122,8 +115,8 @@ class GPTTrainer(BaseTTS):
|
||||||
del gpt_checkpoint[key]
|
del gpt_checkpoint[key]
|
||||||
|
|
||||||
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
|
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
|
||||||
if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.gpt.text_embedding.weight.shape:
|
if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape:
|
||||||
num_new_tokens = self.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
|
num_new_tokens = self.xtts.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
|
||||||
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
|
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
|
||||||
|
|
||||||
# add new tokens to a linear layer (text_head)
|
# add new tokens to a linear layer (text_head)
|
||||||
|
@ -137,7 +130,7 @@ class GPTTrainer(BaseTTS):
|
||||||
# add new weights to the linear layer (text_head)
|
# add new weights to the linear layer (text_head)
|
||||||
text_head_weight = gpt_checkpoint["text_head.weight"]
|
text_head_weight = gpt_checkpoint["text_head.weight"]
|
||||||
start_token_row = text_head_weight[-1, :]
|
start_token_row = text_head_weight[-1, :]
|
||||||
new_entry = torch.randn(num_new_tokens, self.gpt.text_head.weight.shape[1])
|
new_entry = torch.randn(num_new_tokens, self.xtts.gpt.text_head.weight.shape[1])
|
||||||
text_head_weight = torch.cat([text_head_weight, new_entry], axis=0)
|
text_head_weight = torch.cat([text_head_weight, new_entry], axis=0)
|
||||||
text_head_weight[-1, :] = start_token_row
|
text_head_weight[-1, :] = start_token_row
|
||||||
gpt_checkpoint["text_head.weight"] = text_head_weight
|
gpt_checkpoint["text_head.weight"] = text_head_weight
|
||||||
|
@ -150,10 +143,8 @@ class GPTTrainer(BaseTTS):
|
||||||
text_head_bias[-1] = start_token_row
|
text_head_bias[-1] = start_token_row
|
||||||
gpt_checkpoint["text_head.bias"] = text_head_bias
|
gpt_checkpoint["text_head.bias"] = text_head_bias
|
||||||
|
|
||||||
self.gpt.load_state_dict(gpt_checkpoint, strict=True)
|
self.xtts.gpt.load_state_dict(gpt_checkpoint, strict=True)
|
||||||
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
|
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
|
||||||
else:
|
|
||||||
print(">> GPT weights randomly initialized! If you want you can specify a checkpoint in config.model_args.gpt_checkpoint")
|
|
||||||
|
|
||||||
# Mel spectrogram extractor for conditioning
|
# Mel spectrogram extractor for conditioning
|
||||||
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
self.torch_mel_spectrogram_style_encoder = TorchMelSpectrogram(
|
||||||
|
@ -195,6 +186,7 @@ class GPTTrainer(BaseTTS):
|
||||||
# Mel spectrogram extractor for DVAE
|
# Mel spectrogram extractor for DVAE
|
||||||
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate)
|
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
return next(self.parameters()).device
|
return next(self.parameters()).device
|
||||||
|
@ -211,12 +203,30 @@ class GPTTrainer(BaseTTS):
|
||||||
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
||||||
cond_idxs: cond start and end indexs, (b, 2)
|
cond_idxs: cond start and end indexs, (b, 2)
|
||||||
"""
|
"""
|
||||||
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
|
losses = self.xtts.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_idxs=cond_idxs)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
|
def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
|
||||||
return {}, {}
|
if self.config.test_sentences:
|
||||||
|
# init gpt for inference mode
|
||||||
|
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
|
||||||
|
self.xtts.gpt.eval()
|
||||||
|
test_audios = {}
|
||||||
|
print(" | > Synthesizing test sentences.")
|
||||||
|
for idx, s_info in enumerate(self.config.test_sentences):
|
||||||
|
wav = self.xtts.synthesize(s_info["text"], self.config, s_info["speaker_wav"], s_info["language"])["wav"]
|
||||||
|
test_audios["{}-audio".format(idx)] = wav
|
||||||
|
|
||||||
|
# delete inference layers
|
||||||
|
del self.xtts.gpt.gpt_inference
|
||||||
|
del self.xtts.gpt.gpt.wte
|
||||||
|
return {"audios": test_audios}
|
||||||
|
|
||||||
|
def test_log(
|
||||||
|
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
|
||||||
|
) -> None:
|
||||||
|
logger.test_audios(steps, outputs["audios"], self.args.output_sample_rate)
|
||||||
|
|
||||||
def format_batch(self, batch: Dict) -> Dict:
|
def format_batch(self, batch: Dict) -> Dict:
|
||||||
return batch
|
return batch
|
||||||
|
@ -323,7 +333,7 @@ class GPTTrainer(BaseTTS):
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
# init dataloader
|
# init dataloader
|
||||||
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate, is_eval)
|
dataset = XTTSDataset(self.config, samples, self.xtts.tokenizer, config.audio.sample_rate, is_eval)
|
||||||
|
|
||||||
# wait all the DDP process to be ready
|
# wait all the DDP process to be ready
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
|
@ -362,7 +372,7 @@ class GPTTrainer(BaseTTS):
|
||||||
# ToDo: deal with multi GPU training
|
# ToDo: deal with multi GPU training
|
||||||
if self.config.optimizer_wd_only_on_weights:
|
if self.config.optimizer_wd_only_on_weights:
|
||||||
# parameters to only GPT model
|
# parameters to only GPT model
|
||||||
net = self.gpt
|
net = self.xtts.gpt
|
||||||
|
|
||||||
# normalizations
|
# normalizations
|
||||||
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
|
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
|
||||||
|
@ -410,7 +420,7 @@ class GPTTrainer(BaseTTS):
|
||||||
self.config.optimizer_params,
|
self.config.optimizer_params,
|
||||||
self.config.lr,
|
self.config.lr,
|
||||||
# optimize only for the GPT model
|
# optimize only for the GPT model
|
||||||
parameters=self.gpt.parameters(),
|
parameters=self.xtts.gpt.parameters(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_scheduler(self, optimizer) -> List:
|
def get_scheduler(self, optimizer) -> List:
|
||||||
|
@ -432,21 +442,21 @@ class GPTTrainer(BaseTTS):
|
||||||
target_options={"anon": True},
|
target_options={"anon": True},
|
||||||
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
|
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
|
||||||
"""Load the model checkpoint and setup for training or inference"""
|
"""Load the model checkpoint and setup for training or inference"""
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))["model"]
|
||||||
# load the model weights
|
# load the model weights
|
||||||
self.gpt.load_state_dict(state, strict=strict)
|
self.xtts.load_state_dict(state, strict=strict)
|
||||||
|
|
||||||
if eval:
|
if eval:
|
||||||
|
self.xtts.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=False)
|
||||||
self.eval()
|
self.eval()
|
||||||
self.set_inference()
|
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_from_config(config: "GPTConfig", samples: Union[List[List], List[Dict]] = None):
|
def init_from_config(config: "GPTTrainerConfig", samples: Union[List[List], List[Dict]] = None):
|
||||||
"""Initiate model from config
|
"""Initiate model from config
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (GPTConfig): Model config.
|
config (GPTTrainerConfig): Model config.
|
||||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -387,7 +387,7 @@ class Xtts(BaseTTS):
|
||||||
audio = load_audio(audio_path)
|
audio = load_audio(audio_path)
|
||||||
audio = audio[:, : 22050 * length]
|
audio = audio[:, : 22050 * length]
|
||||||
mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu())
|
mel = wav_to_mel_cloning(audio, mel_norms=self.mel_stats.cpu())
|
||||||
cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
|
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
||||||
return cond_latent.transpose(1, 2)
|
return cond_latent.transpose(1, 2)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|
|
@ -3,7 +3,7 @@ from trainer import Trainer, TrainerArgs
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
|
|
||||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTConfig
|
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTTrainerConfig
|
||||||
|
|
||||||
|
|
||||||
config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig(
|
config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig(
|
||||||
|
@ -265,21 +265,21 @@ def main():
|
||||||
debug_loading_failures=False,
|
debug_loading_failures=False,
|
||||||
max_wav_length=255995, # ~11.6 seconds
|
max_wav_length=255995, # ~11.6 seconds
|
||||||
max_text_length=200,
|
max_text_length=200,
|
||||||
tokenizer_file="/raid/datasets/xtts_models/vocab.json",
|
|
||||||
mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth",
|
mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth",
|
||||||
dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth",
|
dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth",
|
||||||
gpt_checkpoint="/raid/datasets/xtts_models/gpt.pth",
|
tokenizer_file="/raid/datasets/xtts_models/vocab.json", # vocab path of the model that you want to fine-tune
|
||||||
|
xtts_checkpoint="https://huggingface.co/coqui/XTTS-v1/resolve/hifigan/model.pth", # checkpoint path of the model that you want to fine-tune
|
||||||
gpt_num_audio_tokens=8194,
|
gpt_num_audio_tokens=8194,
|
||||||
gpt_start_audio_token=8192,
|
gpt_start_audio_token=8192,
|
||||||
gpt_stop_audio_token=8193,
|
gpt_stop_audio_token=8193,
|
||||||
)
|
)
|
||||||
audio_config = XttsAudioConfig(
|
audio_config = XttsAudioConfig(
|
||||||
sample_rate=22050, # autoregressive SR
|
sample_rate=22050, # GPT SR
|
||||||
dvae_sample_rate=22050,
|
dvae_sample_rate=22050,
|
||||||
diffusion_sample_rate=24000,
|
diffusion_sample_rate=24000,
|
||||||
output_sample_rate=24000
|
output_sample_rate=24000
|
||||||
)
|
)
|
||||||
config = GPTConfig(
|
config = GPTTrainerConfig(
|
||||||
output_path=OUT_PATH,
|
output_path=OUT_PATH,
|
||||||
model_args=model_args,
|
model_args=model_args,
|
||||||
run_name=RUN_NAME,
|
run_name=RUN_NAME,
|
||||||
|
@ -313,6 +313,10 @@ def main():
|
||||||
lr_scheduler="MultiStepLR",
|
lr_scheduler="MultiStepLR",
|
||||||
# it was adjusted accordly for the new step scheme
|
# it was adjusted accordly for the new step scheme
|
||||||
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
||||||
|
test_sentences=[
|
||||||
|
{"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"},
|
||||||
|
{"text": "This cake is great. It's so delicious and moist.", "speaker_wav": "/raid/edresson/dev/ref.wav", "language": "en"},
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# init the model from config
|
# init the model from config
|
||||||
|
@ -341,7 +345,7 @@ def main():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
RUN_NAME = "GPT_XTTS"
|
RUN_NAME = "GPT_XTTS"
|
||||||
PROJECT_NAME = "XTTS"
|
PROJECT_NAME = "XTTS_trainer"
|
||||||
OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_style_emb/"
|
OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_style_emb/"
|
||||||
DASHBOARD_LOGGER = "clearml"
|
DASHBOARD_LOGGER = "clearml"
|
||||||
LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_style_emb/"
|
LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_style_emb/"
|
||||||
|
@ -352,12 +356,11 @@ if __name__ == "__main__":
|
||||||
GRAD_ACUMM_STEPS = 28
|
GRAD_ACUMM_STEPS = 28
|
||||||
|
|
||||||
# debug
|
# debug
|
||||||
DASHBOARD_LOGGER = "tensorboard"
|
# DASHBOARD_LOGGER = "tensorboard"
|
||||||
LOGGER_URI = None
|
# LOGGER_URI = None
|
||||||
RESTORE_PATH = None
|
# RESTORE_PATH = None
|
||||||
BATCH_SIZE = 10
|
BATCH_SIZE = 2
|
||||||
GRAD_ACUMM_STEPS = 1
|
GRAD_ACUMM_STEPS = 1
|
||||||
NUM_LOADERS = 1
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue