Run `make style`

This commit is contained in:
Aarni Koskela 2023-12-04 10:38:07 +02:00
parent bd172dabbf
commit d6ea806469
11 changed files with 121 additions and 92 deletions

View File

@ -168,9 +168,7 @@ class TTS(nn.Module):
self.synthesizer = None self.synthesizer = None
self.model_name = model_name self.model_name = model_name
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name( model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
model_name
)
# init synthesizer # init synthesizer
# None values are fetch from the model # None values are fetch from the model

View File

@ -224,7 +224,7 @@ def main():
const=True, const=True,
default=False, default=False,
) )
# args for multi-speaker synthesis # args for multi-speaker synthesis
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None) parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None) parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)

View File

@ -17,9 +17,12 @@ def read_json_with_comments(json_path):
with fsspec.open(json_path, "r", encoding="utf-8") as f: with fsspec.open(json_path, "r", encoding="utf-8") as f:
input_str = f.read() input_str = f.read()
# handle comments but not urls with // # handle comments but not urls with //
input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str) input_str = re.sub(
r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str
)
return json.loads(input_str) return json.loads(input_str)
def register_config(model_name: str) -> Coqpit: def register_config(model_name: str) -> Coqpit:
"""Find the right config for the given model name. """Find the right config for the given model name.

View File

@ -19,9 +19,10 @@ def list_audios(basePath, contains=None):
# return the set of files that are valid # return the set of files that are valid
return list_files(basePath, validExts=audio_types, contains=contains) return list_files(basePath, validExts=audio_types, contains=contains)
def list_files(basePath, validExts=None, contains=None): def list_files(basePath, validExts=None, contains=None):
# loop over the directory structure # loop over the directory structure
for (rootDir, dirNames, filenames) in os.walk(basePath): for rootDir, dirNames, filenames in os.walk(basePath):
# loop over the filenames in the current directory # loop over the filenames in the current directory
for filename in filenames: for filename in filenames:
# if the contains string is not none and the filename does not contain # if the contains string is not none and the filename does not contain
@ -30,7 +31,7 @@ def list_files(basePath, validExts=None, contains=None):
continue continue
# determine the file extension of the current file # determine the file extension of the current file
ext = filename[filename.rfind("."):].lower() ext = filename[filename.rfind(".") :].lower()
# check to see if the file is an audio and should be processed # check to see if the file is an audio and should be processed
if validExts is None or ext.endswith(validExts): if validExts is None or ext.endswith(validExts):
@ -38,7 +39,16 @@ def list_files(basePath, validExts=None, contains=None):
audioPath = os.path.join(rootDir, filename) audioPath = os.path.join(rootDir, filename)
yield audioPath yield audioPath
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
def format_audio_list(
audio_files,
target_language="en",
out_path=None,
buffer=0.2,
eval_percentage=0.15,
speaker_name="coqui",
gradio_progress=None,
):
audio_total_size = 0 audio_total_size = 0
# make sure that ooutput file exists # make sure that ooutput file exists
os.makedirs(out_path, exist_ok=True) os.makedirs(out_path, exist_ok=True)
@ -63,7 +73,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
wav = torch.mean(wav, dim=0, keepdim=True) wav = torch.mean(wav, dim=0, keepdim=True)
wav = wav.squeeze() wav = wav.squeeze()
audio_total_size += (wav.size(-1) / sr) audio_total_size += wav.size(-1) / sr
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language) segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
segments = list(segments) segments = list(segments)
@ -88,7 +98,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
# get previous sentence end # get previous sentence end
previous_word_end = words_list[word_idx - 1].end previous_word_end = words_list[word_idx - 1].end
# add buffer or get the silence midle between the previous sentence and the current one # add buffer or get the silence midle between the previous sentence and the current one
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2) sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2)
sentence = word.word sentence = word.word
first_word = False first_word = False
@ -112,19 +122,16 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
# Average the current word end and next word start # Average the current word end and next word start
word_end = min((word.end + next_word_start) / 2, word.end + buffer) word_end = min((word.end + next_word_start) / 2, word.end + buffer)
absoulte_path = os.path.join(out_path, audio_file) absoulte_path = os.path.join(out_path, audio_file)
os.makedirs(os.path.dirname(absoulte_path), exist_ok=True) os.makedirs(os.path.dirname(absoulte_path), exist_ok=True)
i += 1 i += 1
first_word = True first_word = True
audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0) audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0)
# if the audio is too short ignore it (i.e < 0.33 seconds) # if the audio is too short ignore it (i.e < 0.33 seconds)
if audio.size(-1) >= sr/3: if audio.size(-1) >= sr / 3:
torchaudio.save(absoulte_path, torchaudio.save(absoulte_path, audio, sr)
audio,
sr
)
else: else:
continue continue
@ -134,21 +141,21 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
df = pandas.DataFrame(metadata) df = pandas.DataFrame(metadata)
df = df.sample(frac=1) df = df.sample(frac=1)
num_val_samples = int(len(df)*eval_percentage) num_val_samples = int(len(df) * eval_percentage)
df_eval = df[:num_val_samples] df_eval = df[:num_val_samples]
df_train = df[num_val_samples:] df_train = df[num_val_samples:]
df_train = df_train.sort_values('audio_file') df_train = df_train.sort_values("audio_file")
train_metadata_path = os.path.join(out_path, "metadata_train.csv") train_metadata_path = os.path.join(out_path, "metadata_train.csv")
df_train.to_csv(train_metadata_path, sep="|", index=False) df_train.to_csv(train_metadata_path, sep="|", index=False)
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv") eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
df_eval = df_eval.sort_values('audio_file') df_eval = df_eval.sort_values("audio_file")
df_eval.to_csv(eval_metadata_path, sep="|", index=False) df_eval.to_csv(eval_metadata_path, sep="|", index=False)
# deallocate VRAM and RAM # deallocate VRAM and RAM
del asr_model, df_train, df_eval, df, metadata del asr_model, df_train, df_eval, df, metadata
gc.collect() gc.collect()
return train_metadata_path, eval_metadata_path, audio_total_size return train_metadata_path, eval_metadata_path, audio_total_size

View File

@ -25,7 +25,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
BATCH_SIZE = batch_size # set here the batch size BATCH_SIZE = batch_size # set here the batch size
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
# Define here the dataset that you want to use for the fine-tuning on. # Define here the dataset that you want to use for the fine-tuning on.
config_dataset = BaseDatasetConfig( config_dataset = BaseDatasetConfig(
formatter="coqui", formatter="coqui",
@ -43,7 +42,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/") CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
# DVAE files # DVAE files
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth" DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth" MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
@ -55,8 +53,9 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
# download DVAE files if needed # download DVAE files if needed
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE): if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
print(" > Downloading DVAE files!") print(" > Downloading DVAE files!")
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) ModelManager._download_model_files(
[MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
)
# Download XTTS v2.0 checkpoint if needed # Download XTTS v2.0 checkpoint if needed
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json" TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
@ -160,7 +159,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
# get the longest text audio file to use as speaker reference # get the longest text audio file to use as speaker reference
samples_len = [len(item["text"].split(" ")) for item in train_samples] samples_len = [len(item["text"].split(" ")) for item in train_samples]
longest_text_idx = samples_len.index(max(samples_len)) longest_text_idx = samples_len.index(max(samples_len))
speaker_ref = train_samples[longest_text_idx]["audio_file"] speaker_ref = train_samples[longest_text_idx]["audio_file"]
trainer_out_path = trainer.output_path trainer_out_path = trainer.output_path

View File

@ -20,7 +20,10 @@ def clear_gpu_cache():
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
XTTS_MODEL = None XTTS_MODEL = None
def load_model(xtts_checkpoint, xtts_config, xtts_vocab): def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
global XTTS_MODEL global XTTS_MODEL
clear_gpu_cache() clear_gpu_cache()
@ -37,17 +40,23 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
print("Model Loaded!") print("Model Loaded!")
return "Model Loaded!" return "Model Loaded!"
def run_tts(lang, tts_text, speaker_audio_file): def run_tts(lang, tts_text, speaker_audio_file):
if XTTS_MODEL is None or not speaker_audio_file: if XTTS_MODEL is None or not speaker_audio_file:
return "You need to run the previous step to load the model !!", None, None return "You need to run the previous step to load the model !!", None, None
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
audio_path=speaker_audio_file,
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
max_ref_length=XTTS_MODEL.config.max_ref_len,
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
)
out = XTTS_MODEL.inference( out = XTTS_MODEL.inference(
text=tts_text, text=tts_text,
language=lang, language=lang,
gpt_cond_latent=gpt_cond_latent, gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding, speaker_embedding=speaker_embedding,
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
length_penalty=XTTS_MODEL.config.length_penalty, length_penalty=XTTS_MODEL.config.length_penalty,
repetition_penalty=XTTS_MODEL.config.repetition_penalty, repetition_penalty=XTTS_MODEL.config.repetition_penalty,
top_k=XTTS_MODEL.config.top_k, top_k=XTTS_MODEL.config.top_k,
@ -62,8 +71,6 @@ def run_tts(lang, tts_text, speaker_audio_file):
return "Speech generated !", out_path, speaker_audio_file return "Speech generated !", out_path, speaker_audio_file
# define a logger to redirect # define a logger to redirect
class Logger: class Logger:
def __init__(self, filename="log.out"): def __init__(self, filename="log.out"):
@ -82,6 +89,7 @@ class Logger:
def isatty(self): def isatty(self):
return False return False
# redirect stdout and stderr to a file # redirect stdout and stderr to a file
sys.stdout = Logger() sys.stdout = Logger()
sys.stderr = sys.stdout sys.stderr = sys.stdout
@ -90,13 +98,10 @@ sys.stderr = sys.stdout
# logging.basicConfig(stream=sys.stdout, level=logging.INFO) # logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)]
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.StreamHandler(sys.stdout)
]
) )
def read_logs(): def read_logs():
sys.stdout.flush() sys.stdout.flush()
with open(sys.stdout.log_file, "r") as f: with open(sys.stdout.log_file, "r") as f:
@ -104,7 +109,6 @@ def read_logs():
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="""XTTS fine-tuning demo\n\n""" description="""XTTS fine-tuning demo\n\n"""
""" """
@ -187,12 +191,10 @@ if __name__ == "__main__":
"zh", "zh",
"hu", "hu",
"ko", "ko",
"ja" "ja",
], ],
) )
progress_data = gr.Label( progress_data = gr.Label(label="Progress:")
label="Progress:"
)
logs = gr.Textbox( logs = gr.Textbox(
label="Logs:", label="Logs:",
interactive=False, interactive=False,
@ -200,20 +202,30 @@ if __name__ == "__main__":
demo.load(read_logs, None, logs, every=1) demo.load(read_logs, None, logs, every=1)
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset") prompt_compute_btn = gr.Button(value="Step 1 - Create dataset")
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)): def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
clear_gpu_cache() clear_gpu_cache()
out_path = os.path.join(out_path, "dataset") out_path = os.path.join(out_path, "dataset")
os.makedirs(out_path, exist_ok=True) os.makedirs(out_path, exist_ok=True)
if audio_path is None: if audio_path is None:
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", "" return (
"You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!",
"",
"",
)
else: else:
try: try:
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) train_meta, eval_meta, audio_total_size = format_audio_list(
audio_path, target_language=language, out_path=out_path, gradio_progress=progress
)
except: except:
traceback.print_exc() traceback.print_exc()
error = traceback.format_exc() error = traceback.format_exc()
return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", "" return (
f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}",
"",
"",
)
clear_gpu_cache() clear_gpu_cache()
@ -233,7 +245,7 @@ if __name__ == "__main__":
eval_csv = gr.Textbox( eval_csv = gr.Textbox(
label="Eval CSV:", label="Eval CSV:",
) )
num_epochs = gr.Slider( num_epochs = gr.Slider(
label="Number of epochs:", label="Number of epochs:",
minimum=1, minimum=1,
maximum=100, maximum=100,
@ -261,9 +273,7 @@ if __name__ == "__main__":
step=1, step=1,
value=args.max_audio_length, value=args.max_audio_length,
) )
progress_train = gr.Label( progress_train = gr.Label(label="Progress:")
label="Progress:"
)
logs_tts_train = gr.Textbox( logs_tts_train = gr.Textbox(
label="Logs:", label="Logs:",
interactive=False, interactive=False,
@ -271,18 +281,41 @@ if __name__ == "__main__":
demo.load(read_logs, None, logs_tts_train, every=1) demo.load(read_logs, None, logs_tts_train, every=1)
train_btn = gr.Button(value="Step 2 - Run the training") train_btn = gr.Button(value="Step 2 - Run the training")
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length): def train_model(
language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length
):
clear_gpu_cache() clear_gpu_cache()
if not train_csv or not eval_csv: if not train_csv or not eval_csv:
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" return (
"You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !",
"",
"",
"",
"",
)
try: try:
# convert seconds to waveform frames # convert seconds to waveform frames
max_audio_length = int(max_audio_length * 22050) max_audio_length = int(max_audio_length * 22050)
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(
language,
num_epochs,
batch_size,
grad_acumm,
train_csv,
eval_csv,
output_path=output_path,
max_audio_length=max_audio_length,
)
except: except:
traceback.print_exc() traceback.print_exc()
error = traceback.format_exc() error = traceback.format_exc()
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" return (
f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}",
"",
"",
"",
"",
)
# copy original files to avoid parameters changes issues # copy original files to avoid parameters changes issues
os.system(f"cp {config_path} {exp_path}") os.system(f"cp {config_path} {exp_path}")
@ -309,9 +342,7 @@ if __name__ == "__main__":
label="XTTS vocab path:", label="XTTS vocab path:",
value="", value="",
) )
progress_load = gr.Label( progress_load = gr.Label(label="Progress:")
label="Progress:"
)
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model") load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
with gr.Column() as col2: with gr.Column() as col2:
@ -339,7 +370,7 @@ if __name__ == "__main__":
"hu", "hu",
"ko", "ko",
"ja", "ja",
] ],
) )
tts_text = gr.Textbox( tts_text = gr.Textbox(
label="Input Text.", label="Input Text.",
@ -348,9 +379,7 @@ if __name__ == "__main__":
tts_btn = gr.Button(value="Step 4 - Inference") tts_btn = gr.Button(value="Step 4 - Inference")
with gr.Column() as col3: with gr.Column() as col3:
progress_gen = gr.Label( progress_gen = gr.Label(label="Progress:")
label="Progress:"
)
tts_output_audio = gr.Audio(label="Generated Audio.") tts_output_audio = gr.Audio(label="Generated Audio.")
reference_audio = gr.Audio(label="Reference audio used.") reference_audio = gr.Audio(label="Reference audio used.")
@ -368,7 +397,6 @@ if __name__ == "__main__":
], ],
) )
train_btn.click( train_btn.click(
fn=train_model, fn=train_model,
inputs=[ inputs=[
@ -383,14 +411,10 @@ if __name__ == "__main__":
], ],
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
) )
load_btn.click( load_btn.click(
fn=load_model, fn=load_model,
inputs=[ inputs=[xtts_checkpoint, xtts_config, xtts_vocab],
xtts_checkpoint,
xtts_config,
xtts_vocab
],
outputs=[progress_load], outputs=[progress_load],
) )
@ -404,9 +428,4 @@ if __name__ == "__main__":
outputs=[progress_gen, tts_output_audio, reference_audio], outputs=[progress_gen, tts_output_audio, reference_audio],
) )
demo.launch( demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0")
share=True,
debug=False,
server_port=args.port,
server_name="0.0.0.0"
)

View File

@ -390,7 +390,7 @@ class GPTTrainer(BaseTTS):
loader = DataLoader( loader = DataLoader(
dataset, dataset,
sampler=sampler, sampler=sampler,
batch_size = config.eval_batch_size if is_eval else config.batch_size, batch_size=config.eval_batch_size if is_eval else config.batch_size,
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False, pin_memory=False,

View File

@ -1,34 +1,35 @@
import torch import torch
class SpeakerManager():
class SpeakerManager:
def __init__(self, speaker_file_path=None): def __init__(self, speaker_file_path=None):
self.speakers = torch.load(speaker_file_path) self.speakers = torch.load(speaker_file_path)
@property @property
def name_to_id(self): def name_to_id(self):
return self.speakers.keys() return self.speakers.keys()
@property @property
def num_speakers(self): def num_speakers(self):
return len(self.name_to_id) return len(self.name_to_id)
@property @property
def speaker_names(self): def speaker_names(self):
return list(self.name_to_id.keys()) return list(self.name_to_id.keys())
class LanguageManager():
class LanguageManager:
def __init__(self, config): def __init__(self, config):
self.langs = config["languages"] self.langs = config["languages"]
@property @property
def name_to_id(self): def name_to_id(self):
return self.langs return self.langs
@property @property
def num_languages(self): def num_languages(self):
return len(self.name_to_id) return len(self.name_to_id)
@property @property
def language_names(self): def language_names(self):
return list(self.name_to_id) return list(self.name_to_id)

View File

@ -299,7 +299,7 @@ class ForwardTTS(BaseTTS):
if config.use_d_vector_file: if config.use_d_vector_file:
self.embedded_speaker_dim = config.d_vector_dim self.embedded_speaker_dim = config.d_vector_dim
if self.args.d_vector_dim != self.args.hidden_channels: if self.args.d_vector_dim != self.args.hidden_channels:
#self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) # self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels) self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
# init speaker embedding layer # init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file: if config.use_speaker_embedding and not config.use_d_vector_file:
@ -404,13 +404,13 @@ class ForwardTTS(BaseTTS):
# [B, T, C] # [B, T, C]
x_emb = self.emb(x) x_emb = self.emb(x)
# encoder pass # encoder pass
#o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) # o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g) o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g)
# speaker conditioning # speaker conditioning
# TODO: try different ways of conditioning # TODO: try different ways of conditioning
if g is not None: if g is not None:
if hasattr(self, "proj_g"): if hasattr(self, "proj_g"):
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1) g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
o_en = o_en + g o_en = o_en + g
return o_en, x_mask, g, x_emb return o_en, x_mask, g, x_emb

View File

@ -11,7 +11,7 @@ from TTS.tts.layers.xtts.gpt import GPT
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.layers.xtts.stream_generator import init_stream_support
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
@ -410,12 +410,14 @@ class Xtts(BaseTTS):
if speaker_id is not None: if speaker_id is not None:
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values() gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings) return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
settings.update({ settings.update(
"gpt_cond_len": config.gpt_cond_len, {
"gpt_cond_chunk_len": config.gpt_cond_chunk_len, "gpt_cond_len": config.gpt_cond_len,
"max_ref_len": config.max_ref_len, "gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"sound_norm_refs": config.sound_norm_refs, "max_ref_len": config.max_ref_len,
}) "sound_norm_refs": config.sound_norm_refs,
}
)
return self.full_inference(text, speaker_wav, language, **settings) return self.full_inference(text, speaker_wav, language, **settings)
@torch.inference_mode() @torch.inference_mode()

View File

@ -335,7 +335,7 @@ class Synthesizer(nn.Module):
# handle multi-lingual # handle multi-lingual
language_id = None language_id = None
if self.tts_languages_file or ( if self.tts_languages_file or (
hasattr(self.tts_model, "language_manager") hasattr(self.tts_model, "language_manager")
and self.tts_model.language_manager is not None and self.tts_model.language_manager is not None
and not self.tts_config.model == "xtts" and not self.tts_config.model == "xtts"
): ):