mirror of https://github.com/coqui-ai/TTS.git
Run `make style`
This commit is contained in:
parent
bd172dabbf
commit
d6ea806469
|
@ -168,9 +168,7 @@ class TTS(nn.Module):
|
|||
self.synthesizer = None
|
||||
self.model_name = model_name
|
||||
|
||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(
|
||||
model_name
|
||||
)
|
||||
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
|
||||
|
||||
# init synthesizer
|
||||
# None values are fetch from the model
|
||||
|
|
|
@ -224,7 +224,7 @@ def main():
|
|||
const=True,
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
# args for multi-speaker synthesis
|
||||
parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
|
||||
parser.add_argument("--language_ids_file_path", type=str, help="JSON file for multi-lingual model.", default=None)
|
||||
|
|
|
@ -17,9 +17,12 @@ def read_json_with_comments(json_path):
|
|||
with fsspec.open(json_path, "r", encoding="utf-8") as f:
|
||||
input_str = f.read()
|
||||
# handle comments but not urls with //
|
||||
input_str = re.sub(r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str)
|
||||
input_str = re.sub(
|
||||
r"(\"(?:[^\"\\]|\\.)*\")|(/\*(?:.|[\\n\\r])*?\*/)|(//.*)", lambda m: m.group(1) or m.group(2) or "", input_str
|
||||
)
|
||||
return json.loads(input_str)
|
||||
|
||||
|
||||
def register_config(model_name: str) -> Coqpit:
|
||||
"""Find the right config for the given model name.
|
||||
|
||||
|
|
|
@ -19,9 +19,10 @@ def list_audios(basePath, contains=None):
|
|||
# return the set of files that are valid
|
||||
return list_files(basePath, validExts=audio_types, contains=contains)
|
||||
|
||||
|
||||
def list_files(basePath, validExts=None, contains=None):
|
||||
# loop over the directory structure
|
||||
for (rootDir, dirNames, filenames) in os.walk(basePath):
|
||||
for rootDir, dirNames, filenames in os.walk(basePath):
|
||||
# loop over the filenames in the current directory
|
||||
for filename in filenames:
|
||||
# if the contains string is not none and the filename does not contain
|
||||
|
@ -30,7 +31,7 @@ def list_files(basePath, validExts=None, contains=None):
|
|||
continue
|
||||
|
||||
# determine the file extension of the current file
|
||||
ext = filename[filename.rfind("."):].lower()
|
||||
ext = filename[filename.rfind(".") :].lower()
|
||||
|
||||
# check to see if the file is an audio and should be processed
|
||||
if validExts is None or ext.endswith(validExts):
|
||||
|
@ -38,7 +39,16 @@ def list_files(basePath, validExts=None, contains=None):
|
|||
audioPath = os.path.join(rootDir, filename)
|
||||
yield audioPath
|
||||
|
||||
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
|
||||
|
||||
def format_audio_list(
|
||||
audio_files,
|
||||
target_language="en",
|
||||
out_path=None,
|
||||
buffer=0.2,
|
||||
eval_percentage=0.15,
|
||||
speaker_name="coqui",
|
||||
gradio_progress=None,
|
||||
):
|
||||
audio_total_size = 0
|
||||
# make sure that ooutput file exists
|
||||
os.makedirs(out_path, exist_ok=True)
|
||||
|
@ -63,7 +73,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
wav = torch.mean(wav, dim=0, keepdim=True)
|
||||
|
||||
wav = wav.squeeze()
|
||||
audio_total_size += (wav.size(-1) / sr)
|
||||
audio_total_size += wav.size(-1) / sr
|
||||
|
||||
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
|
||||
segments = list(segments)
|
||||
|
@ -88,7 +98,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
# get previous sentence end
|
||||
previous_word_end = words_list[word_idx - 1].end
|
||||
# add buffer or get the silence midle between the previous sentence and the current one
|
||||
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2)
|
||||
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2)
|
||||
|
||||
sentence = word.word
|
||||
first_word = False
|
||||
|
@ -112,19 +122,16 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
|
||||
# Average the current word end and next word start
|
||||
word_end = min((word.end + next_word_start) / 2, word.end + buffer)
|
||||
|
||||
|
||||
absoulte_path = os.path.join(out_path, audio_file)
|
||||
os.makedirs(os.path.dirname(absoulte_path), exist_ok=True)
|
||||
i += 1
|
||||
first_word = True
|
||||
|
||||
audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0)
|
||||
audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0)
|
||||
# if the audio is too short ignore it (i.e < 0.33 seconds)
|
||||
if audio.size(-1) >= sr/3:
|
||||
torchaudio.save(absoulte_path,
|
||||
audio,
|
||||
sr
|
||||
)
|
||||
if audio.size(-1) >= sr / 3:
|
||||
torchaudio.save(absoulte_path, audio, sr)
|
||||
else:
|
||||
continue
|
||||
|
||||
|
@ -134,21 +141,21 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
|
|||
|
||||
df = pandas.DataFrame(metadata)
|
||||
df = df.sample(frac=1)
|
||||
num_val_samples = int(len(df)*eval_percentage)
|
||||
num_val_samples = int(len(df) * eval_percentage)
|
||||
|
||||
df_eval = df[:num_val_samples]
|
||||
df_train = df[num_val_samples:]
|
||||
|
||||
df_train = df_train.sort_values('audio_file')
|
||||
df_train = df_train.sort_values("audio_file")
|
||||
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
|
||||
df_train.to_csv(train_metadata_path, sep="|", index=False)
|
||||
|
||||
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
|
||||
df_eval = df_eval.sort_values('audio_file')
|
||||
df_eval = df_eval.sort_values("audio_file")
|
||||
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
|
||||
|
||||
# deallocate VRAM and RAM
|
||||
del asr_model, df_train, df_eval, df, metadata
|
||||
gc.collect()
|
||||
|
||||
return train_metadata_path, eval_metadata_path, audio_total_size
|
||||
return train_metadata_path, eval_metadata_path, audio_total_size
|
||||
|
|
|
@ -25,7 +25,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
|||
BATCH_SIZE = batch_size # set here the batch size
|
||||
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
|
||||
|
||||
|
||||
# Define here the dataset that you want to use for the fine-tuning on.
|
||||
config_dataset = BaseDatasetConfig(
|
||||
formatter="coqui",
|
||||
|
@ -43,7 +42,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
|||
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
|
||||
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
||||
|
||||
|
||||
# DVAE files
|
||||
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
|
||||
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
|
||||
|
@ -55,8 +53,9 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
|||
# download DVAE files if needed
|
||||
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
||||
print(" > Downloading DVAE files!")
|
||||
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
||||
|
||||
ModelManager._download_model_files(
|
||||
[MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
|
||||
)
|
||||
|
||||
# Download XTTS v2.0 checkpoint if needed
|
||||
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
|
||||
|
@ -160,7 +159,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
|||
|
||||
# get the longest text audio file to use as speaker reference
|
||||
samples_len = [len(item["text"].split(" ")) for item in train_samples]
|
||||
longest_text_idx = samples_len.index(max(samples_len))
|
||||
longest_text_idx = samples_len.index(max(samples_len))
|
||||
speaker_ref = train_samples[longest_text_idx]["audio_file"]
|
||||
|
||||
trainer_out_path = trainer.output_path
|
||||
|
|
|
@ -20,7 +20,10 @@ def clear_gpu_cache():
|
|||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
XTTS_MODEL = None
|
||||
|
||||
|
||||
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
||||
global XTTS_MODEL
|
||||
clear_gpu_cache()
|
||||
|
@ -37,17 +40,23 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
|||
print("Model Loaded!")
|
||||
return "Model Loaded!"
|
||||
|
||||
|
||||
def run_tts(lang, tts_text, speaker_audio_file):
|
||||
if XTTS_MODEL is None or not speaker_audio_file:
|
||||
return "You need to run the previous step to load the model !!", None, None
|
||||
|
||||
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
|
||||
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(
|
||||
audio_path=speaker_audio_file,
|
||||
gpt_cond_len=XTTS_MODEL.config.gpt_cond_len,
|
||||
max_ref_length=XTTS_MODEL.config.max_ref_len,
|
||||
sound_norm_refs=XTTS_MODEL.config.sound_norm_refs,
|
||||
)
|
||||
out = XTTS_MODEL.inference(
|
||||
text=tts_text,
|
||||
language=lang,
|
||||
gpt_cond_latent=gpt_cond_latent,
|
||||
speaker_embedding=speaker_embedding,
|
||||
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
|
||||
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
|
||||
length_penalty=XTTS_MODEL.config.length_penalty,
|
||||
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
|
||||
top_k=XTTS_MODEL.config.top_k,
|
||||
|
@ -62,8 +71,6 @@ def run_tts(lang, tts_text, speaker_audio_file):
|
|||
return "Speech generated !", out_path, speaker_audio_file
|
||||
|
||||
|
||||
|
||||
|
||||
# define a logger to redirect
|
||||
class Logger:
|
||||
def __init__(self, filename="log.out"):
|
||||
|
@ -82,6 +89,7 @@ class Logger:
|
|||
def isatty(self):
|
||||
return False
|
||||
|
||||
|
||||
# redirect stdout and stderr to a file
|
||||
sys.stdout = Logger()
|
||||
sys.stderr = sys.stdout
|
||||
|
@ -90,13 +98,10 @@ sys.stderr = sys.stdout
|
|||
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)]
|
||||
)
|
||||
|
||||
|
||||
def read_logs():
|
||||
sys.stdout.flush()
|
||||
with open(sys.stdout.log_file, "r") as f:
|
||||
|
@ -104,7 +109,6 @@ def read_logs():
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""XTTS fine-tuning demo\n\n"""
|
||||
"""
|
||||
|
@ -187,12 +191,10 @@ if __name__ == "__main__":
|
|||
"zh",
|
||||
"hu",
|
||||
"ko",
|
||||
"ja"
|
||||
"ja",
|
||||
],
|
||||
)
|
||||
progress_data = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
progress_data = gr.Label(label="Progress:")
|
||||
logs = gr.Textbox(
|
||||
label="Logs:",
|
||||
interactive=False,
|
||||
|
@ -200,20 +202,30 @@ if __name__ == "__main__":
|
|||
demo.load(read_logs, None, logs, every=1)
|
||||
|
||||
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset")
|
||||
|
||||
|
||||
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
|
||||
clear_gpu_cache()
|
||||
out_path = os.path.join(out_path, "dataset")
|
||||
os.makedirs(out_path, exist_ok=True)
|
||||
if audio_path is None:
|
||||
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", ""
|
||||
return (
|
||||
"You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
|
||||
train_meta, eval_meta, audio_total_size = format_audio_list(
|
||||
audio_path, target_language=language, out_path=out_path, gradio_progress=progress
|
||||
)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
error = traceback.format_exc()
|
||||
return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", ""
|
||||
return (
|
||||
f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
clear_gpu_cache()
|
||||
|
||||
|
@ -233,7 +245,7 @@ if __name__ == "__main__":
|
|||
eval_csv = gr.Textbox(
|
||||
label="Eval CSV:",
|
||||
)
|
||||
num_epochs = gr.Slider(
|
||||
num_epochs = gr.Slider(
|
||||
label="Number of epochs:",
|
||||
minimum=1,
|
||||
maximum=100,
|
||||
|
@ -261,9 +273,7 @@ if __name__ == "__main__":
|
|||
step=1,
|
||||
value=args.max_audio_length,
|
||||
)
|
||||
progress_train = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
progress_train = gr.Label(label="Progress:")
|
||||
logs_tts_train = gr.Textbox(
|
||||
label="Logs:",
|
||||
interactive=False,
|
||||
|
@ -271,18 +281,41 @@ if __name__ == "__main__":
|
|||
demo.load(read_logs, None, logs_tts_train, every=1)
|
||||
train_btn = gr.Button(value="Step 2 - Run the training")
|
||||
|
||||
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
|
||||
def train_model(
|
||||
language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length
|
||||
):
|
||||
clear_gpu_cache()
|
||||
if not train_csv or not eval_csv:
|
||||
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
|
||||
return (
|
||||
"You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
try:
|
||||
# convert seconds to waveform frames
|
||||
max_audio_length = int(max_audio_length * 22050)
|
||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
|
||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(
|
||||
language,
|
||||
num_epochs,
|
||||
batch_size,
|
||||
grad_acumm,
|
||||
train_csv,
|
||||
eval_csv,
|
||||
output_path=output_path,
|
||||
max_audio_length=max_audio_length,
|
||||
)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
error = traceback.format_exc()
|
||||
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", ""
|
||||
return (
|
||||
f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
# copy original files to avoid parameters changes issues
|
||||
os.system(f"cp {config_path} {exp_path}")
|
||||
|
@ -309,9 +342,7 @@ if __name__ == "__main__":
|
|||
label="XTTS vocab path:",
|
||||
value="",
|
||||
)
|
||||
progress_load = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
progress_load = gr.Label(label="Progress:")
|
||||
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
|
||||
|
||||
with gr.Column() as col2:
|
||||
|
@ -339,7 +370,7 @@ if __name__ == "__main__":
|
|||
"hu",
|
||||
"ko",
|
||||
"ja",
|
||||
]
|
||||
],
|
||||
)
|
||||
tts_text = gr.Textbox(
|
||||
label="Input Text.",
|
||||
|
@ -348,9 +379,7 @@ if __name__ == "__main__":
|
|||
tts_btn = gr.Button(value="Step 4 - Inference")
|
||||
|
||||
with gr.Column() as col3:
|
||||
progress_gen = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
progress_gen = gr.Label(label="Progress:")
|
||||
tts_output_audio = gr.Audio(label="Generated Audio.")
|
||||
reference_audio = gr.Audio(label="Reference audio used.")
|
||||
|
||||
|
@ -368,7 +397,6 @@ if __name__ == "__main__":
|
|||
],
|
||||
)
|
||||
|
||||
|
||||
train_btn.click(
|
||||
fn=train_model,
|
||||
inputs=[
|
||||
|
@ -383,14 +411,10 @@ if __name__ == "__main__":
|
|||
],
|
||||
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
|
||||
)
|
||||
|
||||
|
||||
load_btn.click(
|
||||
fn=load_model,
|
||||
inputs=[
|
||||
xtts_checkpoint,
|
||||
xtts_config,
|
||||
xtts_vocab
|
||||
],
|
||||
inputs=[xtts_checkpoint, xtts_config, xtts_vocab],
|
||||
outputs=[progress_load],
|
||||
)
|
||||
|
||||
|
@ -404,9 +428,4 @@ if __name__ == "__main__":
|
|||
outputs=[progress_gen, tts_output_audio, reference_audio],
|
||||
)
|
||||
|
||||
demo.launch(
|
||||
share=True,
|
||||
debug=False,
|
||||
server_port=args.port,
|
||||
server_name="0.0.0.0"
|
||||
)
|
||||
demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0")
|
||||
|
|
|
@ -390,7 +390,7 @@ class GPTTrainer(BaseTTS):
|
|||
loader = DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size = config.eval_batch_size if is_eval else config.batch_size,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
|
|
|
@ -1,34 +1,35 @@
|
|||
import torch
|
||||
|
||||
class SpeakerManager():
|
||||
|
||||
class SpeakerManager:
|
||||
def __init__(self, speaker_file_path=None):
|
||||
self.speakers = torch.load(speaker_file_path)
|
||||
|
||||
@property
|
||||
def name_to_id(self):
|
||||
return self.speakers.keys()
|
||||
|
||||
|
||||
@property
|
||||
def num_speakers(self):
|
||||
return len(self.name_to_id)
|
||||
|
||||
|
||||
@property
|
||||
def speaker_names(self):
|
||||
return list(self.name_to_id.keys())
|
||||
|
||||
|
||||
class LanguageManager():
|
||||
|
||||
class LanguageManager:
|
||||
def __init__(self, config):
|
||||
self.langs = config["languages"]
|
||||
|
||||
@property
|
||||
def name_to_id(self):
|
||||
return self.langs
|
||||
|
||||
|
||||
@property
|
||||
def num_languages(self):
|
||||
return len(self.name_to_id)
|
||||
|
||||
|
||||
@property
|
||||
def language_names(self):
|
||||
return list(self.name_to_id)
|
||||
|
|
|
@ -299,7 +299,7 @@ class ForwardTTS(BaseTTS):
|
|||
if config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = config.d_vector_dim
|
||||
if self.args.d_vector_dim != self.args.hidden_channels:
|
||||
#self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||
# self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
|
@ -404,13 +404,13 @@ class ForwardTTS(BaseTTS):
|
|||
# [B, T, C]
|
||||
x_emb = self.emb(x)
|
||||
# encoder pass
|
||||
#o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
||||
# o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
||||
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g)
|
||||
# speaker conditioning
|
||||
# TODO: try different ways of conditioning
|
||||
if g is not None:
|
||||
if g is not None:
|
||||
if hasattr(self, "proj_g"):
|
||||
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
|
||||
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
|
||||
o_en = o_en + g
|
||||
return o_en, x_mask, g, x_emb
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ from TTS.tts.layers.xtts.gpt import GPT
|
|||
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
||||
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
||||
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
||||
from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager
|
||||
from TTS.tts.layers.xtts.xtts_manager import LanguageManager, SpeakerManager
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
@ -410,12 +410,14 @@ class Xtts(BaseTTS):
|
|||
if speaker_id is not None:
|
||||
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
||||
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
|
||||
settings.update({
|
||||
"gpt_cond_len": config.gpt_cond_len,
|
||||
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||
"max_ref_len": config.max_ref_len,
|
||||
"sound_norm_refs": config.sound_norm_refs,
|
||||
})
|
||||
settings.update(
|
||||
{
|
||||
"gpt_cond_len": config.gpt_cond_len,
|
||||
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
||||
"max_ref_len": config.max_ref_len,
|
||||
"sound_norm_refs": config.sound_norm_refs,
|
||||
}
|
||||
)
|
||||
return self.full_inference(text, speaker_wav, language, **settings)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
|
@ -335,7 +335,7 @@ class Synthesizer(nn.Module):
|
|||
# handle multi-lingual
|
||||
language_id = None
|
||||
if self.tts_languages_file or (
|
||||
hasattr(self.tts_model, "language_manager")
|
||||
hasattr(self.tts_model, "language_manager")
|
||||
and self.tts_model.language_manager is not None
|
||||
and not self.tts_config.model == "xtts"
|
||||
):
|
||||
|
|
Loading…
Reference in New Issue