mirror of https://github.com/coqui-ai/TTS.git
rebase fixes
This commit is contained in:
parent
07c961382f
commit
6a46339a43
|
@ -10,21 +10,21 @@ import traceback
|
|||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mozilla_voice_tts.generic_utils import count_parameters
|
||||
from mozilla_voice_tts.speaker_encoder.dataset import MyDataset
|
||||
from mozilla_voice_tts.speaker_encoder.generic_utils import save_best_model
|
||||
from mozilla_voice_tts.speaker_encoder.losses import GE2ELoss
|
||||
from mozilla_voice_tts.speaker_encoder.losses import GE2ELoss, AngleProtoLoss
|
||||
from mozilla_voice_tts.speaker_encoder.model import SpeakerEncoder
|
||||
from mozilla_voice_tts.speaker_encoder.visual import plot_embeddings
|
||||
from mozilla_voice_tts.tts.datasets.preprocess import load_meta_data
|
||||
from mozilla_voice_tts.tts.utils.audio import AudioProcessor
|
||||
from mozilla_voice_tts.tts.utils.generic_utils import (
|
||||
create_experiment_folder, get_git_branch, remove_experiment_folder,
|
||||
set_init_dict)
|
||||
from mozilla_voice_tts.tts.utils.io import copy_config_file, load_config
|
||||
from mozilla_voice_tts.tts.utils.radam import RAdam
|
||||
from mozilla_voice_tts.tts.utils.tensorboard_logger import TensorboardLogger
|
||||
from mozilla_voice_tts.tts.utils.training import NoamLR, check_update
|
||||
from mozilla_voice_tts.utils.audio import AudioProcessor
|
||||
from mozilla_voice_tts.utils.generic_utils import count_parameters
|
||||
from mozilla_voice_tts.utils.radam import RAdam
|
||||
from mozilla_voice_tts.utils.tensorboard_logger import TensorboardLogger
|
||||
from mozilla_voice_tts.utils.training import NoamLR, check_update
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
@ -146,7 +146,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
elif c.loss == "angleproto":
|
||||
criterion = AngleProtoLoss()
|
||||
else:
|
||||
raise Exception("The %s not is a loss supported" %c.loss)
|
||||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path)
|
||||
|
@ -192,6 +192,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
_, global_step = train(model, criterion, optimizer, scheduler, ap,
|
||||
global_step)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
|
|
@ -1,163 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# adapted from https://github.com/cvqluu/GE2E-Loss
|
||||
class GE2ELoss(nn.Module):
|
||||
def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"):
|
||||
"""
|
||||
Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1]
|
||||
Accepts an input of size (N, M, D)
|
||||
where N is the number of speakers in the batch,
|
||||
M is the number of utterances per speaker,
|
||||
and D is the dimensionality of the embedding vector (e.g. d-vector)
|
||||
Args:
|
||||
- init_w (float): defines the initial value of w in Equation (5) of [1]
|
||||
- init_b (float): definies the initial value of b in Equation (5) of [1]
|
||||
"""
|
||||
super(GE2ELoss, self).__init__()
|
||||
# pylint: disable=E1102
|
||||
self.w = nn.Parameter(torch.tensor(init_w))
|
||||
# pylint: disable=E1102
|
||||
self.b = nn.Parameter(torch.tensor(init_b))
|
||||
self.loss_method = loss_method
|
||||
|
||||
print('Initialised Generalized End-to-End loss')
|
||||
|
||||
assert self.loss_method in ["softmax", "contrast"]
|
||||
|
||||
if self.loss_method == "softmax":
|
||||
self.embed_loss = self.embed_loss_softmax
|
||||
if self.loss_method == "contrast":
|
||||
self.embed_loss = self.embed_loss_contrast
|
||||
|
||||
# pylint: disable=R0201
|
||||
def calc_new_centroids(self, dvecs, centroids, spkr, utt):
|
||||
"""
|
||||
Calculates the new centroids excluding the reference utterance
|
||||
"""
|
||||
excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :]))
|
||||
excl = torch.mean(excl, 0)
|
||||
new_centroids = []
|
||||
for i, centroid in enumerate(centroids):
|
||||
if i == spkr:
|
||||
new_centroids.append(excl)
|
||||
else:
|
||||
new_centroids.append(centroid)
|
||||
return torch.stack(new_centroids)
|
||||
|
||||
def calc_cosine_sim(self, dvecs, centroids):
|
||||
"""
|
||||
Make the cosine similarity matrix with dims (N,M,N)
|
||||
"""
|
||||
cos_sim_matrix = []
|
||||
for spkr_idx, speaker in enumerate(dvecs):
|
||||
cs_row = []
|
||||
for utt_idx, utterance in enumerate(speaker):
|
||||
new_centroids = self.calc_new_centroids(
|
||||
dvecs, centroids, spkr_idx, utt_idx
|
||||
)
|
||||
# vector based cosine similarity for speed
|
||||
cs_row.append(
|
||||
torch.clamp(
|
||||
torch.mm(
|
||||
utterance.unsqueeze(1).transpose(0, 1),
|
||||
new_centroids.transpose(0, 1),
|
||||
)
|
||||
/ (torch.norm(utterance) * torch.norm(new_centroids, dim=1)),
|
||||
1e-6,
|
||||
)
|
||||
)
|
||||
cs_row = torch.cat(cs_row, dim=0)
|
||||
cos_sim_matrix.append(cs_row)
|
||||
return torch.stack(cos_sim_matrix)
|
||||
|
||||
# pylint: disable=R0201
|
||||
def embed_loss_softmax(self, dvecs, cos_sim_matrix):
|
||||
"""
|
||||
Calculates the loss on each embedding $L(e_{ji})$ by taking softmax
|
||||
"""
|
||||
N, M, _ = dvecs.shape
|
||||
L = []
|
||||
for j in range(N):
|
||||
L_row = []
|
||||
for i in range(M):
|
||||
L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j])
|
||||
L_row = torch.stack(L_row)
|
||||
L.append(L_row)
|
||||
return torch.stack(L)
|
||||
|
||||
# pylint: disable=R0201
|
||||
def embed_loss_contrast(self, dvecs, cos_sim_matrix):
|
||||
"""
|
||||
Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid
|
||||
"""
|
||||
N, M, _ = dvecs.shape
|
||||
L = []
|
||||
for j in range(N):
|
||||
L_row = []
|
||||
for i in range(M):
|
||||
centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i])
|
||||
excl_centroids_sigmoids = torch.cat(
|
||||
(centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])
|
||||
)
|
||||
L_row.append(
|
||||
1.0
|
||||
- torch.sigmoid(cos_sim_matrix[j, i, j])
|
||||
+ torch.max(excl_centroids_sigmoids)
|
||||
)
|
||||
L_row = torch.stack(L_row)
|
||||
L.append(L_row)
|
||||
return torch.stack(L)
|
||||
|
||||
def forward(self, dvecs):
|
||||
"""
|
||||
Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||
"""
|
||||
centroids = torch.mean(dvecs, 1)
|
||||
cos_sim_matrix = self.calc_cosine_sim(dvecs, centroids)
|
||||
torch.clamp(self.w, 1e-6)
|
||||
cos_sim_matrix = self.w * cos_sim_matrix + self.b
|
||||
L = self.embed_loss(dvecs, cos_sim_matrix)
|
||||
return L.mean()
|
||||
|
||||
# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py
|
||||
class AngleProtoLoss(nn.Module):
|
||||
"""
|
||||
Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982
|
||||
Accepts an input of size (N, M, D)
|
||||
where N is the number of speakers in the batch,
|
||||
M is the number of utterances per speaker,
|
||||
and D is the dimensionality of the embedding vector
|
||||
Args:
|
||||
- init_w (float): defines the initial value of w
|
||||
- init_b (float): definies the initial value of b
|
||||
"""
|
||||
def __init__(self, init_w=10.0, init_b=-5.0):
|
||||
super(AngleProtoLoss, self).__init__()
|
||||
# pylint: disable=E1102
|
||||
self.w = nn.Parameter(torch.tensor(init_w))
|
||||
# pylint: disable=E1102
|
||||
self.b = nn.Parameter(torch.tensor(init_b))
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
self.use_cuda = torch.cuda.is_available()
|
||||
|
||||
print('Initialised Angular Prototypical loss')
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats)
|
||||
"""
|
||||
out_anchor = torch.mean(x[:,1:,:],1)
|
||||
out_positive = x[:,0,:]
|
||||
num_speakers = out_anchor.size()[0]
|
||||
|
||||
cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1,-1,num_speakers),out_anchor.unsqueeze(-1).expand(-1,-1,num_speakers).transpose(0,2))
|
||||
torch.clamp(self.w, 1e-6)
|
||||
cos_sim_matrix = cos_sim_matrix * self.w + self.b
|
||||
label = torch.from_numpy(np.asarray(range(0,num_speakers)))
|
||||
if self.use_cuda:
|
||||
label = label.cuda()
|
||||
L = self.criterion(cos_sim_matrix, label)
|
||||
return L
|
|
@ -157,4 +157,4 @@ class AngleProtoLoss(nn.Module):
|
|||
cos_sim_matrix = cos_sim_matrix * self.w + self.b
|
||||
label = torch.from_numpy(np.asarray(range(0, num_speakers))).to(cos_sim_matrix.device)
|
||||
L = self.criterion(cos_sim_matrix, label)
|
||||
return L
|
||||
return L
|
||||
|
|
|
@ -93,9 +93,10 @@ def mozilla_de(root_path, meta_file):
|
|||
|
||||
def mailabs(root_path, meta_files=None):
|
||||
"""Normalizes M-AI-Labs meta data files to TTS format"""
|
||||
speaker_regex = re.compile("by_book/(male|female)/(?P<speaker_name>[^/]+)/")
|
||||
speaker_regex = re.compile(
|
||||
"by_book/(male|female)/(?P<speaker_name>[^/]+)/")
|
||||
if meta_files is None:
|
||||
csv_files = glob(root_path+"/**/metadata.csv", recursive=True)
|
||||
csv_files = glob(root_path + "/**/metadata.csv", recursive=True)
|
||||
else:
|
||||
csv_files = meta_files
|
||||
# meta_files = [f.strip() for f in meta_files.split(",")]
|
||||
|
@ -115,12 +116,15 @@ def mailabs(root_path, meta_files=None):
|
|||
if meta_files is None:
|
||||
wav_file = os.path.join(folder, 'wavs', cols[0] + '.wav')
|
||||
else:
|
||||
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), 'wavs', cols[0] + '.wav')
|
||||
wav_file = os.path.join(root_path,
|
||||
folder.replace("metadata.csv", ""),
|
||||
'wavs', cols[0] + '.wav')
|
||||
if os.path.isfile(wav_file):
|
||||
text = cols[1].strip()
|
||||
items.append([text, wav_file, speaker_name])
|
||||
else:
|
||||
raise RuntimeError("> File %s does not exist!"%(wav_file))
|
||||
raise RuntimeError("> File %s does not exist!" %
|
||||
(wav_file))
|
||||
return items
|
||||
|
||||
|
||||
|
@ -185,7 +189,8 @@ def libri_tts(root_path, meta_files=None):
|
|||
text = cols[1]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
for item in items:
|
||||
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
|
||||
assert os.path.exists(
|
||||
item[1]), f" [!] wav files don't exist - {item[1]}"
|
||||
return items
|
||||
|
||||
|
||||
|
@ -197,7 +202,8 @@ def custom_turkish(root_path, meta_file):
|
|||
with open(txt_file, 'r', encoding='utf-8') as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split('|')
|
||||
wav_file = os.path.join(root_path, 'wavs', cols[0].strip() + '.wav')
|
||||
wav_file = os.path.join(root_path, 'wavs',
|
||||
cols[0].strip() + '.wav')
|
||||
if not os.path.exists(wav_file):
|
||||
skipped_files.append(wav_file)
|
||||
continue
|
||||
|
@ -206,6 +212,7 @@ def custom_turkish(root_path, meta_file):
|
|||
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
|
||||
return items
|
||||
|
||||
|
||||
# ToDo: add the dataset link when the dataset is released publicly
|
||||
def brspeech(root_path, meta_file):
|
||||
'''BRSpeech 3.0 beta'''
|
||||
|
@ -223,20 +230,25 @@ def brspeech(root_path, meta_file):
|
|||
items.append([text, wav_file, speaker_name])
|
||||
return items
|
||||
|
||||
|
||||
def vctk(root_path, meta_files=None, wavs_path='wav48'):
|
||||
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
||||
test_speakers = meta_files
|
||||
items = []
|
||||
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
||||
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt",
|
||||
recursive=True)
|
||||
for meta_file in meta_files:
|
||||
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
|
||||
_, speaker_id, txt_file = os.path.relpath(meta_file,
|
||||
root_path).split(os.sep)
|
||||
file_id = txt_file.split('.')[0]
|
||||
if isinstance(test_speakers, list): # if is list ignore this speakers ids
|
||||
if isinstance(test_speakers,
|
||||
list): # if is list ignore this speakers ids
|
||||
if speaker_id in test_speakers:
|
||||
continue
|
||||
with open(meta_file) as file_text:
|
||||
text = file_text.readlines()[0]
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id+'.wav')
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id,
|
||||
file_id + '.wav')
|
||||
items.append([text, wav_file, speaker_id])
|
||||
|
||||
return items
|
|
@ -6,6 +6,7 @@ from mozilla_voice_tts.tts.layers.gst_layers import GST
|
|||
from mozilla_voice_tts.tts.layers.tacotron import Decoder, Encoder, PostCBHG
|
||||
from mozilla_voice_tts.tts.models.tacotron_abstract import TacotronAbstract
|
||||
|
||||
|
||||
class Tacotron(TacotronAbstract):
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
|
@ -42,8 +43,8 @@ class Tacotron(TacotronAbstract):
|
|||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, attn_K, separate_stopnet,
|
||||
bidirectional_decoder, double_decoder_consistency,
|
||||
ddc_r, encoder_in_features, decoder_in_features,
|
||||
speaker_embedding_dim, gst, gst_embedding_dim,
|
||||
ddc_r, encoder_in_features, decoder_in_features,
|
||||
speaker_embedding_dim, gst, gst_embedding_dim,
|
||||
gst_num_heads, gst_style_tokens)
|
||||
|
||||
# speaker embedding layers
|
||||
|
|
|
@ -1,15 +1,9 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
<<<<<<< HEAD:mozilla_voice_tts/tts/models/tacotron2.py
|
||||
from mozilla_voice_tts.tts.layers.gst_layers import GST
|
||||
from mozilla_voice_tts.tts.layers.tacotron2 import Decoder, Encoder, Postnet
|
||||
from mozilla_voice_tts.tts.models.tacotron_abstract import TacotronAbstract
|
||||
=======
|
||||
from TTS.tts.layers.gst_layers import GST
|
||||
from TTS.tts.layers.tacotron2 import Decoder, Encoder, Postnet
|
||||
from TTS.tts.models.tacotron_abstract import TacotronAbstract
|
||||
>>>>>>> bugfix in DDC now DDC work on Tacotron1:TTS/tts/models/tacotron2.py
|
||||
|
||||
# TODO: match function arguments with tacotron
|
||||
class Tacotron2(TacotronAbstract):
|
||||
|
@ -47,8 +41,8 @@ class Tacotron2(TacotronAbstract):
|
|||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, attn_K, separate_stopnet,
|
||||
bidirectional_decoder, double_decoder_consistency,
|
||||
ddc_r, encoder_in_features, decoder_in_features,
|
||||
speaker_embedding_dim, gst, gst_embedding_dim,
|
||||
ddc_r, encoder_in_features, decoder_in_features,
|
||||
speaker_embedding_dim, gst, gst_embedding_dim,
|
||||
gst_num_heads, gst_style_tokens)
|
||||
|
||||
# speaker embedding layer
|
||||
|
@ -61,7 +55,7 @@ class Tacotron2(TacotronAbstract):
|
|||
# speaker and gst embeddings is concat in decoder input
|
||||
if self.num_speakers > 1:
|
||||
self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim
|
||||
|
||||
|
||||
# embedding layer
|
||||
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||
|
||||
|
|
|
@ -28,8 +28,8 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
speaker_embedding_dim=None,
|
||||
gst=False,
|
||||
gst_embedding_dim=512,
|
||||
|
|
182
synthesize.py
182
synthesize.py
|
@ -1,182 +0,0 @@
|
|||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import torch
|
||||
import json
|
||||
import string
|
||||
|
||||
from TTS.utils.synthesis import synthesis
|
||||
from TTS.utils.generic_utils import setup_model
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.text.symbols import make_symbols, symbols, phonemes
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
def tts(model,
|
||||
vocoder_model,
|
||||
C,
|
||||
VC,
|
||||
text,
|
||||
ap,
|
||||
ap_vocoder,
|
||||
use_cuda,
|
||||
batched_vocoder,
|
||||
speaker_id=None,
|
||||
figures=False):
|
||||
t_1 = time.time()
|
||||
use_vocoder_model = vocoder_model is not None
|
||||
waveform, alignment, _, postnet_output, stop_tokens, _ = synthesis(
|
||||
model, text, C, use_cuda, ap, speaker_id, style_wav=C.gst['gst_style_input'],
|
||||
truncated=False, enable_eos_bos_chars=C.enable_eos_bos_chars,
|
||||
use_griffin_lim=(not use_vocoder_model), do_trim_silence=True)
|
||||
|
||||
if C.model == "Tacotron" and use_vocoder_model:
|
||||
postnet_output = ap.out_linear_to_mel(postnet_output.T).T
|
||||
# correct if there is a scale difference b/w two models
|
||||
if use_vocoder_model:
|
||||
postnet_output = ap._denormalize(postnet_output)
|
||||
postnet_output = ap_vocoder._normalize(postnet_output)
|
||||
vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0)
|
||||
waveform = vocoder_model.generate(
|
||||
vocoder_input.cuda() if use_cuda else vocoder_input,
|
||||
batched=batched_vocoder,
|
||||
target=8000,
|
||||
overlap=400)
|
||||
print(" > Run-time: {}".format(time.time() - t_1))
|
||||
return alignment, postnet_output, stop_tokens, waveform
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
global symbols, phonemes
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('text', type=str, help='Text to generate speech.')
|
||||
parser.add_argument('config_path',
|
||||
type=str,
|
||||
help='Path to model config file.')
|
||||
parser.add_argument(
|
||||
'model_path',
|
||||
type=str,
|
||||
help='Path to model file.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'out_path',
|
||||
type=str,
|
||||
help='Path to save final wav file. Wav file will be names as the text given.',
|
||||
)
|
||||
parser.add_argument('--use_cuda',
|
||||
type=bool,
|
||||
help='Run model on CUDA.',
|
||||
default=False)
|
||||
parser.add_argument(
|
||||
'--vocoder_path',
|
||||
type=str,
|
||||
help=
|
||||
'Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).',
|
||||
default="",
|
||||
)
|
||||
parser.add_argument('--vocoder_config_path',
|
||||
type=str,
|
||||
help='Path to vocoder model config file.',
|
||||
default="")
|
||||
parser.add_argument(
|
||||
'--batched_vocoder',
|
||||
type=bool,
|
||||
help="If True, vocoder model uses faster batch processing.",
|
||||
default=True)
|
||||
parser.add_argument('--speakers_json',
|
||||
type=str,
|
||||
help="JSON file for multi-speaker model.",
|
||||
default="")
|
||||
parser.add_argument(
|
||||
'--speaker_id',
|
||||
type=int,
|
||||
help="target speaker_id if the model is multi-speaker.",
|
||||
default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.vocoder_path != "":
|
||||
assert args.use_cuda, " [!] Enable cuda for vocoder."
|
||||
from WaveRNN.models.wavernn import Model as VocoderModel
|
||||
|
||||
# load the config
|
||||
C = load_config(args.config_path)
|
||||
C.forward_attn_mask = True
|
||||
|
||||
# load the audio processor
|
||||
ap = AudioProcessor(**C.audio)
|
||||
|
||||
# if the vocabulary was passed, replace the default
|
||||
if 'characters' in C.keys():
|
||||
symbols, phonemes = make_symbols(**C.characters)
|
||||
|
||||
# load speakers
|
||||
if args.speakers_json != '':
|
||||
speakers = json.load(open(args.speakers_json, 'r'))
|
||||
num_speakers = len(speakers)
|
||||
else:
|
||||
num_speakers = 0
|
||||
|
||||
# load the model
|
||||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||
model = setup_model(num_chars, num_speakers, C)
|
||||
cp = torch.load(args.model_path)
|
||||
model.load_state_dict(cp['model'])
|
||||
model.eval()
|
||||
if args.use_cuda:
|
||||
model.cuda()
|
||||
model.decoder.set_r(cp['r'])
|
||||
|
||||
# load vocoder model
|
||||
if args.vocoder_path != "":
|
||||
VC = load_config(args.vocoder_config_path)
|
||||
ap_vocoder = AudioProcessor(**VC.audio)
|
||||
bits = 10
|
||||
vocoder_model = VocoderModel(rnn_dims=512,
|
||||
fc_dims=512,
|
||||
mode=VC.mode,
|
||||
mulaw=VC.mulaw,
|
||||
pad=VC.pad,
|
||||
upsample_factors=VC.upsample_factors,
|
||||
feat_dims=VC.audio["num_mels"],
|
||||
compute_dims=128,
|
||||
res_out_dims=128,
|
||||
res_blocks=10,
|
||||
hop_length=ap.hop_length,
|
||||
sample_rate=ap.sample_rate,
|
||||
use_aux_net=True,
|
||||
use_upsample_net=True)
|
||||
|
||||
check = torch.load(args.vocoder_path)
|
||||
vocoder_model.load_state_dict(check['model'])
|
||||
vocoder_model.eval()
|
||||
if args.use_cuda:
|
||||
vocoder_model.cuda()
|
||||
else:
|
||||
vocoder_model = None
|
||||
VC = None
|
||||
ap_vocoder = None
|
||||
|
||||
# synthesize voice
|
||||
print(" > Text: {}".format(args.text))
|
||||
_, _, _, wav = tts(model,
|
||||
vocoder_model,
|
||||
C,
|
||||
VC,
|
||||
args.text,
|
||||
ap,
|
||||
ap_vocoder,
|
||||
args.use_cuda,
|
||||
args.batched_vocoder,
|
||||
speaker_id=args.speaker_id,
|
||||
figures=False)
|
||||
|
||||
# save the results
|
||||
file_name = args.text.replace(" ", "_")
|
||||
file_name = file_name.translate(
|
||||
str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'
|
||||
out_path = os.path.join(args.out_path, file_name)
|
||||
print(" > Saving output to {}".format(out_path))
|
||||
ap.save_wav(wav, out_path)
|
|
@ -76,61 +76,6 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
count += 1
|
||||
|
||||
|
||||
class TacotronGSTTrainTest(unittest.TestCase):
|
||||
def test_train_step(self):
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
|
||||
input_lengths = torch.sort(input_lengths, descending=True)[0]
|
||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||
mel_lengths[0] = 30
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||
|
||||
for idx in mel_lengths:
|
||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
||||
|
||||
stop_targets = stop_targets.view(input_dummy.shape[0],
|
||||
stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||
|
||||
criterion = MSELossMasked(seq_len_norm=False).to(device)
|
||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||
model = Tacotron2(num_chars=24,
|
||||
gst=True,
|
||||
r=c.r,
|
||||
num_speakers=5).to(device)
|
||||
model.train()
|
||||
model_ref = copy.deepcopy(model)
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(),
|
||||
model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(5):
|
||||
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
||||
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
||||
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
loss = loss + criterion(mel_postnet_out, mel_postnet_spec, mel_lengths) + stop_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# check parameter changes
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(),
|
||||
model_ref.parameters()):
|
||||
# ignore pre-higway layer since it works conditional
|
||||
# if count not in [145, 59]:
|
||||
assert (param != param_ref).any(
|
||||
), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref)
|
||||
count += 1
|
||||
|
||||
class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_train_step():
|
||||
|
@ -185,8 +130,8 @@ class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
|||
count += 1
|
||||
|
||||
class TacotronGSTTrainTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_train_step():
|
||||
#pylint: disable=no-self-use
|
||||
def test_train_step(self):
|
||||
# with random gst mel style
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
|
||||
|
|
Loading…
Reference in New Issue