handle multi speaker and gst in Synthetizer class

This commit is contained in:
kirianguiller 2021-03-01 15:17:15 +01:00 committed by Eren Gölge
parent a53958ae3a
commit 48ae52a9a3
5 changed files with 150 additions and 83 deletions

View File

@ -10,6 +10,7 @@ from flask import Flask, render_template, request, send_file
from TTS.utils.io import load_config from TTS.utils.io import load_config
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer from TTS.utils.synthesizer import Synthesizer
from TTS.utils.generic_utils import style_wav_uri_to_dict
def create_argparser(): def create_argparser():
@ -81,13 +82,19 @@ synthesizer = Synthesizer(
args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda
) )
use_speaker_embedding = synthesizer.tts_config.get("use_external_speaker_embedding_file", False)
use_gst = synthesizer.tts_config.get("use_gst", False)
app = Flask(__name__) app = Flask(__name__)
@app.route("/") @app.route("/")
def index(): def index():
return render_template("index.html", show_details=args.show_details) return render_template(
'index.html',
show_details=args.show_details,
use_speaker_embedding=use_speaker_embedding,
use_gst = use_gst
)
@app.route("/details") @app.route("/details")
def details(): def details():
@ -108,9 +115,13 @@ def details():
@app.route("/api/tts", methods=["GET"]) @app.route("/api/tts", methods=["GET"])
def tts(): def tts():
text = request.args.get("text") text = request.args.get('text')
speaker_json_key = request.args.get('speaker', "")
style_wav = request.args.get('style-wav', "")
style_wav = style_wav_uri_to_dict(style_wav)
print(" > Model input: {}".format(text)) print(" > Model input: {}".format(text))
wavs = synthesizer.tts(text) wavs = synthesizer.tts(text, speaker_json_key=speaker_json_key, style_wav=style_wav)
out = io.BytesIO() out = io.BytesIO()
synthesizer.save_wav(wavs, out) synthesizer.save_wav(wavs, out)
return send_file(out, mimetype="audio/wav") return send_file(out, mimetype="audio/wav")

View File

@ -60,6 +60,14 @@
<ul class="list-unstyled"> <ul class="list-unstyled">
</ul> </ul>
{%if use_speaker_embedding%}
<input id="speaker-json-key" placeholder="speaker json key.." size=45 type="text" name="speaker-json-key">
{%endif%}
{%if use_gst%}
<input value='{"0": 0.1}' id="style-wav" placeholder="style wav (dict or path ot wav).." size=45 type="text" name="style-wav">
{%endif%}
<input id="text" placeholder="Type here..." size=45 type="text" name="text"> <input id="text" placeholder="Type here..." size=45 type="text" name="text">
<button id="speak-button" name="speak">Speak</button><br/><br/> <button id="speak-button" name="speak">Speak</button><br/><br/>
{%if show_details%} {%if show_details%}
@ -73,15 +81,24 @@
<!-- Bootstrap core JavaScript --> <!-- Bootstrap core JavaScript -->
<script> <script>
function getTextValue(textId) {
const container = q(textId)
if (container) {
return container.value
}
return ""
}
function q(selector) {return document.querySelector(selector)} function q(selector) {return document.querySelector(selector)}
q('#text').focus() q('#text').focus()
function do_tts(e) { function do_tts(e) {
text = q('#text').value const text = q('#text').value
const speakerJsonKey = getTextValue('#speaker-json-key')
const styleWav = getTextValue('#style-wav')
if (text) { if (text) {
q('#message').textContent = 'Synthesizing...' q('#message').textContent = 'Synthesizing...'
q('#speak-button').disabled = true q('#speak-button').disabled = true
q('#audio').hidden = true q('#audio').hidden = true
synthesize(text) synthesize(text, speakerJsonKey, styleWav)
} }
e.preventDefault() e.preventDefault()
return false return false
@ -92,8 +109,8 @@
do_tts(e) do_tts(e)
} }
}) })
function synthesize(text) { function synthesize(text, speakerJsonKey="", styleWav="") {
fetch('/api/tts?text=' + encodeURIComponent(text), {cache: 'no-cache'}) fetch(`/api/tts?text=${encodeURIComponent(text)}&speaker=${encodeURIComponent(speakerJsonKey)}&style-wav=${encodeURIComponent(styleWav)}` , {cache: 'no-cache'})
.then(function(res) { .then(function(res) {
if (!res.ok) throw Error(res.statusText) if (!res.ok) throw Error(res.statusText)
return res.blob() return res.blob()

View File

@ -65,30 +65,27 @@ def compute_style_mel(style_wav, ap, cuda=False):
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None): def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
if "tacotron" in CONFIG.model.lower(): speaker_embedding_g = speaker_id if speaker_id is not None else speaker_embeddings
if CONFIG.use_gst: if 'tacotron' in CONFIG.model.lower():
decoder_output, postnet_output, alignments, stop_tokens = model.inference( if not CONFIG.use_gst:
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings style_mel = None
)
else:
if truncated: if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
)
else: else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference( decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
) elif 'glow' in CONFIG.model.lower():
elif "glow" in CONFIG.model.lower():
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
if hasattr(model, "module"): if hasattr(model, "module"):
# distributed model # distributed model
postnet_output, _, _, _, alignments, _, _ = model.module.inference( postnet_output, _, _, _, alignments, _, _ = model.module.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings inputs, inputs_lengths, g=speaker_embedding_g
) )
else: else:
postnet_output, _, _, _, alignments, _, _ = model.inference( postnet_output, _, _, _, alignments, _, _ = model.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings inputs, inputs_lengths, g=speaker_embedding_g
) )
postnet_output = postnet_output.permute(0, 2, 1) postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models. # these only belong to tacotron models.
@ -99,11 +96,11 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
if hasattr(model, "module"): if hasattr(model, "module"):
# distributed model # distributed model
postnet_output, alignments = model.module.inference( postnet_output, alignments = model.module.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings inputs, inputs_lengths, g=speaker_embedding_g
) )
else: else:
postnet_output, alignments = model.inference( postnet_output, alignments = model.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings inputs, inputs_lengths, g=speaker_embedding_g
) )
postnet_output = postnet_output.permute(0, 2, 1) postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models. # these only belong to tacotron models.

View File

@ -1,10 +1,12 @@
import datetime import datetime
import glob import glob
import json
import os import os
import shutil import shutil
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Union
def get_git_branch(): def get_git_branch():
@ -160,6 +162,21 @@ def check_argument(
is_valid = True is_valid = True
assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}" assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
elif val_type: elif val_type:
assert ( assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
isinstance(c[name], val_type) or c[name] is None
), f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]:
"""Transform an uri style_wav, in either a string (path to wav file to be use for style transfer)
or a dict (gst tokens/values to be use for styling)
Args:
style_wav (str): uri
Returns:
Union[str, dict]: path to file (str) or gst style (dict)
"""
if os.path.isfile(style_wav) and style_wav.endswith(".wav"):
return style_wav # style_wav is a .wav file located on the server
style_wav = json.loads(style_wav)
return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...}

View File

@ -1,4 +1,5 @@
import time import time
from typing import List
import numpy as np import numpy as np
import pysbd import pysbd
@ -17,7 +18,14 @@ from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input, setup_gen
class Synthesizer(object): class Synthesizer(object):
def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint=None, vocoder_config=None, use_cuda=False): def __init__(
self,
tts_checkpoint: str,
tts_config_path: str,
vocoder_checkpoint: str = "",
vocoder_config: str = "",
use_cuda: bool = False,
) -> None:
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder """General 🐸 TTS interface for inference. It takes a tts and a vocoder
model and synthesize speech from the provided text. model and synthesize speech from the provided text.
@ -27,27 +35,25 @@ class Synthesizer(object):
If you have certain special characters in your text, you need to handle If you have certain special characters in your text, you need to handle
them before providing the text to Synthesizer. them before providing the text to Synthesizer.
TODO: handle multi-speaker and GST inference.
Args: Args:
tts_checkpoint (str): path to the tts model file. tts_checkpoint (str): path to the tts model file.
tts_config (str): path to the tts config file. tts_config_path (str): path to the tts config file.
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None. vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
vocoder_config (str, optional): path to the vocoder config file. Defaults to None. vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
use_cuda (bool, optional): enable/disable cuda. Defaults to False. use_cuda (bool, optional): enable/disable cuda. Defaults to False.
""" """
self.tts_checkpoint = tts_checkpoint self.tts_checkpoint = tts_checkpoint
self.tts_config = tts_config self.tts_config_path = tts_config_path
self.vocoder_checkpoint = vocoder_checkpoint self.vocoder_checkpoint = vocoder_checkpoint
self.vocoder_config = vocoder_config self.vocoder_config = vocoder_config
self.use_cuda = use_cuda self.use_cuda = use_cuda
self.wavernn = None
self.vocoder_model = None self.vocoder_model = None
self.num_speakers = 0 self.num_speakers = 0
self.tts_speakers = None self.tts_speakers = {}
self.speaker_embedding_dim = None self.speaker_embedding_dim = 0
self.seg = self.get_segmenter("en") self.seg = self._get_segmenter("en")
self.use_cuda = use_cuda self.use_cuda = use_cuda
if self.use_cuda: if self.use_cuda:
assert torch.cuda.is_available(), "CUDA is not availabe on this machine." assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
self.load_tts(tts_checkpoint, tts_config, use_cuda) self.load_tts(tts_checkpoint, tts_config, use_cuda)
@ -57,38 +63,40 @@ class Synthesizer(object):
self.output_sample_rate = self.vocoder_config.audio["sample_rate"] self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
@staticmethod @staticmethod
def get_segmenter(lang): def _get_segmenter(lang: str):
return pysbd.Segmenter(language=lang, clean=True) return pysbd.Segmenter(language=lang, clean=True)
def load_speakers(self):
# load speakers
if self.model_config.use_speaker_embedding is not None:
self.tts_speakers = load_speaker_mapping(self.tts_config.tts_speakers_json)
self.num_speakers = len(self.tts_speakers)
else:
self.num_speakers = 0
# set external speaker embedding
if self.tts_config.use_external_speaker_embedding_file:
speaker_embedding = self.tts_speakers[list(self.tts_speakers.keys())[0]]["embedding"]
self.speaker_embedding_dim = len(speaker_embedding)
def init_speaker(self, speaker_idx): def _load_speakers(self) -> None:
# load speakers print("Loading speakers ...")
self.tts_speakers = load_speaker_mapping(self.tts_config.external_speaker_embedding_file)
self.num_speakers = len(self.tts_speakers)
self.speaker_embedding_dim = len(self.tts_speakers[list(self.tts_speakers.keys())[0]][
"embedding"
])
def _load_speaker_embedding(self, speaker_json_key: str = ""):
speaker_embedding = None speaker_embedding = None
if hasattr(self, "tts_speakers") and speaker_idx is not None:
assert speaker_idx < len( if self.tts_config.get("use_external_speaker_embedding_file") and not speaker_json_key:
self.tts_speakers raise ValueError("While 'use_external_speaker_embedding_file', you must pass a 'speaker_json_key'")
), f" [!] speaker_idx is out of the range. {speaker_idx} vs {len(self.tts_speakers)}"
if self.tts_config.use_external_speaker_embedding_file: if speaker_json_key != "":
speaker_embedding = self.tts_speakers[speaker_idx]["embedding"] assert self.tts_speakers
assert speaker_json_key in self.tts_speakers, f"speaker_json_key is not in self.tts_speakers keys : '{speaker_idx}'"
speaker_embedding = self.tts_speakers[speaker_json_key]["embedding"]
return speaker_embedding return speaker_embedding
def load_tts(self, tts_checkpoint, tts_config, use_cuda): def _load_tts(
self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool
) -> None:
# pylint: disable=global-statement # pylint: disable=global-statement
global symbols, phonemes global symbols, phonemes
self.tts_config = load_config(tts_config) self.tts_config = load_config(tts_config_path)
self.use_phonemes = self.tts_config.use_phonemes self.use_phonemes = self.tts_config.use_phonemes
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
@ -100,12 +108,22 @@ class Synthesizer(object):
else: else:
self.input_size = len(symbols) self.input_size = len(symbols)
self.tts_model = setup_model(self.input_size, num_speakers=self.num_speakers, c=self.tts_config) if self.tts_config.use_speaker_embedding is True:
self.tts_model.load_checkpoint(tts_config, tts_checkpoint, eval=True) self._load_speakers()
self.tts_model = setup_model(
self.input_size,
num_speakers=self.num_speakers,
c=self.tts_config,
speaker_embedding_dim=self.speaker_embedding_dim)
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
if use_cuda: if use_cuda:
self.tts_model.cuda() self.tts_model.cuda()
def load_vocoder(self, model_file, model_config, use_cuda):
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
self.vocoder_config = load_config(model_config) self.vocoder_config = load_config(model_config)
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"]) self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"])
self.vocoder_model = setup_generator(self.vocoder_config) self.vocoder_model = setup_generator(self.vocoder_config)
@ -113,36 +131,36 @@ class Synthesizer(object):
if use_cuda: if use_cuda:
self.vocoder_model.cuda() self.vocoder_model.cuda()
def save_wav(self, wav, path): def _split_into_sentences(self, text) -> List[str]:
return self.seg.segment(text)
def save_wav(self, wav: List[int], path: str) -> None:
wav = np.array(wav) wav = np.array(wav)
self.ap.save_wav(wav, path, self.output_sample_rate) self.ap.save_wav(wav, path, self.output_sample_rate)
def split_into_sentences(self, text): def tts(self, text: str, speaker_json_key: str = "", style_wav = None) -> List[int]:
return self.seg.segment(text)
def tts(self, text, speaker_idx=None):
start_time = time.time() start_time = time.time()
wavs = [] wavs = []
sens = self.split_into_sentences(text) sens = self._split_into_sentences(text)
print(" > Text splitted to sentences.") print(" > Text splitted to sentences.")
print(sens) print(sens)
speaker_embedding = self._load_speaker_embedding(speaker_json_key)
speaker_embedding = self.init_speaker(speaker_idx)
use_gl = self.vocoder_model is None use_gl = self.vocoder_model is None
for sen in sens: for sen in sens:
# synthesize voice # synthesize voice
waveform, _, _, mel_postnet_spec, _, _ = synthesis( waveform, _, _, mel_postnet_spec, _, _ = synthesis(
self.tts_model, model=self.tts_model,
sen, text=sen,
self.tts_config, CONFIG=self.tts_config,
self.use_cuda, use_cuda=self.use_cuda,
self.ap, ap=self.ap,
speaker_idx, speaker_id=None,
None, style_wav=style_wav,
False, truncated=False,
self.tts_config.enable_eos_bos_chars, enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
use_gl, use_griffin_lim=use_gl,
speaker_embedding=speaker_embedding, speaker_embedding=speaker_embedding,
) )
if not use_gl: if not use_gl:
@ -152,12 +170,19 @@ class Synthesizer(object):
# renormalize spectrogram based on vocoder config # renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch # compute scale factor for possible sample rate mismatch
scale_factor = [1, self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate] scale_factor = [
1,
self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate,
]
if scale_factor[1] != 1: if scale_factor[1] != 1:
print(" > interpolating tts model output.") print(" > interpolating tts model output.")
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) vocoder_input = interpolate_vocoder_input(
scale_factor, vocoder_input
)
else: else:
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable vocoder_input = torch.tensor(vocoder_input).unsqueeze(
0
) # pylint: disable=not-callable
# run vocoder model # run vocoder model
# [1, T, C] # [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(device_type)) waveform = self.vocoder_model.inference(vocoder_input.to(device_type))