mirror of https://github.com/coqui-ai/TTS.git
code styling
This commit is contained in:
parent
25328aad00
commit
47e356cb48
|
@ -7,10 +7,10 @@ from pathlib import Path
|
||||||
|
|
||||||
from flask import Flask, render_template, request, send_file
|
from flask import Flask, render_template, request, send_file
|
||||||
|
|
||||||
|
from TTS.utils.generic_utils import style_wav_uri_to_dict
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
from TTS.utils.synthesizer import Synthesizer
|
from TTS.utils.synthesizer import Synthesizer
|
||||||
from TTS.utils.generic_utils import style_wav_uri_to_dict
|
|
||||||
|
|
||||||
|
|
||||||
def create_argparser():
|
def create_argparser():
|
||||||
|
@ -90,11 +90,9 @@ app = Flask(__name__)
|
||||||
@app.route("/")
|
@app.route("/")
|
||||||
def index():
|
def index():
|
||||||
return render_template(
|
return render_template(
|
||||||
'index.html',
|
"index.html", show_details=args.show_details, use_speaker_embedding=use_speaker_embedding, use_gst=use_gst
|
||||||
show_details=args.show_details,
|
)
|
||||||
use_speaker_embedding=use_speaker_embedding,
|
|
||||||
use_gst = use_gst
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.route("/details")
|
@app.route("/details")
|
||||||
def details():
|
def details():
|
||||||
|
@ -115,9 +113,9 @@ def details():
|
||||||
|
|
||||||
@app.route("/api/tts", methods=["GET"])
|
@app.route("/api/tts", methods=["GET"])
|
||||||
def tts():
|
def tts():
|
||||||
text = request.args.get('text')
|
text = request.args.get("text")
|
||||||
speaker_json_key = request.args.get('speaker', "")
|
speaker_json_key = request.args.get("speaker", "")
|
||||||
style_wav = request.args.get('style-wav', "")
|
style_wav = request.args.get("style-wav", "")
|
||||||
|
|
||||||
style_wav = style_wav_uri_to_dict(style_wav)
|
style_wav = style_wav_uri_to_dict(style_wav)
|
||||||
print(" > Model input: {}".format(text))
|
print(" > Model input: {}".format(text))
|
||||||
|
|
|
@ -66,17 +66,19 @@ def compute_style_mel(style_wav, ap, cuda=False):
|
||||||
|
|
||||||
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
|
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
|
||||||
speaker_embedding_g = speaker_id if speaker_id is not None else speaker_embeddings
|
speaker_embedding_g = speaker_id if speaker_id is not None else speaker_embeddings
|
||||||
if 'tacotron' in CONFIG.model.lower():
|
if "tacotron" in CONFIG.model.lower():
|
||||||
if not CONFIG.use_gst:
|
if not CONFIG.use_gst:
|
||||||
style_mel = None
|
style_mel = None
|
||||||
|
|
||||||
if truncated:
|
if truncated:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||||
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||||
elif 'glow' in CONFIG.model.lower():
|
)
|
||||||
|
elif "glow" in CONFIG.model.lower():
|
||||||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||||
if hasattr(model, "module"):
|
if hasattr(model, "module"):
|
||||||
# distributed model
|
# distributed model
|
||||||
|
@ -84,9 +86,7 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
|
||||||
inputs, inputs_lengths, g=speaker_embedding_g
|
inputs, inputs_lengths, g=speaker_embedding_g
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
postnet_output, _, _, _, alignments, _, _ = model.inference(
|
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_embedding_g)
|
||||||
inputs, inputs_lengths, g=speaker_embedding_g
|
|
||||||
)
|
|
||||||
postnet_output = postnet_output.permute(0, 2, 1)
|
postnet_output = postnet_output.permute(0, 2, 1)
|
||||||
# these only belong to tacotron models.
|
# these only belong to tacotron models.
|
||||||
decoder_output = None
|
decoder_output = None
|
||||||
|
@ -95,13 +95,9 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
|
||||||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||||
if hasattr(model, "module"):
|
if hasattr(model, "module"):
|
||||||
# distributed model
|
# distributed model
|
||||||
postnet_output, alignments = model.module.inference(
|
postnet_output, alignments = model.module.inference(inputs, inputs_lengths, g=speaker_embedding_g)
|
||||||
inputs, inputs_lengths, g=speaker_embedding_g
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
postnet_output, alignments = model.inference(
|
postnet_output, alignments = model.inference(inputs, inputs_lengths, g=speaker_embedding_g)
|
||||||
inputs, inputs_lengths, g=speaker_embedding_g
|
|
||||||
)
|
|
||||||
postnet_output = postnet_output.permute(0, 2, 1)
|
postnet_output = postnet_output.permute(0, 2, 1)
|
||||||
# these only belong to tacotron models.
|
# these only belong to tacotron models.
|
||||||
decoder_output = None
|
decoder_output = None
|
||||||
|
|
|
@ -162,7 +162,9 @@ def check_argument(
|
||||||
is_valid = True
|
is_valid = True
|
||||||
assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
|
assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
|
||||||
elif val_type:
|
elif val_type:
|
||||||
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
assert (
|
||||||
|
isinstance(c[name], val_type) or c[name] is None
|
||||||
|
), f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
|
||||||
|
|
||||||
|
|
||||||
def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]:
|
def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]:
|
||||||
|
@ -176,7 +178,7 @@ def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]:
|
||||||
Union[str, dict]: path to file (str) or gst style (dict)
|
Union[str, dict]: path to file (str) or gst style (dict)
|
||||||
"""
|
"""
|
||||||
if os.path.isfile(style_wav) and style_wav.endswith(".wav"):
|
if os.path.isfile(style_wav) and style_wav.endswith(".wav"):
|
||||||
return style_wav # style_wav is a .wav file located on the server
|
return style_wav # style_wav is a .wav file located on the server
|
||||||
|
|
||||||
style_wav = json.loads(style_wav)
|
style_wav = json.loads(style_wav)
|
||||||
return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...}
|
return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...}
|
||||||
|
|
|
@ -68,14 +68,11 @@ class Synthesizer(object):
|
||||||
def _get_segmenter(lang: str):
|
def _get_segmenter(lang: str):
|
||||||
return pysbd.Segmenter(language=lang, clean=True)
|
return pysbd.Segmenter(language=lang, clean=True)
|
||||||
|
|
||||||
|
|
||||||
def _load_speakers(self, speaker_file: str) -> None:
|
def _load_speakers(self, speaker_file: str) -> None:
|
||||||
print("Loading speakers ...")
|
print("Loading speakers ...")
|
||||||
self.tts_speakers = load_speaker_mapping(speaker_file)
|
self.tts_speakers = load_speaker_mapping(speaker_file)
|
||||||
self.num_speakers = len(self.tts_speakers)
|
self.num_speakers = len(self.tts_speakers)
|
||||||
self.speaker_embedding_dim = len(self.tts_speakers[list(self.tts_speakers.keys())[0]][
|
self.speaker_embedding_dim = len(self.tts_speakers[list(self.tts_speakers.keys())[0]]["embedding"])
|
||||||
"embedding"
|
|
||||||
])
|
|
||||||
|
|
||||||
def _load_speaker_embedding(self, speaker_json_key: str = ""):
|
def _load_speaker_embedding(self, speaker_json_key: str = ""):
|
||||||
|
|
||||||
|
@ -86,14 +83,14 @@ class Synthesizer(object):
|
||||||
|
|
||||||
if speaker_json_key != "":
|
if speaker_json_key != "":
|
||||||
assert self.tts_speakers
|
assert self.tts_speakers
|
||||||
assert speaker_json_key in self.tts_speakers, f" [!] speaker_json_key is not in self.tts_speakers keys : '{speaker_json_key}'"
|
assert (
|
||||||
|
speaker_json_key in self.tts_speakers
|
||||||
|
), f" [!] speaker_json_key is not in self.tts_speakers keys : '{speaker_json_key}'"
|
||||||
speaker_embedding = self.tts_speakers[speaker_json_key]["embedding"]
|
speaker_embedding = self.tts_speakers[speaker_json_key]["embedding"]
|
||||||
|
|
||||||
return speaker_embedding
|
return speaker_embedding
|
||||||
|
|
||||||
def _load_tts(
|
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
||||||
self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool
|
|
||||||
) -> None:
|
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
|
|
||||||
global symbols, phonemes
|
global symbols, phonemes
|
||||||
|
@ -111,20 +108,19 @@ class Synthesizer(object):
|
||||||
self.input_size = len(symbols)
|
self.input_size = len(symbols)
|
||||||
|
|
||||||
if self.tts_config.use_speaker_embedding is True:
|
if self.tts_config.use_speaker_embedding is True:
|
||||||
self._load_speakers(self.tts_config.get('external_speaker_embedding_file', self.tts_speakers_file))
|
self._load_speakers(self.tts_config.get("external_speaker_embedding_file", self.tts_speakers_file))
|
||||||
|
|
||||||
self.tts_model = setup_model(
|
self.tts_model = setup_model(
|
||||||
self.input_size,
|
self.input_size,
|
||||||
num_speakers=self.num_speakers,
|
num_speakers=self.num_speakers,
|
||||||
c=self.tts_config,
|
c=self.tts_config,
|
||||||
speaker_embedding_dim=self.speaker_embedding_dim)
|
speaker_embedding_dim=self.speaker_embedding_dim,
|
||||||
|
)
|
||||||
|
|
||||||
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.tts_model.cuda()
|
self.tts_model.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
||||||
self.vocoder_config = load_config(model_config)
|
self.vocoder_config = load_config(model_config)
|
||||||
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"])
|
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"])
|
||||||
|
@ -140,7 +136,7 @@ class Synthesizer(object):
|
||||||
wav = np.array(wav)
|
wav = np.array(wav)
|
||||||
self.ap.save_wav(wav, path, self.output_sample_rate)
|
self.ap.save_wav(wav, path, self.output_sample_rate)
|
||||||
|
|
||||||
def tts(self, text: str, speaker_json_key: str = "", style_wav = None) -> List[int]:
|
def tts(self, text: str, speaker_json_key: str = "", style_wav=None) -> List[int]:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
wavs = []
|
wavs = []
|
||||||
sens = self._split_into_sentences(text)
|
sens = self._split_into_sentences(text)
|
||||||
|
@ -178,13 +174,9 @@ class Synthesizer(object):
|
||||||
]
|
]
|
||||||
if scale_factor[1] != 1:
|
if scale_factor[1] != 1:
|
||||||
print(" > interpolating tts model output.")
|
print(" > interpolating tts model output.")
|
||||||
vocoder_input = interpolate_vocoder_input(
|
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
|
||||||
scale_factor, vocoder_input
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(
|
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
||||||
0
|
|
||||||
) # pylint: disable=not-callable
|
|
||||||
# run vocoder model
|
# run vocoder model
|
||||||
# [1, T, C]
|
# [1, T, C]
|
||||||
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
|
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
|
||||||
|
|
Loading…
Reference in New Issue