Run `make style`

This commit is contained in:
Aarni Koskela 2023-12-04 10:38:07 +02:00
parent bd172dabbf
commit d6ea806469
11 changed files with 121 additions and 92 deletions

View File

@ -168,9 +168,7 @@ class TTS(nn.Module):
self.synthesizer = None
self.model_name = model_name
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
model_name
)
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
# init synthesizer
# None values are fetch from the model

View File

@ -224,7 +224,7 @@ def main():
const=True,
default=False,
)
# args for multi-speaker synthesis
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)

View File

@ -17,9 +17,12 @@ def read_json_with_comments(json_path):
with fsspec.open(json_path, "r", encoding="utf-8") as f:
input_str = f.read()
# handle comments but not urls with //
input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str)
input_str = re.sub(
r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str
)
return json.loads(input_str)
def register_config(model_name: str) -> Coqpit:
"""Find the right config for the given model name.

View File

@ -19,9 +19,10 @@ def list_audios(basePath, contains=None):
# return the set of files that are valid
return list_files(basePath, validExts=audio_types, contains=contains)
def list_files(basePath, validExts=None, contains=None):
# loop over the directory structure
for (rootDir, dirNames, filenames) in os.walk(basePath):
for rootDir, dirNames, filenames in os.walk(basePath):
# loop over the filenames in the current directory
for filename in filenames:
# if the contains string is not none and the filename does not contain
@ -30,7 +31,7 @@ def list_files(basePath, validExts=None, contains=None):
continue
# determine the file extension of the current file
ext = filename[filename.rfind("."):].lower()
ext = filename[filename.rfind(".") :].lower()
# check to see if the file is an audio and should be processed
if validExts is None or ext.endswith(validExts):
@ -38,7 +39,16 @@ def list_files(basePath, validExts=None, contains=None):
audioPath = os.path.join(rootDir, filename)
yield audioPath
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
def format_audio_list(
audio_files,
target_language="en",
out_path=None,
buffer=0.2,
eval_percentage=0.15,
speaker_name="coqui",
gradio_progress=None,
):
audio_total_size = 0
# make sure that ooutput file exists
os.makedirs(out_path, exist_ok=True)
@ -63,7 +73,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
wav = torch.mean(wav, dim=0, keepdim=True)
wav = wav.squeeze()
audio_total_size += (wav.size(-1) / sr)
audio_total_size += wav.size(-1) / sr
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
segments = list(segments)
@ -88,7 +98,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
# get previous sentence end
previous_word_end = words_list[word_idx - 1].end
# add buffer or get the silence midle between the previous sentence and the current one
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2)
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2)
sentence = word.word
first_word = False
@ -112,19 +122,16 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
# Average the current word end and next word start
word_end = min((word.end + next_word_start) / 2, word.end + buffer)
absoulte_path = os.path.join(out_path, audio_file)
os.makedirs(os.path.dirname(absoulte_path), exist_ok=True)
i += 1
first_word = True
audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0)
audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0)
# if the audio is too short ignore it (i.e < 0.33 seconds)
if audio.size(-1) >= sr/3:
torchaudio.save(absoulte_path,
audio,
sr
)
if audio.size(-1) >= sr / 3:
torchaudio.save(absoulte_path, audio, sr)
else:
continue
@ -134,21 +141,21 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
df = pandas.DataFrame(metadata)
df = df.sample(frac=1)
num_val_samples = int(len(df)*eval_percentage)
num_val_samples = int(len(df) * eval_percentage)
df_eval = df[:num_val_samples]
df_train = df[num_val_samples:]
df_train = df_train.sort_values('audio_file')
df_train = df_train.sort_values("audio_file")
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
df_train.to_csv(train_metadata_path, sep="|", index=False)
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
df_eval = df_eval.sort_values('audio_file')
df_eval = df_eval.sort_values("audio_file")
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
# deallocate VRAM and RAM
del asr_model, df_train, df_eval, df, metadata
gc.collect()
return train_metadata_path, eval_metadata_path, audio_total_size
return train_metadata_path, eval_metadata_path, audio_total_size

View File

@ -25,7 +25,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
BATCH_SIZE = batch_size # set here the batch size
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
# Define here the dataset that you want to use for the fine-tuning on.
config_dataset = BaseDatasetConfig(
formatter="coqui",
@ -43,7 +42,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
# DVAE files
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
@ -55,8 +53,9 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
# download DVAE files if needed
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
print(" > Downloading DVAE files!")
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
ModelManager._download_model_files(
[MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
)
# Download XTTS v2.0 checkpoint if needed
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
@ -160,7 +159,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
# get the longest text audio file to use as speaker reference
samples_len = [len(item["text"].split(" ")) for item in train_samples]
longest_text_idx = samples_len.index(max(samples_len))
longest_text_idx = samples_len.index(max(samples_len))
speaker_ref = train_samples[longest_text_idx]["audio_file"]
trainer_out_path = trainer.output_path

View File

@ -20,7 +20,10 @@ def clear_gpu_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
XTTS_MODEL = None
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
global XTTS_MODEL
clear_gpu_cache()
@ -37,17 +40,23 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
print("Model Loaded!")
return "Model Loaded!"
def run_tts(lang, tts_text, speaker_audio_file):
if XTTS_MODEL is None or not speaker_audio_file:
return "You need to run the previous step to load the model !!", None, None
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
audio_path=speaker_audio_file,
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
max_ref_length=XTTS_MODEL.config.max_ref_len,
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
)
out = XTTS_MODEL.inference(
text=tts_text,
language=lang,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
length_penalty=XTTS_MODEL.config.length_penalty,
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
top_k=XTTS_MODEL.config.top_k,
@ -62,8 +71,6 @@ def run_tts(lang, tts_text, speaker_audio_file):
return "Speech generated !", out_path, speaker_audio_file
# define a logger to redirect
class Logger:
def __init__(self, filename="log.out"):
@ -82,6 +89,7 @@ class Logger:
def isatty(self):
return False
# redirect stdout and stderr to a file
sys.stdout = Logger()
sys.stderr = sys.stdout
@ -90,13 +98,10 @@ sys.stderr = sys.stdout
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.StreamHandler(sys.stdout)
]
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)]
)
def read_logs():
sys.stdout.flush()
with open(sys.stdout.log_file, "r") as f:
@ -104,7 +109,6 @@ def read_logs():
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="""XTTS fine-tuning demo\n\n"""
"""
@ -187,12 +191,10 @@ if __name__ == "__main__":
"zh",
"hu",
"ko",
"ja"
"ja",
],
)
progress_data = gr.Label(
label="Progress:"
)
progress_data = gr.Label(label="Progress:")
logs = gr.Textbox(
label="Logs:",
interactive=False,
@ -200,20 +202,30 @@ if __name__ == "__main__":
demo.load(read_logs, None, logs, every=1)
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset")
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
clear_gpu_cache()
out_path = os.path.join(out_path, "dataset")
os.makedirs(out_path, exist_ok=True)
if audio_path is None:
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", ""
return (
"You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!",
"",
"",
)
else:
try:
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
train_meta, eval_meta, audio_total_size = format_audio_list(
audio_path, target_language=language, out_path=out_path, gradio_progress=progress
)
except:
traceback.print_exc()
error = traceback.format_exc()
return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", ""
return (
f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}",
"",
"",
)
clear_gpu_cache()
@ -233,7 +245,7 @@ if __name__ == "__main__":
eval_csv = gr.Textbox(
label="Eval CSV:",
)
num_epochs = gr.Slider(
num_epochs = gr.Slider(
label="Number of epochs:",
minimum=1,
maximum=100,
@ -261,9 +273,7 @@ if __name__ == "__main__":
step=1,
value=args.max_audio_length,
)
progress_train = gr.Label(
label="Progress:"
)
progress_train = gr.Label(label="Progress:")
logs_tts_train = gr.Textbox(
label="Logs:",
interactive=False,
@ -271,18 +281,41 @@ if __name__ == "__main__":
demo.load(read_logs, None, logs_tts_train, every=1)
train_btn = gr.Button(value="Step 2 - Run the training")
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
def train_model(
language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length
):
clear_gpu_cache()
if not train_csv or not eval_csv:
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
return (
"You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !",
"",
"",
"",
"",
)
try:
# convert seconds to waveform frames
max_audio_length = int(max_audio_length * 22050)
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(
language,
num_epochs,
batch_size,
grad_acumm,
train_csv,
eval_csv,
output_path=output_path,
max_audio_length=max_audio_length,
)
except:
traceback.print_exc()
error = traceback.format_exc()
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", ""
return (
f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}",
"",
"",
"",
"",
)
# copy original files to avoid parameters changes issues
os.system(f"cp {config_path} {exp_path}")
@ -309,9 +342,7 @@ if __name__ == "__main__":
label="XTTS vocab path:",
value="",
)
progress_load = gr.Label(
label="Progress:"
)
progress_load = gr.Label(label="Progress:")
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
with gr.Column() as col2:
@ -339,7 +370,7 @@ if __name__ == "__main__":
"hu",
"ko",
"ja",
]
],
)
tts_text = gr.Textbox(
label="Input Text.",
@ -348,9 +379,7 @@ if __name__ == "__main__":
tts_btn = gr.Button(value="Step 4 - Inference")
with gr.Column() as col3:
progress_gen = gr.Label(
label="Progress:"
)
progress_gen = gr.Label(label="Progress:")
tts_output_audio = gr.Audio(label="Generated Audio.")
reference_audio = gr.Audio(label="Reference audio used.")
@ -368,7 +397,6 @@ if __name__ == "__main__":
],
)
train_btn.click(
fn=train_model,
inputs=[
@ -383,14 +411,10 @@ if __name__ == "__main__":
],
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
)
load_btn.click(
fn=load_model,
inputs=[
xtts_checkpoint,
xtts_config,
xtts_vocab
],
inputs=[xtts_checkpoint, xtts_config, xtts_vocab],
outputs=[progress_load],
)
@ -404,9 +428,4 @@ if __name__ == "__main__":
outputs=[progress_gen, tts_output_audio, reference_audio],
)
demo.launch(
share=True,
debug=False,
server_port=args.port,
server_name="0.0.0.0"
)
demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0")

View File

@ -390,7 +390,7 @@ class GPTTrainer(BaseTTS):
loader = DataLoader(
dataset,
sampler=sampler,
batch_size = config.eval_batch_size if is_eval else config.batch_size,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,

View File

@ -1,34 +1,35 @@
import torch
class SpeakerManager():
class SpeakerManager:
def __init__(self, speaker_file_path=None):
self.speakers = torch.load(speaker_file_path)
@property
def name_to_id(self):
return self.speakers.keys()
@property
def num_speakers(self):
return len(self.name_to_id)
@property
def speaker_names(self):
return list(self.name_to_id.keys())
class LanguageManager():
class LanguageManager:
def __init__(self, config):
self.langs = config["languages"]
@property
def name_to_id(self):
return self.langs
@property
def num_languages(self):
return len(self.name_to_id)
@property
def language_names(self):
return list(self.name_to_id)

View File

@ -299,7 +299,7 @@ class ForwardTTS(BaseTTS):
if config.use_d_vector_file:
self.embedded_speaker_dim = config.d_vector_dim
if self.args.d_vector_dim != self.args.hidden_channels:
#self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
# self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
@ -404,13 +404,13 @@ class ForwardTTS(BaseTTS):
# [B, T, C]
x_emb = self.emb(x)
# encoder pass
#o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
# o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g)
# speaker conditioning
# TODO: try different ways of conditioning
if g is not None:
if g is not None:
if hasattr(self, "proj_g"):
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
o_en = o_en + g
return o_en, x_mask, g, x_emb

View File

@ -11,7 +11,7 @@ from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec
@ -410,12 +410,14 @@ class Xtts(BaseTTS):
if speaker_id is not None:
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
settings.update({
"gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
})
settings.update(
{
"gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
}
)
return self.full_inference(text, speaker_wav, language, **settings)
@torch.inference_mode()

View File

@ -335,7 +335,7 @@ class Synthesizer(nn.Module):
# handle multi-lingual
language_id = None
if self.tts_languages_file or (
hasattr(self.tts_model, "language_manager")
hasattr(self.tts_model, "language_manager")
and self.tts_model.language_manager is not None
and not self.tts_config.model == "xtts"
):