mirror of https://github.com/coqui-ai/TTS.git
handle multi speaker and gst in Synthetizer class
This commit is contained in:
parent
a53958ae3a
commit
48ae52a9a3
|
@ -10,6 +10,7 @@ from flask import Flask, render_template, request, send_file
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
from TTS.utils.synthesizer import Synthesizer
|
from TTS.utils.synthesizer import Synthesizer
|
||||||
|
from TTS.utils.generic_utils import style_wav_uri_to_dict
|
||||||
|
|
||||||
|
|
||||||
def create_argparser():
|
def create_argparser():
|
||||||
|
@ -81,13 +82,19 @@ synthesizer = Synthesizer(
|
||||||
args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda
|
args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda
|
||||||
)
|
)
|
||||||
|
|
||||||
|
use_speaker_embedding = synthesizer.tts_config.get("use_external_speaker_embedding_file", False)
|
||||||
|
use_gst = synthesizer.tts_config.get("use_gst", False)
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/")
|
@app.route("/")
|
||||||
def index():
|
def index():
|
||||||
return render_template("index.html", show_details=args.show_details)
|
return render_template(
|
||||||
|
'index.html',
|
||||||
|
show_details=args.show_details,
|
||||||
|
use_speaker_embedding=use_speaker_embedding,
|
||||||
|
use_gst = use_gst
|
||||||
|
)
|
||||||
|
|
||||||
@app.route("/details")
|
@app.route("/details")
|
||||||
def details():
|
def details():
|
||||||
|
@ -108,9 +115,13 @@ def details():
|
||||||
|
|
||||||
@app.route("/api/tts", methods=["GET"])
|
@app.route("/api/tts", methods=["GET"])
|
||||||
def tts():
|
def tts():
|
||||||
text = request.args.get("text")
|
text = request.args.get('text')
|
||||||
|
speaker_json_key = request.args.get('speaker', "")
|
||||||
|
style_wav = request.args.get('style-wav', "")
|
||||||
|
|
||||||
|
style_wav = style_wav_uri_to_dict(style_wav)
|
||||||
print(" > Model input: {}".format(text))
|
print(" > Model input: {}".format(text))
|
||||||
wavs = synthesizer.tts(text)
|
wavs = synthesizer.tts(text, speaker_json_key=speaker_json_key, style_wav=style_wav)
|
||||||
out = io.BytesIO()
|
out = io.BytesIO()
|
||||||
synthesizer.save_wav(wavs, out)
|
synthesizer.save_wav(wavs, out)
|
||||||
return send_file(out, mimetype="audio/wav")
|
return send_file(out, mimetype="audio/wav")
|
||||||
|
|
|
@ -60,6 +60,14 @@
|
||||||
|
|
||||||
<ul class="list-unstyled">
|
<ul class="list-unstyled">
|
||||||
</ul>
|
</ul>
|
||||||
|
{%if use_speaker_embedding%}
|
||||||
|
<input id="speaker-json-key" placeholder="speaker json key.." size=45 type="text" name="speaker-json-key">
|
||||||
|
{%endif%}
|
||||||
|
|
||||||
|
{%if use_gst%}
|
||||||
|
<input value='{"0": 0.1}' id="style-wav" placeholder="style wav (dict or path ot wav).." size=45 type="text" name="style-wav">
|
||||||
|
{%endif%}
|
||||||
|
|
||||||
<input id="text" placeholder="Type here..." size=45 type="text" name="text">
|
<input id="text" placeholder="Type here..." size=45 type="text" name="text">
|
||||||
<button id="speak-button" name="speak">Speak</button><br/><br/>
|
<button id="speak-button" name="speak">Speak</button><br/><br/>
|
||||||
{%if show_details%}
|
{%if show_details%}
|
||||||
|
@ -73,15 +81,24 @@
|
||||||
|
|
||||||
<!-- Bootstrap core JavaScript -->
|
<!-- Bootstrap core JavaScript -->
|
||||||
<script>
|
<script>
|
||||||
|
function getTextValue(textId) {
|
||||||
|
const container = q(textId)
|
||||||
|
if (container) {
|
||||||
|
return container.value
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
function q(selector) {return document.querySelector(selector)}
|
function q(selector) {return document.querySelector(selector)}
|
||||||
q('#text').focus()
|
q('#text').focus()
|
||||||
function do_tts(e) {
|
function do_tts(e) {
|
||||||
text = q('#text').value
|
const text = q('#text').value
|
||||||
|
const speakerJsonKey = getTextValue('#speaker-json-key')
|
||||||
|
const styleWav = getTextValue('#style-wav')
|
||||||
if (text) {
|
if (text) {
|
||||||
q('#message').textContent = 'Synthesizing...'
|
q('#message').textContent = 'Synthesizing...'
|
||||||
q('#speak-button').disabled = true
|
q('#speak-button').disabled = true
|
||||||
q('#audio').hidden = true
|
q('#audio').hidden = true
|
||||||
synthesize(text)
|
synthesize(text, speakerJsonKey, styleWav)
|
||||||
}
|
}
|
||||||
e.preventDefault()
|
e.preventDefault()
|
||||||
return false
|
return false
|
||||||
|
@ -92,8 +109,8 @@
|
||||||
do_tts(e)
|
do_tts(e)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
function synthesize(text) {
|
function synthesize(text, speakerJsonKey="", styleWav="") {
|
||||||
fetch('/api/tts?text=' + encodeURIComponent(text), {cache: 'no-cache'})
|
fetch(`/api/tts?text=${encodeURIComponent(text)}&speaker=${encodeURIComponent(speakerJsonKey)}&style-wav=${encodeURIComponent(styleWav)}` , {cache: 'no-cache'})
|
||||||
.then(function(res) {
|
.then(function(res) {
|
||||||
if (!res.ok) throw Error(res.statusText)
|
if (!res.ok) throw Error(res.statusText)
|
||||||
return res.blob()
|
return res.blob()
|
||||||
|
|
|
@ -65,30 +65,27 @@ def compute_style_mel(style_wav, ap, cuda=False):
|
||||||
|
|
||||||
|
|
||||||
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
|
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
|
||||||
if "tacotron" in CONFIG.model.lower():
|
speaker_embedding_g = speaker_id if speaker_id is not None else speaker_embeddings
|
||||||
if CONFIG.use_gst:
|
if 'tacotron' in CONFIG.model.lower():
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
if not CONFIG.use_gst:
|
||||||
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
style_mel = None
|
||||||
)
|
|
||||||
else:
|
|
||||||
if truncated:
|
if truncated:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||||
)
|
elif 'glow' in CONFIG.model.lower():
|
||||||
elif "glow" in CONFIG.model.lower():
|
|
||||||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||||
if hasattr(model, "module"):
|
if hasattr(model, "module"):
|
||||||
# distributed model
|
# distributed model
|
||||||
postnet_output, _, _, _, alignments, _, _ = model.module.inference(
|
postnet_output, _, _, _, alignments, _, _ = model.module.inference(
|
||||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
inputs, inputs_lengths, g=speaker_embedding_g
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
postnet_output, _, _, _, alignments, _, _ = model.inference(
|
postnet_output, _, _, _, alignments, _, _ = model.inference(
|
||||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
inputs, inputs_lengths, g=speaker_embedding_g
|
||||||
)
|
)
|
||||||
postnet_output = postnet_output.permute(0, 2, 1)
|
postnet_output = postnet_output.permute(0, 2, 1)
|
||||||
# these only belong to tacotron models.
|
# these only belong to tacotron models.
|
||||||
|
@ -99,11 +96,11 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
|
||||||
if hasattr(model, "module"):
|
if hasattr(model, "module"):
|
||||||
# distributed model
|
# distributed model
|
||||||
postnet_output, alignments = model.module.inference(
|
postnet_output, alignments = model.module.inference(
|
||||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
inputs, inputs_lengths, g=speaker_embedding_g
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
postnet_output, alignments = model.inference(
|
postnet_output, alignments = model.inference(
|
||||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
inputs, inputs_lengths, g=speaker_embedding_g
|
||||||
)
|
)
|
||||||
postnet_output = postnet_output.permute(0, 2, 1)
|
postnet_output = postnet_output.permute(0, 2, 1)
|
||||||
# these only belong to tacotron models.
|
# these only belong to tacotron models.
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import datetime
|
import datetime
|
||||||
import glob
|
import glob
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
def get_git_branch():
|
def get_git_branch():
|
||||||
|
@ -160,6 +162,21 @@ def check_argument(
|
||||||
is_valid = True
|
is_valid = True
|
||||||
assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
|
assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
|
||||||
elif val_type:
|
elif val_type:
|
||||||
assert (
|
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
||||||
isinstance(c[name], val_type) or c[name] is None
|
|
||||||
), f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
|
|
||||||
|
def style_wav_uri_to_dict(style_wav: str) -> Union[str, dict]:
|
||||||
|
"""Transform an uri style_wav, in either a string (path to wav file to be use for style transfer)
|
||||||
|
or a dict (gst tokens/values to be use for styling)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
style_wav (str): uri
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[str, dict]: path to file (str) or gst style (dict)
|
||||||
|
"""
|
||||||
|
if os.path.isfile(style_wav) and style_wav.endswith(".wav"):
|
||||||
|
return style_wav # style_wav is a .wav file located on the server
|
||||||
|
|
||||||
|
style_wav = json.loads(style_wav)
|
||||||
|
return style_wav # style_wav is a gst dictionary with {token1_id : token1_weigth, ...}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import time
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pysbd
|
import pysbd
|
||||||
|
@ -17,7 +18,14 @@ from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input, setup_gen
|
||||||
|
|
||||||
|
|
||||||
class Synthesizer(object):
|
class Synthesizer(object):
|
||||||
def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint=None, vocoder_config=None, use_cuda=False):
|
def __init__(
|
||||||
|
self,
|
||||||
|
tts_checkpoint: str,
|
||||||
|
tts_config_path: str,
|
||||||
|
vocoder_checkpoint: str = "",
|
||||||
|
vocoder_config: str = "",
|
||||||
|
use_cuda: bool = False,
|
||||||
|
) -> None:
|
||||||
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
|
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
|
||||||
model and synthesize speech from the provided text.
|
model and synthesize speech from the provided text.
|
||||||
|
|
||||||
|
@ -27,27 +35,25 @@ class Synthesizer(object):
|
||||||
If you have certain special characters in your text, you need to handle
|
If you have certain special characters in your text, you need to handle
|
||||||
them before providing the text to Synthesizer.
|
them before providing the text to Synthesizer.
|
||||||
|
|
||||||
TODO: handle multi-speaker and GST inference.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tts_checkpoint (str): path to the tts model file.
|
tts_checkpoint (str): path to the tts model file.
|
||||||
tts_config (str): path to the tts config file.
|
tts_config_path (str): path to the tts config file.
|
||||||
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
|
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
|
||||||
vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
|
vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
|
||||||
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
||||||
"""
|
"""
|
||||||
self.tts_checkpoint = tts_checkpoint
|
self.tts_checkpoint = tts_checkpoint
|
||||||
self.tts_config = tts_config
|
self.tts_config_path = tts_config_path
|
||||||
self.vocoder_checkpoint = vocoder_checkpoint
|
self.vocoder_checkpoint = vocoder_checkpoint
|
||||||
self.vocoder_config = vocoder_config
|
self.vocoder_config = vocoder_config
|
||||||
self.use_cuda = use_cuda
|
self.use_cuda = use_cuda
|
||||||
self.wavernn = None
|
|
||||||
self.vocoder_model = None
|
self.vocoder_model = None
|
||||||
self.num_speakers = 0
|
self.num_speakers = 0
|
||||||
self.tts_speakers = None
|
self.tts_speakers = {}
|
||||||
self.speaker_embedding_dim = None
|
self.speaker_embedding_dim = 0
|
||||||
self.seg = self.get_segmenter("en")
|
self.seg = self._get_segmenter("en")
|
||||||
self.use_cuda = use_cuda
|
self.use_cuda = use_cuda
|
||||||
|
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
||||||
self.load_tts(tts_checkpoint, tts_config, use_cuda)
|
self.load_tts(tts_checkpoint, tts_config, use_cuda)
|
||||||
|
@ -57,38 +63,40 @@ class Synthesizer(object):
|
||||||
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
|
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_segmenter(lang):
|
def _get_segmenter(lang: str):
|
||||||
return pysbd.Segmenter(language=lang, clean=True)
|
return pysbd.Segmenter(language=lang, clean=True)
|
||||||
|
|
||||||
def load_speakers(self):
|
|
||||||
# load speakers
|
|
||||||
if self.model_config.use_speaker_embedding is not None:
|
|
||||||
self.tts_speakers = load_speaker_mapping(self.tts_config.tts_speakers_json)
|
|
||||||
self.num_speakers = len(self.tts_speakers)
|
|
||||||
else:
|
|
||||||
self.num_speakers = 0
|
|
||||||
# set external speaker embedding
|
|
||||||
if self.tts_config.use_external_speaker_embedding_file:
|
|
||||||
speaker_embedding = self.tts_speakers[list(self.tts_speakers.keys())[0]]["embedding"]
|
|
||||||
self.speaker_embedding_dim = len(speaker_embedding)
|
|
||||||
|
|
||||||
def init_speaker(self, speaker_idx):
|
def _load_speakers(self) -> None:
|
||||||
# load speakers
|
print("Loading speakers ...")
|
||||||
|
self.tts_speakers = load_speaker_mapping(self.tts_config.external_speaker_embedding_file)
|
||||||
|
self.num_speakers = len(self.tts_speakers)
|
||||||
|
self.speaker_embedding_dim = len(self.tts_speakers[list(self.tts_speakers.keys())[0]][
|
||||||
|
"embedding"
|
||||||
|
])
|
||||||
|
|
||||||
|
def _load_speaker_embedding(self, speaker_json_key: str = ""):
|
||||||
|
|
||||||
speaker_embedding = None
|
speaker_embedding = None
|
||||||
if hasattr(self, "tts_speakers") and speaker_idx is not None:
|
|
||||||
assert speaker_idx < len(
|
if self.tts_config.get("use_external_speaker_embedding_file") and not speaker_json_key:
|
||||||
self.tts_speakers
|
raise ValueError("While 'use_external_speaker_embedding_file', you must pass a 'speaker_json_key'")
|
||||||
), f" [!] speaker_idx is out of the range. {speaker_idx} vs {len(self.tts_speakers)}"
|
|
||||||
if self.tts_config.use_external_speaker_embedding_file:
|
if speaker_json_key != "":
|
||||||
speaker_embedding = self.tts_speakers[speaker_idx]["embedding"]
|
assert self.tts_speakers
|
||||||
|
assert speaker_json_key in self.tts_speakers, f"speaker_json_key is not in self.tts_speakers keys : '{speaker_idx}'"
|
||||||
|
speaker_embedding = self.tts_speakers[speaker_json_key]["embedding"]
|
||||||
|
|
||||||
return speaker_embedding
|
return speaker_embedding
|
||||||
|
|
||||||
def load_tts(self, tts_checkpoint, tts_config, use_cuda):
|
def _load_tts(
|
||||||
|
self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool
|
||||||
|
) -> None:
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
|
|
||||||
global symbols, phonemes
|
global symbols, phonemes
|
||||||
|
|
||||||
self.tts_config = load_config(tts_config)
|
self.tts_config = load_config(tts_config_path)
|
||||||
self.use_phonemes = self.tts_config.use_phonemes
|
self.use_phonemes = self.tts_config.use_phonemes
|
||||||
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
||||||
|
|
||||||
|
@ -100,12 +108,22 @@ class Synthesizer(object):
|
||||||
else:
|
else:
|
||||||
self.input_size = len(symbols)
|
self.input_size = len(symbols)
|
||||||
|
|
||||||
self.tts_model = setup_model(self.input_size, num_speakers=self.num_speakers, c=self.tts_config)
|
if self.tts_config.use_speaker_embedding is True:
|
||||||
self.tts_model.load_checkpoint(tts_config, tts_checkpoint, eval=True)
|
self._load_speakers()
|
||||||
|
|
||||||
|
self.tts_model = setup_model(
|
||||||
|
self.input_size,
|
||||||
|
num_speakers=self.num_speakers,
|
||||||
|
c=self.tts_config,
|
||||||
|
speaker_embedding_dim=self.speaker_embedding_dim)
|
||||||
|
|
||||||
|
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.tts_model.cuda()
|
self.tts_model.cuda()
|
||||||
|
|
||||||
def load_vocoder(self, model_file, model_config, use_cuda):
|
|
||||||
|
|
||||||
|
def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> None:
|
||||||
self.vocoder_config = load_config(model_config)
|
self.vocoder_config = load_config(model_config)
|
||||||
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"])
|
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"])
|
||||||
self.vocoder_model = setup_generator(self.vocoder_config)
|
self.vocoder_model = setup_generator(self.vocoder_config)
|
||||||
|
@ -113,36 +131,36 @@ class Synthesizer(object):
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.vocoder_model.cuda()
|
self.vocoder_model.cuda()
|
||||||
|
|
||||||
def save_wav(self, wav, path):
|
def _split_into_sentences(self, text) -> List[str]:
|
||||||
|
return self.seg.segment(text)
|
||||||
|
|
||||||
|
def save_wav(self, wav: List[int], path: str) -> None:
|
||||||
wav = np.array(wav)
|
wav = np.array(wav)
|
||||||
self.ap.save_wav(wav, path, self.output_sample_rate)
|
self.ap.save_wav(wav, path, self.output_sample_rate)
|
||||||
|
|
||||||
def split_into_sentences(self, text):
|
def tts(self, text: str, speaker_json_key: str = "", style_wav = None) -> List[int]:
|
||||||
return self.seg.segment(text)
|
|
||||||
|
|
||||||
def tts(self, text, speaker_idx=None):
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
wavs = []
|
wavs = []
|
||||||
sens = self.split_into_sentences(text)
|
sens = self._split_into_sentences(text)
|
||||||
print(" > Text splitted to sentences.")
|
print(" > Text splitted to sentences.")
|
||||||
print(sens)
|
print(sens)
|
||||||
|
speaker_embedding = self._load_speaker_embedding(speaker_json_key)
|
||||||
|
|
||||||
speaker_embedding = self.init_speaker(speaker_idx)
|
|
||||||
use_gl = self.vocoder_model is None
|
use_gl = self.vocoder_model is None
|
||||||
|
|
||||||
for sen in sens:
|
for sen in sens:
|
||||||
# synthesize voice
|
# synthesize voice
|
||||||
waveform, _, _, mel_postnet_spec, _, _ = synthesis(
|
waveform, _, _, mel_postnet_spec, _, _ = synthesis(
|
||||||
self.tts_model,
|
model=self.tts_model,
|
||||||
sen,
|
text=sen,
|
||||||
self.tts_config,
|
CONFIG=self.tts_config,
|
||||||
self.use_cuda,
|
use_cuda=self.use_cuda,
|
||||||
self.ap,
|
ap=self.ap,
|
||||||
speaker_idx,
|
speaker_id=None,
|
||||||
None,
|
style_wav=style_wav,
|
||||||
False,
|
truncated=False,
|
||||||
self.tts_config.enable_eos_bos_chars,
|
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
||||||
use_gl,
|
use_griffin_lim=use_gl,
|
||||||
speaker_embedding=speaker_embedding,
|
speaker_embedding=speaker_embedding,
|
||||||
)
|
)
|
||||||
if not use_gl:
|
if not use_gl:
|
||||||
|
@ -152,12 +170,19 @@ class Synthesizer(object):
|
||||||
# renormalize spectrogram based on vocoder config
|
# renormalize spectrogram based on vocoder config
|
||||||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
||||||
# compute scale factor for possible sample rate mismatch
|
# compute scale factor for possible sample rate mismatch
|
||||||
scale_factor = [1, self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate]
|
scale_factor = [
|
||||||
|
1,
|
||||||
|
self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate,
|
||||||
|
]
|
||||||
if scale_factor[1] != 1:
|
if scale_factor[1] != 1:
|
||||||
print(" > interpolating tts model output.")
|
print(" > interpolating tts model output.")
|
||||||
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
|
vocoder_input = interpolate_vocoder_input(
|
||||||
|
scale_factor, vocoder_input
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
vocoder_input = torch.tensor(vocoder_input).unsqueeze(
|
||||||
|
0
|
||||||
|
) # pylint: disable=not-callable
|
||||||
# run vocoder model
|
# run vocoder model
|
||||||
# [1, T, C]
|
# [1, T, C]
|
||||||
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
|
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
|
||||||
|
|
Loading…
Reference in New Issue