mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #453 from Edresson/dev
Script for spectrogram extraction using teacher forcing and Glow-TTS inference with MAS.
This commit is contained in:
commit
f7582107da
|
@ -0,0 +1,280 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Extract Mel spectrograms with teacher forcing."""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.speakers import parse_speakers
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
|
||||
use_cuda = torch.cuda.is_available()
|
||||
|
||||
def setup_loader(ap, r, verbose=False):
|
||||
dataset = MyDataset(
|
||||
r,
|
||||
c.text_cleaner,
|
||||
compute_linear_spec=False,
|
||||
meta_data=meta_data,
|
||||
ap=ap,
|
||||
tp=c.characters if "characters" in c.keys() else None,
|
||||
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
|
||||
batch_group_size=0,
|
||||
min_seq_len=c.min_seq_len,
|
||||
max_seq_len=c.max_seq_len,
|
||||
phoneme_cache_path=c.phoneme_cache_path,
|
||||
use_phonemes=c.use_phonemes,
|
||||
phoneme_language=c.phoneme_language,
|
||||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
use_noise_augment=False,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping
|
||||
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
|
||||
else None,
|
||||
)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(c.num_loader_workers)
|
||||
dataset.sort_items()
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=c.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=None,
|
||||
num_workers=c.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def set_filename(wav_path, out_path):
|
||||
wav_file = os.path.basename(wav_path)
|
||||
file_name = wav_file.split('.')[0]
|
||||
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
|
||||
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
|
||||
os.makedirs(os.path.join(out_path, "wav_gl"), exist_ok=True)
|
||||
os.makedirs(os.path.join(out_path, "wav"), exist_ok=True)
|
||||
wavq_path = os.path.join(out_path, "quant", file_name)
|
||||
mel_path = os.path.join(out_path, "mel", file_name)
|
||||
wav_gl_path = os.path.join(out_path, "wav_gl", file_name+'.wav')
|
||||
wav_path = os.path.join(out_path, "wav", file_name+'.wav')
|
||||
return file_name, wavq_path, mel_path, wav_gl_path, wav_path
|
||||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
speaker_names = data[2]
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
item_idx = data[7]
|
||||
attn_mask = data[9]
|
||||
avg_text_length = torch.mean(text_lengths.float())
|
||||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
|
||||
if c.use_speaker_embedding:
|
||||
if c.use_external_speaker_embedding_file:
|
||||
speaker_embeddings = data[8]
|
||||
speaker_ids = None
|
||||
else:
|
||||
speaker_ids = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
speaker_embeddings = None
|
||||
else:
|
||||
speaker_embeddings = None
|
||||
speaker_ids = None
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda(non_blocking=True)
|
||||
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||
mel_input = mel_input.cuda(non_blocking=True)
|
||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
if speaker_ids is not None:
|
||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||
if speaker_embeddings is not None:
|
||||
speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.cuda(non_blocking=True)
|
||||
return (
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_ids,
|
||||
speaker_embeddings,
|
||||
avg_text_length,
|
||||
avg_spec_length,
|
||||
attn_mask,
|
||||
item_idx,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(model_name, model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask=None, speaker_ids=None, speaker_embeddings=None):
|
||||
if model_name == "glow_tts":
|
||||
mel_input = mel_input.permute(0, 2, 1) # B x D x T
|
||||
speaker_c = None
|
||||
if speaker_ids is not None:
|
||||
speaker_c = speaker_ids
|
||||
elif speaker_embeddings is not None:
|
||||
speaker_c = speaker_embeddings
|
||||
|
||||
model_output, *_ = model.inference_with_MAS(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c
|
||||
)
|
||||
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
||||
|
||||
elif "tacotron" in model_name:
|
||||
_, postnet_outputs, *_ = model(
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
||||
# normalize tacotron output
|
||||
if model_name == "tacotron":
|
||||
mel_specs = []
|
||||
postnet_outputs = postnet_outputs.data.cpu().numpy()
|
||||
for b in range(postnet_outputs.shape[0]):
|
||||
postnet_output = postnet_outputs[b]
|
||||
mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T))
|
||||
model_output = torch.stack(mel_specs).cpu().numpy()
|
||||
|
||||
elif model_name == "tacotron2":
|
||||
model_output = postnet_outputs.detach().cpu().numpy()
|
||||
return model_output
|
||||
|
||||
def extract_spectrograms(data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt"):
|
||||
model.eval()
|
||||
export_metadata = []
|
||||
for _, data in tqdm(enumerate(data_loader), total=len(data_loader)):
|
||||
|
||||
# format data
|
||||
(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_ids,
|
||||
speaker_embeddings,
|
||||
_,
|
||||
_,
|
||||
attn_mask,
|
||||
item_idx,
|
||||
) = format_data(data)
|
||||
|
||||
model_output = inference(c.model.lower(), model, ap, text_input, text_lengths, mel_input, mel_lengths, attn_mask, speaker_ids, speaker_embeddings)
|
||||
|
||||
for idx in range(text_input.shape[0]):
|
||||
wav_file_path = item_idx[idx]
|
||||
wav = ap.load_wav(wav_file_path)
|
||||
_, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path)
|
||||
|
||||
# quantize and save wav
|
||||
if quantized_wav:
|
||||
wavq = ap.quantize(wav)
|
||||
np.save(wavq_path, wavq)
|
||||
|
||||
# save TTS mel
|
||||
mel = model_output[idx]
|
||||
mel_length = mel_lengths[idx]
|
||||
mel = mel[:mel_length, :].T
|
||||
np.save(mel_path, mel)
|
||||
|
||||
export_metadata.append([wav_file_path, mel_path])
|
||||
if save_audio:
|
||||
ap.save_wav(wav, wav_path)
|
||||
|
||||
if debug:
|
||||
print("Audio for debug saved at:", wav_gl_path)
|
||||
wav = ap.inv_melspectrogram(mel)
|
||||
ap.save_wav(wav, wav_gl_path)
|
||||
|
||||
with open(os.path.join(output_path, metada_name), "w") as f:
|
||||
for data in export_metadata:
|
||||
f.write(f"{data[0]}|{data[1]+'.npy'}\n")
|
||||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data, symbols, phonemes, model_characters, speaker_mapping
|
||||
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
if "characters" in c.keys():
|
||||
symbols, phonemes = make_symbols(**c.characters)
|
||||
|
||||
# set model characters
|
||||
model_characters = phonemes if c.use_phonemes else symbols
|
||||
num_chars = len(model_characters)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||
|
||||
# use eval and training partitions
|
||||
meta_data = meta_data_train + meta_data_eval
|
||||
|
||||
# parse speakers
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, None)
|
||||
|
||||
# setup model
|
||||
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
|
||||
|
||||
# restore model
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
# set r
|
||||
r = 1 if c.model.lower() == "glow_tts" else model.decoder.r
|
||||
own_loader = setup_loader(ap, r, verbose=True)
|
||||
|
||||
extract_spectrograms(own_loader, model, ap, args.output_path, quantized_wav=args.quantized, save_audio=args.save_audio, debug=args.debug, metada_name="metada.txt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
help='Path to config file for training.',
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
'--checkpoint_path',
|
||||
type=str,
|
||||
help='Model file to be restored.',
|
||||
required=True)
|
||||
parser.add_argument(
|
||||
'--output_path',
|
||||
type=str,
|
||||
help='Path to save mel specs',
|
||||
required=True)
|
||||
parser.add_argument('--debug',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Save audio files for debug')
|
||||
parser.add_argument('--save_audio',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Save audio files')
|
||||
parser.add_argument('--quantized',
|
||||
action='store_true',
|
||||
help='Save quantized audio files')
|
||||
args = parser.parse_args()
|
||||
|
||||
c = load_config(args.config_path)
|
||||
|
||||
main(args)
|
|
@ -117,7 +117,7 @@ def format_data(data):
|
|||
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||
mel_input = mel_input.cuda(non_blocking=True)
|
||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron"] else None
|
||||
linear_input = linear_input.cuda(non_blocking=True) if c.model.lower() in ["tacotron"] else None
|
||||
stop_targets = stop_targets.cuda(non_blocking=True)
|
||||
if speaker_ids is not None:
|
||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import itertools
|
||||
import traceback
|
||||
from inspect import signature
|
||||
|
||||
|
@ -495,7 +496,11 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
optimizer_gen = getattr(torch.optim, c.optimizer)
|
||||
optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params)
|
||||
optimizer_disc = getattr(torch.optim, c.optimizer)
|
||||
optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params)
|
||||
|
||||
if c.discriminator_model == 'hifigan_discriminator':
|
||||
optimizer_disc = optimizer_disc(itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()), lr=c.lr_disc, **c.optimizer_params)
|
||||
else:
|
||||
optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params)
|
||||
|
||||
# schedulers
|
||||
scheduler_gen = None
|
||||
|
|
|
@ -179,6 +179,81 @@ class GlowTTS(nn.Module):
|
|||
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||
return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
|
||||
|
||||
@torch.no_grad()
|
||||
def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
|
||||
"""
|
||||
It's similar to the teacher forcing in Tacotron.
|
||||
It was proposed in: https://arxiv.org/abs/2104.05557
|
||||
Shapes:
|
||||
x: [B, T]
|
||||
x_lenghts: B
|
||||
y: [B, C, T]
|
||||
y_lengths: B
|
||||
g: [B, C] or B
|
||||
"""
|
||||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
if g is not None:
|
||||
if self.external_speaker_embedding_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
|
||||
# create masks
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
# decoder pass
|
||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||
# find the alignment path between z and encoder output
|
||||
o_scale = torch.exp(-2 * o_log_scale)
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||
|
||||
# get predited aligned distribution
|
||||
z = y_mean * y_mask
|
||||
|
||||
# reverse the decoder and predict using the aligned distribution
|
||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||
|
||||
return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
|
||||
|
||||
@torch.no_grad()
|
||||
def decoder_inference(self, y, y_lengths=None, g=None):
|
||||
"""
|
||||
Shapes:
|
||||
y: [B, C, T]
|
||||
y_lengths: B
|
||||
g: [B, C] or B
|
||||
"""
|
||||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
if g is not None:
|
||||
if self.external_speaker_embedding_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype)
|
||||
|
||||
# decoder pass
|
||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||
|
||||
# reverse decoder and predict
|
||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||
|
||||
return y, logdet
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, x_lengths, g=None):
|
||||
if g is not None:
|
||||
|
|
|
@ -28,9 +28,10 @@ def load_speaker_mapping(out_path):
|
|||
|
||||
def save_speaker_mapping(out_path, speaker_mapping):
|
||||
"""Saves speaker mapping if not yet present."""
|
||||
speakers_json_path = make_speakers_json_path(out_path)
|
||||
with open(speakers_json_path, "w") as f:
|
||||
json.dump(speaker_mapping, f, indent=4)
|
||||
if out_path is not None:
|
||||
speakers_json_path = make_speakers_json_path(out_path)
|
||||
with open(speakers_json_path, "w") as f:
|
||||
json.dump(speaker_mapping, f, indent=4)
|
||||
|
||||
|
||||
def get_speakers(items):
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from tests import get_tests_input_path
|
||||
|
||||
from tests import get_tests_output_path, run_cli
|
||||
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||
|
||||
torch.manual_seed(1)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
class TestExtractTTSSpectrograms(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_GlowTTS():
|
||||
# set paths
|
||||
config_path = os.path.join(get_tests_input_path(), "test_glow_tts.json")
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar')
|
||||
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/')
|
||||
# load config
|
||||
c = load_config(config_path)
|
||||
# create model
|
||||
num_chars = len(phonemes if c.use_phonemes else symbols)
|
||||
model = setup_model(num_chars, 1, c, speaker_embedding_dim=None)
|
||||
# save model
|
||||
torch.save({"model": model.state_dict()}, checkpoint_path)
|
||||
# run test
|
||||
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"')
|
||||
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
|
||||
@staticmethod
|
||||
def test_Tacotron2():
|
||||
# set paths
|
||||
config_path = os.path.join(get_tests_input_path(), "test_tacotron2_config.json")
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar')
|
||||
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/')
|
||||
# load config
|
||||
c = load_config(config_path)
|
||||
# create model
|
||||
num_chars = len(phonemes if c.use_phonemes else symbols)
|
||||
model = setup_model(num_chars, 1, c, speaker_embedding_dim=None)
|
||||
# save model
|
||||
torch.save({"model": model.state_dict()}, checkpoint_path)
|
||||
# run test
|
||||
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"')
|
||||
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
|
||||
@staticmethod
|
||||
def test_Tacotron():
|
||||
# set paths
|
||||
config_path = os.path.join(get_tests_input_path(), "test_tacotron_config.json")
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), 'checkpoint_test.pth.tar')
|
||||
output_path = os.path.join(get_tests_output_path(), 'output_extract_tts_spectrograms/')
|
||||
# load config
|
||||
c = load_config(config_path)
|
||||
# create model
|
||||
num_chars = len(phonemes if c.use_phonemes else symbols)
|
||||
model = setup_model(num_chars, 1, c, speaker_embedding_dim=None)
|
||||
# save model
|
||||
torch.save({"model": model.state_dict()}, checkpoint_path)
|
||||
# run test
|
||||
run_cli(f'CUDA_VISIBLE_DEVICES="" python TTS/bin/extract_tts_spectrograms.py --config_path "{config_path}" --checkpoint_path "{checkpoint_path}" --output_path "{output_path}"')
|
||||
run_cli(f'rm -rf "{output_path}" "{checkpoint_path}"')
|
|
@ -129,3 +129,58 @@ class GlowTTSTrainTest(unittest.TestCase):
|
|||
count, param.shape, param, param_ref
|
||||
)
|
||||
count += 1
|
||||
|
||||
class GlowTTSInferenceTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_inference():
|
||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
mel_spec = torch.rand(8, c.audio["num_mels"], 30).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||
|
||||
# create model
|
||||
model = GlowTTS(
|
||||
num_chars=32,
|
||||
hidden_channels_enc=48,
|
||||
hidden_channels_dec=48,
|
||||
hidden_channels_dp=32,
|
||||
out_channels=80,
|
||||
encoder_type="rel_pos_transformer",
|
||||
encoder_params={
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 6,
|
||||
"num_heads": 2,
|
||||
"hidden_channels_ffn": 16, # 4 times the hidden_channels
|
||||
"input_length": None,
|
||||
},
|
||||
use_encoder_prenet=True,
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=1,
|
||||
num_block_layers=4,
|
||||
dropout_p_dec=0.0,
|
||||
num_speakers=0,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_squeeze=1,
|
||||
sigmoid_scale=False,
|
||||
mean_only=False,
|
||||
).to(device)
|
||||
|
||||
model.eval()
|
||||
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
|
||||
|
||||
# inference encoder and decoder with MAS
|
||||
y, *_ = model.inference_with_MAS(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, None
|
||||
)
|
||||
|
||||
y_dec, _ = model.decoder_inference(mel_spec, mel_lengths
|
||||
)
|
||||
|
||||
assert (y_dec.shape == y.shape), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
|
||||
y.shape, y_dec.shape
|
||||
)
|
||||
|
|
|
@ -37,6 +37,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||
linear_spec = torch.rand(8, 30, c.audio["fft_size"]).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
mel_lengths[-1] = mel_spec.size(1)
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||
|
||||
|
@ -96,6 +97,7 @@ class MultiSpeakeTacotronTrainTest(unittest.TestCase):
|
|||
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||
linear_spec = torch.rand(8, 30, c.audio["fft_size"]).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
mel_lengths[-1] = mel_spec.size(1)
|
||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||
speaker_embeddings = torch.rand(8, 55).to(device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue