mirror of https://github.com/coqui-ai/TTS.git
SS refactoring
This commit is contained in:
parent
e82d31b6ac
commit
7586fbc4de
|
@ -1,5 +1,6 @@
|
|||
from torch import nn
|
||||
from TTS.tts.layers.generic.res_conv_bn import ConvBNBlock, ResidualConvBNBlock
|
||||
from TTS.tts.layers.generic.wavenet import WNBlocks
|
||||
from TTS.tts.layers.glow_tts.transformer import Transformer
|
||||
|
||||
|
||||
|
@ -19,7 +20,7 @@ class Decoder(nn.Module):
|
|||
Default decoder_params...
|
||||
|
||||
for 'transformer'
|
||||
encoder_params={
|
||||
decoder_params={
|
||||
'hidden_channels_ffn': 128,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
|
@ -30,12 +31,22 @@ class Decoder(nn.Module):
|
|||
},
|
||||
|
||||
for 'residual_conv_bn'
|
||||
encoder_params = {
|
||||
decoder_params = {
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17
|
||||
}
|
||||
|
||||
for 'wavenet'
|
||||
decoder_params = {
|
||||
"num_blocks": 12,
|
||||
"hidden_channels":192,
|
||||
"kernel_size": 5,
|
||||
"dilation_rate": 1,
|
||||
"num_layers": 4,
|
||||
"dropout_p": 0.05
|
||||
}
|
||||
"""
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
|
@ -60,6 +71,8 @@ class Decoder(nn.Module):
|
|||
elif decoder_type == 'residual_conv_bn':
|
||||
self.decoder = ResidualConvBNBlock(self.hidden_channels,
|
||||
**decoder_params)
|
||||
elif decoder_type == 'wavenet':
|
||||
self.decoder = WNBlocks(in_channels=self.in_channels, hidden_channels=self.hidden_channels, **decoder_params)
|
||||
else:
|
||||
raise ValueError(f'[!] Unknown decoder type - {decoder_type}')
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ class SpeedySpeech(nn.Module):
|
|||
"num_res_blocks": 13
|
||||
},
|
||||
decoder_type='residual_conv_bn',
|
||||
decoder_residual_conv_bn_params={
|
||||
decoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
|
@ -41,7 +41,7 @@ class SpeedySpeech(nn.Module):
|
|||
if positional_encoding:
|
||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels,
|
||||
decoder_type, decoder_residual_conv_bn_params)
|
||||
decoder_type, decoder_params)
|
||||
self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels)
|
||||
|
||||
if num_speakers > 1 and not external_c:
|
||||
|
|
|
@ -117,22 +117,20 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
|||
num_speakers=num_speakers,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_sqz=2,
|
||||
num_squeeze=2,
|
||||
sigmoid_scale=False,
|
||||
mean_only=True,
|
||||
hidden_channels_enc=192,
|
||||
hidden_channels_dec=192,
|
||||
use_encoder_prenet=True,
|
||||
use_encoder_prenet=c["use_encoder_prenet"],
|
||||
external_speaker_embedding_dim=speaker_embedding_dim)
|
||||
elif c.model.lower() == "speedy_speech":
|
||||
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
out_channels=c.audio['num_mels'],
|
||||
hidden_channels=128,
|
||||
hidden_channels=c['hidden_channels'],
|
||||
positional_encoding=c['positional_encoding'],
|
||||
encoder_type=c['encoder_type'],
|
||||
encoder_params=c['encoder_params'],
|
||||
decoder_type=c['decoder_type'],
|
||||
decoder_residual_conv_bn_params=c['decoder_residual_conv_bn_params'],
|
||||
decoder_params=c['decoder_params'],
|
||||
c_in_channels=0)
|
||||
return model
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ def copy_model_files(c, config_file, out_path, new_fields):
|
|||
in the config file.
|
||||
"""
|
||||
# copy config.json
|
||||
copy_config_path = OUT_PATH
|
||||
copy_config_path = os.path.join(out_path, 'config.json')
|
||||
config_lines = open(config_file, "r").readlines()
|
||||
# add extra information fields
|
||||
for key, value in new_fields.items():
|
||||
|
@ -70,5 +70,6 @@ def copy_model_files(c, config_file, out_path, new_fields):
|
|||
config_out_file.writelines(config_lines)
|
||||
config_out_file.close()
|
||||
# copy model stats file if available
|
||||
copy_stats_path = os.path.join(out_path, 'scale_stats.npy')
|
||||
copyfile(c.audio['stats_path'], copy_stats_path)
|
||||
if c.audio['stats_path'] is not None:
|
||||
copy_stats_path = os.path.join(out_path, 'scale_stats.npy')
|
||||
copyfile(c.audio['stats_path'], copy_stats_path)
|
||||
|
|
Loading…
Reference in New Issue