mirror of https://github.com/coqui-ai/TTS.git
Add XTTS base training code
This commit is contained in:
parent
1e152692ed
commit
a32961bcb4
|
@ -0,0 +1,202 @@
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import torchaudio
|
||||
from torchaudio.backend.sox_io_backend import load as torchaudio_sox_load
|
||||
from torchaudio.backend.soundfile_backend import load as torchaudio_soundfile_load
|
||||
torch.set_num_threads(1)
|
||||
|
||||
def key_samples_by_col(samples, col):
|
||||
"""Returns a dictionary of samples keyed by language."""
|
||||
samples_by_col = {}
|
||||
for sample in samples:
|
||||
col_val = sample[col]
|
||||
assert isinstance(col_val, str)
|
||||
if col_val not in samples_by_col:
|
||||
samples_by_col[col_val] = []
|
||||
samples_by_col[col_val].append(sample)
|
||||
return samples_by_col
|
||||
|
||||
|
||||
def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate):
|
||||
rel_clip = load_audio(gt_path, sample_rate)
|
||||
sample_length = random.randint(min_sample_length, max_sample_length)
|
||||
gap = rel_clip.shape[-1] - sample_length
|
||||
if gap < 0:
|
||||
sample_length = rel_clip.shape[-1] // 2
|
||||
gap = rel_clip.shape[-1] - sample_length
|
||||
rand_start = random.randint(0, gap)
|
||||
rand_end = rand_start+sample_length
|
||||
rel_clip = rel_clip[:, rand_start:rand_end]
|
||||
rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
|
||||
cond_idxs = [rand_start, rand_end]
|
||||
return rel_clip, rel_clip.shape[-1], cond_idxs
|
||||
|
||||
|
||||
def load_audio(audiopath, sampling_rate):
|
||||
# better load setting following: https://github.com/faroit/python_audio_loading_benchmark
|
||||
if audiopath[-4:] == '.mp3':
|
||||
# it uses torchaudio with sox backend to load mp3
|
||||
audio, lsr = torchaudio_sox_load(audiopath)
|
||||
else:
|
||||
# it uses torchaudio soundfile backend to load all the others data type
|
||||
audio, lsr = torchaudio_soundfile_load(audiopath)
|
||||
|
||||
# stereo to mono if needed
|
||||
if audio.size(0) != 1:
|
||||
audio = torch.mean(audio, dim=0, keepdim=True)
|
||||
|
||||
if lsr != sampling_rate:
|
||||
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
|
||||
|
||||
# Check some assumptions about audio range. This should be automatically fixed in load_wav_to_torch, but might not be in some edge cases, where we should squawk.
|
||||
# '10' is arbitrarily chosen since it seems like audio will often "overdrive" the [-1,1] bounds.
|
||||
if torch.any(audio > 10) or not torch.any(audio < 0):
|
||||
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
||||
# clip audio invalid values
|
||||
audio.clip_(-1, 1)
|
||||
return audio
|
||||
|
||||
class XTTSDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, config, samples, tokenizer, sample_rate):
|
||||
self.config = config
|
||||
model_args = config.model_args
|
||||
self.failed_samples = set()
|
||||
self.debug_failures = model_args.debug_loading_failures
|
||||
self.max_conditioning_length = model_args.max_conditioning_length
|
||||
self.min_conditioning_length = model_args.min_conditioning_length
|
||||
|
||||
# self.samples = []
|
||||
# cache the samples and added type "0" for all samples
|
||||
# ToDo: find a better way to deal with type
|
||||
# for item in samples:
|
||||
# self.samples.append([item['audio_file'], item["text"], 0])
|
||||
self.samples = samples
|
||||
random.seed(config.training_seed)
|
||||
# random.shuffle(self.samples)
|
||||
random.shuffle(self.samples)
|
||||
# order by language
|
||||
self.samples = key_samples_by_col(self.samples, "language")
|
||||
print(" > Sampling by language:", self.samples.keys())
|
||||
|
||||
# use always the output sampling rate to load in the highest quality
|
||||
self.sample_rate = sample_rate
|
||||
self.max_wav_len = model_args.max_wav_length
|
||||
self.max_text_len = model_args.max_text_length
|
||||
assert self.max_wav_len is not None and self.max_text_len is not None
|
||||
|
||||
# load specific vocabulary
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def get_text(self, text, lang):
|
||||
tokens = self.tokenizer.encode(text, lang)
|
||||
tokens = torch.IntTensor(tokens)
|
||||
assert not torch.any(tokens == 1), f"UNK token found in {text} -> {self.tokenizer.decode(tokens)}"
|
||||
# The stop token should always be sacred.
|
||||
assert not torch.any(tokens == 0), f"Stop token found in {text}"
|
||||
return tokens
|
||||
|
||||
def load_item(self, sample):
|
||||
text = str(sample['text'])
|
||||
tseq = self.get_text(text, sample["language"])
|
||||
audiopath = sample['audio_file']
|
||||
wav = load_audio(audiopath, self.sample_rate)
|
||||
if text is None or len(text.strip()) == 0:
|
||||
raise ValueError
|
||||
if wav is None or wav.shape[-1] < (0.5 * self.sample_rate):
|
||||
# Ultra short clips are also useless (and can cause problems within some models).
|
||||
raise ValueError
|
||||
|
||||
# get a slice from GT to condition the model
|
||||
cond, cond_len, cond_idxs = get_prompt_slice(audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate)
|
||||
|
||||
return tseq, audiopath, wav, cond, cond_len, cond_idxs
|
||||
|
||||
def __getitem__(self, index):
|
||||
# select a random language
|
||||
lang = random.choice(list(self.samples.keys()))
|
||||
# select random sample
|
||||
index = random.randint(0, len(self.samples[lang]) - 1)
|
||||
sample = self.samples[lang][index]
|
||||
# a unique id for each sampel to deal with fails
|
||||
sample_id = lang+"_"+str(index)
|
||||
|
||||
# ignore samples that we already know that is not valid ones
|
||||
if sample_id in self.failed_samples:
|
||||
if self.debug_failures:
|
||||
print(f"Ignoring sample {sample['audio_file']} because it was already ignored before !!")
|
||||
# call get item again to get other sample
|
||||
return self[1]
|
||||
|
||||
# try to load the sample, if fails added it to the failed samples list
|
||||
try:
|
||||
tseq, audiopath, wav, cond, cond_len, cond_idxs = self.load_item(sample)
|
||||
except:
|
||||
if self.debug_failures:
|
||||
print(f"error loading {sample['audio_file']} {sys.exc_info()}")
|
||||
self.failed_samples.add(sample_id)
|
||||
return self[1]
|
||||
|
||||
# check if the audio and text size limits and if it out of the limits, added it failed_samples
|
||||
if wav is None or \
|
||||
(self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len) or \
|
||||
(self.max_text_len is not None and tseq.shape[0] > self.max_text_len):
|
||||
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
||||
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
|
||||
if self.debug_failures and wav is not None and tseq is not None:
|
||||
print(f"error loading {sample['audio_file']}: ranges are out of bounds; {wav.shape[-1]}, {tseq.shape[0]}")
|
||||
self.failed_samples.add(sample_id)
|
||||
return self[1]
|
||||
|
||||
res = {
|
||||
# 'real_text': text,
|
||||
'text': tseq,
|
||||
'text_lengths': torch.tensor(tseq.shape[0], dtype=torch.long),
|
||||
'wav': wav,
|
||||
'wav_lengths': torch.tensor(wav.shape[-1], dtype=torch.long),
|
||||
'filenames': audiopath,
|
||||
'conditioning': cond.unsqueeze(1),
|
||||
'cond_lens': torch.tensor(cond_len, dtype=torch.long),
|
||||
'cond_idxs': torch.tensor(cond_idxs),
|
||||
}
|
||||
return res
|
||||
|
||||
def __len__(self):
|
||||
return sum([len(v) for v in self.samples.values()])
|
||||
|
||||
def collate_fn(self, batch):
|
||||
# convert list of dicts to dict of lists
|
||||
B = len(batch)
|
||||
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
||||
|
||||
# stack for features that already have the same shape
|
||||
batch["wav_lengths"] = torch.stack(batch["wav_lengths"])
|
||||
batch["text_lengths"] = torch.stack(batch["text_lengths"])
|
||||
batch["conditioning"] = torch.stack(batch["conditioning"])
|
||||
batch["cond_lens"] = torch.stack(batch["cond_lens"])
|
||||
batch["cond_idxs"] = torch.stack(batch["cond_idxs"])
|
||||
max_text_len = batch["text_lengths"].max()
|
||||
max_wav_len = batch["wav_lengths"].max()
|
||||
|
||||
# create padding tensors
|
||||
text_padded = torch.IntTensor(B, max_text_len)
|
||||
wav_padded = torch.FloatTensor(B, 1, max_wav_len)
|
||||
|
||||
# initialize tensors for zero padding
|
||||
text_padded = text_padded.zero_()
|
||||
wav_padded = wav_padded.zero_()
|
||||
for i in range(B):
|
||||
text = batch["text"][i]
|
||||
text_padded[i, : batch["text_lengths"][i]] = torch.IntTensor(text)
|
||||
wav = batch['wav'][i]
|
||||
wav_padded[i, :, :batch["wav_lengths"][i]] = torch.FloatTensor(wav)
|
||||
|
||||
batch["wav"] = wav_padded
|
||||
batch["padded_text"] = text_padded
|
||||
|
||||
return batch
|
|
@ -0,0 +1,448 @@
|
|||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
import sys
|
||||
|
||||
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
|
||||
from TTS.tts.layers.xtts.gpt import GPT
|
||||
from TTS.tts.models.xtts import XttsArgs, XttsAudioConfig
|
||||
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.tts.configs.tortoise_config import TortoiseConfig
|
||||
from TTS.tts.layers.tortoise.arch_utils import TorchMelSpectrogram
|
||||
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
|
||||
from trainer.torch import DistributedSampler
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
|
||||
from TTS.tts.layers.xtts.trainer.dataset import XTTSDataset
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
from TTS.tts.layers.xtts.dvae import DiscreteVAE
|
||||
|
||||
@dataclass
|
||||
class GPTConfig(TortoiseConfig):
|
||||
lr: float = 5e-06
|
||||
training_seed: int = 1
|
||||
optimizer_wd_only_on_weights: bool = False
|
||||
use_weighted_loss: bool = False # TODO: move it to the base config
|
||||
weighted_loss_attrs: dict = field(default_factory=lambda: {})
|
||||
weighted_loss_multipliers: dict = field(default_factory=lambda: {})
|
||||
|
||||
|
||||
@dataclass
|
||||
class XttsAudioConfig(XttsAudioConfig):
|
||||
dvae_sample_rate: int = 22050
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTArgs(XttsArgs):
|
||||
min_conditioning_length: int = 66150
|
||||
max_conditioning_length: int = 132300
|
||||
gpt_loss_text_ce_weight: float = 0.01
|
||||
gpt_loss_mel_ce_weight: float = 1.0
|
||||
gpt_num_audio_tokens: int = 8194
|
||||
debug_loading_failures: bool = False
|
||||
max_wav_length: int = 255995 # ~11.6 seconds
|
||||
max_text_length: int = 200
|
||||
tokenizer_file: str = ""
|
||||
mel_norm_file: str = "https://coqui.gateway.scarf.sh/v0.14.0_models/mel_norms.pth"
|
||||
dvae_checkpoint: str = ""
|
||||
gpt_checkpoint: str = ""
|
||||
vocoder: str = "" # overide vocoder key on the config to avoid json write issues
|
||||
|
||||
|
||||
def callback_clearml_load_save(operation_type, model_info):
|
||||
# return None means skip the file upload/log, returning model_info will continue with the log/upload
|
||||
# you can also change the upload destination file name model_info.upload_filename or check the local file size with Path(model_info.local_model_path).stat().st_size
|
||||
assert operation_type in ('load', 'save')
|
||||
# print(operation_type, model_info.__dict__)
|
||||
|
||||
if "similarities.pth" in model_info.__dict__['local_model_path']:
|
||||
return None
|
||||
|
||||
return model_info
|
||||
|
||||
class GPTTrainer(BaseTTS):
|
||||
def __init__(self, config: Coqpit):
|
||||
"""
|
||||
Tortoise GPT training class
|
||||
"""
|
||||
super().__init__(config, ap=None, tokenizer=None)
|
||||
self.config = config
|
||||
|
||||
self.tokenizer = VoiceBpeTokenizer(self.args.tokenizer_file)
|
||||
|
||||
self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size()
|
||||
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
|
||||
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
|
||||
|
||||
self.gpt = GPT(
|
||||
layers=self.args.gpt_layers,
|
||||
model_dim=self.args.gpt_n_model_channels,
|
||||
start_text_token=self.args.gpt_start_text_token,
|
||||
stop_text_token=self.args.gpt_stop_text_token,
|
||||
heads=self.args.gpt_n_heads,
|
||||
max_text_tokens=self.args.gpt_max_text_tokens,
|
||||
max_mel_tokens=self.args.gpt_max_audio_tokens,
|
||||
max_prompt_tokens=self.args.gpt_max_prompt_tokens,
|
||||
number_text_tokens=self.args.gpt_number_text_tokens,
|
||||
num_audio_tokens=self.args.gpt_num_audio_tokens,
|
||||
start_audio_token=self.args.gpt_start_audio_token,
|
||||
stop_audio_token=self.args.gpt_stop_audio_token,
|
||||
).cuda()
|
||||
|
||||
|
||||
# load GPT if available
|
||||
if self.args.gpt_checkpoint:
|
||||
gpt_checkpoint = torch.load(
|
||||
self.args.gpt_checkpoint, map_location=torch.device("cpu")
|
||||
)
|
||||
# deal with coqui Trainer exported model
|
||||
if "model" in gpt_checkpoint.keys() and "config" in gpt_checkpoint.keys():
|
||||
print("Coqui Trainer checkpoint detected! Converting it!")
|
||||
gpt_checkpoint = gpt_checkpoint["model"]
|
||||
states_keys = list(gpt_checkpoint.keys())
|
||||
for key in states_keys:
|
||||
if "gpt." in key:
|
||||
new_key = key.replace("gpt.", "")
|
||||
gpt_checkpoint[new_key] = gpt_checkpoint[key]
|
||||
del gpt_checkpoint[key]
|
||||
else:
|
||||
del gpt_checkpoint[key]
|
||||
|
||||
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
|
||||
if "text_embedding.weight" in gpt_checkpoint and gpt_checkpoint["text_embedding.weight"].shape != self.gpt.text_embedding.weight.shape:
|
||||
num_new_tokens = self.gpt.text_embedding.weight.shape[0] - gpt_checkpoint["text_embedding.weight"].shape[0]
|
||||
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
|
||||
|
||||
# add new tokens to a linear layer (text_head)
|
||||
emb_g = gpt_checkpoint["text_embedding.weight"]
|
||||
new_row = torch.randn(num_new_tokens, emb_g.shape[1])
|
||||
start_token_row = emb_g[-1, :]
|
||||
emb_g = torch.cat([emb_g, new_row], axis=0)
|
||||
emb_g[-1, :] = start_token_row
|
||||
gpt_checkpoint["text_embedding.weight"] = emb_g
|
||||
|
||||
# add new weights to the linear layer (text_head)
|
||||
text_head_weight = gpt_checkpoint["text_head.weight"]
|
||||
start_token_row = text_head_weight[-1, :]
|
||||
new_entry = torch.randn(num_new_tokens, self.gpt.text_head.weight.shape[1])
|
||||
text_head_weight = torch.cat([text_head_weight, new_entry], axis=0)
|
||||
text_head_weight[-1, :] = start_token_row
|
||||
gpt_checkpoint["text_head.weight"] = text_head_weight
|
||||
|
||||
# add new biases to the linear layer (text_head)
|
||||
text_head_bias = gpt_checkpoint["text_head.bias"]
|
||||
start_token_row = text_head_bias[-1]
|
||||
new_bias_entry = torch.zeros(num_new_tokens)
|
||||
text_head_bias = torch.cat([text_head_bias, new_bias_entry], axis=0)
|
||||
text_head_bias[-1] = start_token_row
|
||||
gpt_checkpoint["text_head.bias"] = text_head_bias
|
||||
|
||||
self.gpt.load_state_dict(gpt_checkpoint, strict=True)
|
||||
print(">> GPT weights restored from:", self.args.gpt_checkpoint)
|
||||
else:
|
||||
print(">> GPT weights randomly initialized! If you want you can specify a checkpoint in config.model_args.gpt_checkpoint")
|
||||
|
||||
# Mel spectrogram extractor for conditioning
|
||||
self.torch_mel_spectrogram = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.sample_rate)
|
||||
|
||||
# Load DVAE
|
||||
self.dvae = DiscreteVAE(
|
||||
channels=80,
|
||||
normalization=None,
|
||||
positional_dims=1,
|
||||
num_tokens=self.args.gpt_num_audio_tokens - 2,
|
||||
codebook_dim=512,
|
||||
hidden_dim=512,
|
||||
num_resnet_blocks=3,
|
||||
kernel_size=3,
|
||||
num_layers=2,
|
||||
use_transposed_convs=False,
|
||||
)
|
||||
|
||||
self.dvae.eval()
|
||||
if self.args.dvae_checkpoint:
|
||||
dvae_checkpoint = torch.load(
|
||||
self.args.dvae_checkpoint, map_location=torch.device("cpu")
|
||||
)
|
||||
self.dvae.load_state_dict(dvae_checkpoint, strict=False)
|
||||
print(">> DVAE weights restored from:", self.args.dvae_checkpoint)
|
||||
else:
|
||||
raise RuntimeError("You need to specify config.model_args.dvae_checkpoint path to be able to train the GPT decoder!!")
|
||||
|
||||
# Mel spectrogram extractor for DVAE
|
||||
self.torch_mel_spectrogram_dvae = TorchMelSpectrogram(mel_norm_file=self.args.mel_norm_file, sampling_rate=config.audio.dvae_sample_rate)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_lens):
|
||||
"""
|
||||
Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
|
||||
(actuated by `text_first`).
|
||||
|
||||
text_inputs: long tensor, (b,t)
|
||||
text_lengths: long tensor, (b,)
|
||||
mel_inputs: long tensor, (b,m)
|
||||
wav_lengths: long tensor, (b,)
|
||||
cond_mels: MEL float tensor, (b, num_samples, 80,t_m)
|
||||
cond_lengths: long tensor, (b,)
|
||||
"""
|
||||
losses = self.gpt(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels=cond_mels, cond_lens=cond_lens)
|
||||
return losses
|
||||
|
||||
@torch.no_grad()
|
||||
def test_run(self, assets) -> Tuple[Dict, Dict]: # pylint: disable=W0613
|
||||
return {}, {}
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
return batch
|
||||
|
||||
@torch.no_grad() # torch no grad to avoid gradients from the pre-processing and DVAE codes extraction
|
||||
def format_batch_on_device(self, batch):
|
||||
"""Compute spectrograms on the device."""
|
||||
batch["text_lengths"] = batch["text_lengths"]
|
||||
batch["wav_lengths"] = batch["wav_lengths"]
|
||||
batch["text_inputs"] = batch["padded_text"]
|
||||
batch["cond_lens"] = batch["cond_lens"]
|
||||
batch["cond_idxs"] = batch["cond_idxs"]
|
||||
# compute conditioning mel specs
|
||||
# transform waves from torch.Size([B, num_cond_samples, 1, T] to torch.Size([B * num_cond_samples, 1, T] because if is faster than iterate the tensor
|
||||
B, num_cond_samples, C, T = batch["conditioning"].size()
|
||||
conditioning_reshaped = batch["conditioning"].view(B*num_cond_samples, C, T)
|
||||
paired_conditioning_mel = self.torch_mel_spectrogram(conditioning_reshaped)
|
||||
# transform torch.Size([B * num_cond_samples, n_mel, T_mel]) in torch.Size([B, num_cond_samples, n_mel, T_mel])
|
||||
n_mel = self.torch_mel_spectrogram.n_mel_channels # paired_conditioning_mel.size(1)
|
||||
T_mel = paired_conditioning_mel.size(2)
|
||||
paired_conditioning_mel = paired_conditioning_mel.view(B, num_cond_samples, n_mel, T_mel)
|
||||
# get the conditioning embeddings
|
||||
batch["cond_mels"] = paired_conditioning_mel
|
||||
# compute codes using DVAE
|
||||
if self.config.audio.sample_rate != self.config.audio.dvae_sample_rate:
|
||||
dvae_wav = torchaudio.functional.resample(
|
||||
batch["wav"],
|
||||
orig_freq=self.config.audio.sample_rate,
|
||||
new_freq=self.config.audio.dvae_sample_rate,
|
||||
lowpass_filter_width=64,
|
||||
rolloff=0.9475937167399596,
|
||||
resampling_method="kaiser_window",
|
||||
beta=14.769656459379492,
|
||||
)
|
||||
else:
|
||||
dvae_wav = batch["wav"]
|
||||
dvae_mel_spec = self.torch_mel_spectrogram_dvae(dvae_wav)
|
||||
codes = self.dvae.get_codebook_indices(dvae_mel_spec)
|
||||
batch["audio_codes"] = codes
|
||||
# delete useless batch tensors
|
||||
del batch["padded_text"]
|
||||
del batch["wav"]
|
||||
del batch["conditioning"]
|
||||
|
||||
return batch
|
||||
|
||||
def train_step(self, batch, criterion):
|
||||
loss_dict = {}
|
||||
cond_mels = batch["cond_mels"]
|
||||
text_inputs = batch["text_inputs"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
audio_codes = batch["audio_codes"]
|
||||
wav_lengths = batch["wav_lengths"]
|
||||
|
||||
cond_lens=batch["cond_lens"]
|
||||
# Todo: implement masking on the cond slice
|
||||
cond_idxs = batch["cond_idxs"]
|
||||
|
||||
loss_text, loss_mel, _ = self.forward(text_inputs, text_lengths, audio_codes, wav_lengths, cond_mels, cond_lens)
|
||||
loss_dict["loss_text_ce"] = loss_text * self.args.gpt_loss_text_ce_weight
|
||||
loss_dict["loss_mel_ce"] = loss_mel * self.args.gpt_loss_mel_ce_weight
|
||||
loss_dict["loss"] = loss_dict["loss_text_ce"] + loss_dict["loss_mel_ce"]
|
||||
return {"model_outputs": None}, loss_dict
|
||||
|
||||
def eval_step(self, batch, criterion):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def on_epoch_start(self, trainer): # pylint: disable=W0613
|
||||
# guarante that dvae will be in eval mode after .train() on evaluation end
|
||||
self.dvae = self.dvae.eval()
|
||||
|
||||
def on_init_end(self, trainer): # pylint: disable=W0613
|
||||
# ignore similarities.pth on clearml save/upload
|
||||
if self.config.dashboard_logger.lower() == "clearml":
|
||||
from clearml.binding.frameworks import WeightsFileHandler
|
||||
WeightsFileHandler.add_pre_callback(callback_clearml_load_save)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(
|
||||
self,
|
||||
x,
|
||||
aux_input=None,
|
||||
): # pylint: disable=dangerous-default-value
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_criterion():
|
||||
return None
|
||||
|
||||
def get_sampler(self, dataset: TTSDataset, num_gpus=1):
|
||||
# sampler for DDP
|
||||
batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
return batch_sampler
|
||||
|
||||
def get_data_loader(
|
||||
self,
|
||||
config: Coqpit,
|
||||
assets: Dict,
|
||||
is_eval: bool,
|
||||
samples: Union[List[Dict], List[List]],
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None,
|
||||
) -> "DataLoader": # pylint: disable=W0613
|
||||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
# Todo: remove the randomness of dataset when it is eval
|
||||
# init dataloader
|
||||
dataset = XTTSDataset(self.config, samples, self.tokenizer, config.audio.sample_rate)
|
||||
|
||||
# wait all the DDP process to be ready
|
||||
if num_gpus > 1:
|
||||
torch.distributed.barrier()
|
||||
|
||||
# sort input sequences from short to long
|
||||
# dataset.preprocess_samples()
|
||||
|
||||
# get samplers
|
||||
sampler = self.get_sampler(dataset, num_gpus)
|
||||
|
||||
# ignore sampler when is eval because if we changed the sampler parameter we will not be able to compare previous runs
|
||||
if sampler is None or is_eval:
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
else:
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_sampler=sampler,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def get_optimizer(self) -> List:
|
||||
"""Initiate and return the optimizer based on the config parameters.
|
||||
"""
|
||||
# ToDo: deal with multi GPU training
|
||||
if self.config.optimizer_wd_only_on_weights:
|
||||
# parameters to only GPT model
|
||||
net = self.gpt
|
||||
|
||||
# normalizations
|
||||
norm_modules = (nn.BatchNorm2d, nn.InstanceNorm2d, nn.BatchNorm1d, nn.InstanceNorm1d,
|
||||
nn.BatchNorm3d, nn.InstanceNorm3d, nn.GroupNorm, nn.LayerNorm)
|
||||
# nn.Embedding
|
||||
emb_modules = (nn.Embedding, nn.EmbeddingBag)
|
||||
|
||||
param_names_notweights = set()
|
||||
all_param_names = set()
|
||||
param_map = {}
|
||||
for mn, m in net.named_modules():
|
||||
for k, v in m.named_parameters():
|
||||
v.is_bias = k.endswith(".bias")
|
||||
v.is_weight = k.endswith(".weight")
|
||||
v.is_norm = isinstance(m, norm_modules)
|
||||
v.is_emb = isinstance(m, emb_modules)
|
||||
|
||||
fpn = '%s.%s' % (mn, k) if mn else k # full param name
|
||||
all_param_names.add(fpn)
|
||||
param_map[fpn] = v
|
||||
if v.is_bias or v.is_norm or v.is_emb:
|
||||
param_names_notweights.add(fpn)
|
||||
|
||||
params_names_notweights = sorted(list(param_names_notweights))
|
||||
params_notweights = [param_map[k] for k in params_names_notweights]
|
||||
params_names_weights = sorted(list(all_param_names ^ param_names_notweights))
|
||||
params_weights = [param_map[k] for k in params_names_weights]
|
||||
|
||||
groups = [
|
||||
{ 'params': params_weights, 'weight_decay': self.config.optimizer_params["weight_decay"]},
|
||||
{ 'params': params_notweights, 'weight_decay': 0}
|
||||
]
|
||||
# torch.optim.AdamW
|
||||
opt = get_optimizer(
|
||||
self.config.optimizer,
|
||||
self.config.optimizer_params,
|
||||
self.config.lr,
|
||||
parameters=groups,
|
||||
)
|
||||
opt._group_names = [params_names_weights, params_names_notweights]
|
||||
return opt
|
||||
|
||||
return get_optimizer(
|
||||
self.config.optimizer,
|
||||
self.config.optimizer_params,
|
||||
self.config.lr,
|
||||
# optimize only for the GPT model
|
||||
parameters=self.gpt.parameters(),
|
||||
)
|
||||
|
||||
def get_scheduler(self, optimizer) -> List:
|
||||
"""Set the scheduler for the optimizer.
|
||||
|
||||
Args:
|
||||
optimizer: `torch.optim.Optimizer`.
|
||||
"""
|
||||
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, optimizer)
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
config,
|
||||
checkpoint_path,
|
||||
eval=False,
|
||||
strict=True,
|
||||
cache_storage="/tmp/tts_cache",
|
||||
target_protocol="s3",
|
||||
target_options={"anon": True},
|
||||
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
|
||||
"""Load the model checkpoint and setup for training or inference"""
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
# load the model weights
|
||||
self.gpt.load_state_dict(state, strict=strict)
|
||||
|
||||
if eval:
|
||||
self.eval()
|
||||
self.set_inference()
|
||||
assert not self.training
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "GPTConfig", samples: Union[List[List], List[Dict]] = None):
|
||||
"""Initiate model from config
|
||||
|
||||
Args:
|
||||
config (GPTConfig): Model config.
|
||||
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
||||
Defaults to None.
|
||||
"""
|
||||
return GPTTrainer(config)
|
||||
|
|
@ -0,0 +1,364 @@
|
|||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
|
||||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTTrainer, GPTArgs, XttsAudioConfig, GPTConfig
|
||||
|
||||
|
||||
config_coqui_MLS_metadata_train_with_previous_audio_key_de = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/MLS/mls_german",
|
||||
meta_file_train="metadata_train_with_previous_audio_key.csv",
|
||||
language="de",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_MLS_metadata_test_with_previous_audio_key_de = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/MLS/mls_german",
|
||||
meta_file_train="metadata_test_with_previous_audio_key.csv",
|
||||
language="de",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_MLS_metadata_dev_with_previous_audio_key_de = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/MLS/mls_german",
|
||||
meta_file_train="metadata_dev_with_previous_audio_key.csv",
|
||||
language="de",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_mls_french_metadata_with_previous_audio_key_fr = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/MLS/mls_french/",
|
||||
meta_file_train="metadata_with_previous_audio_key.csv",
|
||||
language="fr",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_mls_spanish_metadata_with_previous_audio_key_es = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/MLS/mls_spanish/",
|
||||
meta_file_train="/raid/datasets/MLS/mls_spanish/metadata_with_previous_audio_key.csv",
|
||||
language="es",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_mls_italian_metadata_with_previous_audio_key_it = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/MLS/mls_italian/",
|
||||
meta_file_train="/raid/datasets/MLS/mls_italian/metadata_with_previous_audio_key.csv",
|
||||
language="it",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_mls_portuguese_metadata_with_previous_audio_key_pt = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/MLS/mls_portuguese/",
|
||||
meta_file_train="/raid/datasets/MLS/mls_portuguese/metadata_with_previous_audio_key.csv",
|
||||
language="pt",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_mls_polish_metadata_with_previous_audio_key_pl = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/MLS/mls_polish/",
|
||||
meta_file_train="/raid/datasets/MLS/mls_polish/metadata_with_previous_audio_key.csv",
|
||||
language="pl",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_it_train_with_scores_it = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_it_train_with_scores.csv",
|
||||
language="it",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_it_test_with_scores_it = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_it_test_with_scores.csv",
|
||||
language="it",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_it_dev_with_scores_it = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_it_dev_with_scores.csv",
|
||||
language="it",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_pt_train_with_scores_pt = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_pt_train_with_scores.csv",
|
||||
language="pt",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_pt_test_with_scores_pt = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_pt_test_with_scores.csv",
|
||||
language="pt",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_pt_dev_with_scores_pt = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_pt_dev_with_scores.csv",
|
||||
language="pt",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_en_train_en = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_en_train.csv",
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_en_test_en = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_en_test.csv",
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_en_dev_en = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_en_dev.csv",
|
||||
language="en",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_tr_validated_tr = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_tr_validated.csv",
|
||||
language="tr",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_ru_validated_ru = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_ru_validated.csv",
|
||||
language="ru",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_nl_validated_nl = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_nl_validated.csv",
|
||||
language="nl",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_cs_validated_cs = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_cs_validated.csv",
|
||||
language="cs",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_fr_validated_fr = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_fr_validated.csv",
|
||||
language="fr",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_es_validated_es = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_es_validated.csv",
|
||||
language="es",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_pl_validated_pl = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_pl_validated.csv",
|
||||
language="pl",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_ar_validated_ar = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_ar_validated.csv",
|
||||
language="ar",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_zh_CN_validated_zh_cn = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_zh-CN_validated.csv",
|
||||
language="zh-cn",
|
||||
)
|
||||
|
||||
|
||||
config_coqui_common_voice_metafile_ja_validated_ja = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
dataset_name="coqui",
|
||||
path="/raid/datasets/common_voice/",
|
||||
meta_file_train="/raid/datasets/common_voice/metafile_ja_validated.csv",
|
||||
language="ja",
|
||||
)
|
||||
|
||||
# DATASETS_CONFIG_LIST=[config_coqui_MLS_metadata_train_with_previous_audio_key_de, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_MLS_metadata_dev_with_previous_audio_key_de, config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it, config_coqui_mls_portuguese_metadata_with_previous_audio_key_pt, config_coqui_mls_polish_metadata_with_previous_audio_key_pl, config_coqui_common_voice_metafile_it_train_with_scores_it, config_coqui_common_voice_metafile_it_test_with_scores_it, config_coqui_common_voice_metafile_it_dev_with_scores_it, config_coqui_common_voice_metafile_pt_train_with_scores_pt, config_coqui_common_voice_metafile_pt_test_with_scores_pt, config_coqui_common_voice_metafile_pt_dev_with_scores_pt, config_coqui_common_voice_metafile_en_train_en, config_coqui_common_voice_metafile_en_test_en, config_coqui_common_voice_metafile_en_dev_en, config_coqui_common_voice_metafile_tr_validated_tr, config_coqui_common_voice_metafile_ru_validated_ru, config_coqui_common_voice_metafile_nl_validated_nl, config_coqui_common_voice_metafile_cs_validated_cs, config_coqui_common_voice_metafile_fr_validated_fr, config_coqui_common_voice_metafile_es_validated_es, config_coqui_common_voice_metafile_pl_validated_pl, config_coqui_common_voice_metafile_ar_validated_ar, config_coqui_common_voice_metafile_zh_CN_validated_zh_cn, config_coqui_common_voice_metafile_ja_validated_ja]
|
||||
|
||||
# DATASETS_CONFIG_LIST = [config_coqui_mls_french_metadata_with_previous_audio_key_fr, config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_spanish_metadata_with_previous_audio_key_es, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
|
||||
|
||||
DATASETS_CONFIG_LIST = [config_coqui_MLS_metadata_test_with_previous_audio_key_de, config_coqui_mls_italian_metadata_with_previous_audio_key_it]
|
||||
|
||||
def freeze_layers(trainer):
|
||||
pass
|
||||
|
||||
def main():
|
||||
# init args and config
|
||||
model_args = GPTArgs(
|
||||
max_conditioning_length=132300, # 6 secs
|
||||
min_conditioning_length=66150, # 3 secs
|
||||
debug_loading_failures=True,
|
||||
max_wav_length=255995, # ~11.6 seconds
|
||||
max_text_length=200,
|
||||
tokenizer_file="/raid/datasets/xtts_models/vocab.json",
|
||||
mel_norm_file="/raid/datasets/xtts_models/mel_stats.pth",
|
||||
dvae_checkpoint="/raid/datasets/xtts_models/dvae.pth",
|
||||
gpt_checkpoint="/raid/datasets/xtts_models/gpt.pth",
|
||||
gpt_num_audio_tokens=8194,
|
||||
gpt_start_audio_token=8192,
|
||||
gpt_stop_audio_token=8193,
|
||||
)
|
||||
audio_config = XttsAudioConfig(
|
||||
sample_rate=22050, # autoregressive SR
|
||||
dvae_sample_rate=22050,
|
||||
diffusion_sample_rate=24000,
|
||||
output_sample_rate=24000
|
||||
)
|
||||
config = GPTConfig(
|
||||
output_path=OUT_PATH,
|
||||
model_args=model_args,
|
||||
run_name=RUN_NAME,
|
||||
project_name=PROJECT_NAME,
|
||||
run_description="""
|
||||
GPT XTTS training
|
||||
""",
|
||||
dashboard_logger=DASHBOARD_LOGGER,
|
||||
logger_uri=LOGGER_URI,
|
||||
audio=audio_config,
|
||||
batch_size=BATCH_SIZE,
|
||||
batch_group_size=48,
|
||||
eval_batch_size=BATCH_SIZE,
|
||||
num_loader_workers=8,
|
||||
eval_split_max_size=256,
|
||||
print_step=50,
|
||||
plot_step=100,
|
||||
log_model_step=1000,
|
||||
save_step=10000,
|
||||
save_n_checkpoints=1,
|
||||
save_checkpoints=True,
|
||||
# target_loss="loss",
|
||||
print_eval=False,
|
||||
# Optimizer values like tortoise. However, they used pytorch implementation with modifications to not apply WD to non-weight parameters. We are using default Pytorch
|
||||
optimizer="AdamW",
|
||||
optimizer_wd_only_on_weights=True,
|
||||
optimizer_params={"betas": [.9, .96], "eps": 1e-8, "weight_decay": 1e-2},
|
||||
lr=5e-06, # learning rate
|
||||
# lr=1e-4, # learning rate
|
||||
# ToDo: implement 500 step warmup like tortoise and EMA weights replaces LR decay with rate: .999
|
||||
lr_scheduler="MultiStepLR",
|
||||
# it was adjusted accordly for the new step scheme
|
||||
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
||||
)
|
||||
|
||||
# init the model from config
|
||||
model = GPTTrainer.init_from_config(config)
|
||||
|
||||
# load training samples
|
||||
train_samples, eval_samples = load_tts_samples(
|
||||
DATASETS_CONFIG_LIST,
|
||||
eval_split=True,
|
||||
eval_split_max_size=config.eval_split_max_size,
|
||||
eval_split_size=config.eval_split_size,
|
||||
)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainerArgs(restore_path=RESTORE_PATH, skip_train_epoch=SKIP_TRAIN_EPOCH, start_with_eval=START_WITH_EVAL, grad_accum_steps=GRAD_ACUMM_STEPS),
|
||||
config,
|
||||
output_path=OUT_PATH,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
callbacks={"on_epoch_start": freeze_layers}
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
RUN_NAME = "GPT_XTTS"
|
||||
PROJECT_NAME = "XTTS"
|
||||
OUT_PATH = "/raid/edresson/dev/Checkpoints/XTTS_style_emb/"
|
||||
DASHBOARD_LOGGER = "clearml"
|
||||
LOGGER_URI = "s3://coqui-ai-models/TTS/Checkpoints/XTTS_style_emb/"
|
||||
RESTORE_PATH = None
|
||||
SKIP_TRAIN_EPOCH = False
|
||||
START_WITH_EVAL = True
|
||||
BATCH_SIZE = 9
|
||||
GRAD_ACUMM_STEPS = 28
|
||||
|
||||
# debug
|
||||
DASHBOARD_LOGGER = "tensorboard"
|
||||
LOGGER_URI = None
|
||||
RESTORE_PATH = None
|
||||
BATCH_SIZE = 2
|
||||
GRAD_ACUMM_STEPS = 1
|
||||
NUM_LOADERS = 1
|
||||
|
||||
|
||||
|
||||
main()
|
Loading…
Reference in New Issue