SS refactoring

This commit is contained in:
erogol 2021-01-05 14:31:11 +01:00
parent e82d31b6ac
commit 7586fbc4de
4 changed files with 25 additions and 13 deletions

View File

@ -1,5 +1,6 @@
from torch import nn
from TTS.tts.layers.generic.res_conv_bn import ConvBNBlock, ResidualConvBNBlock
from TTS.tts.layers.generic.wavenet import WNBlocks
from TTS.tts.layers.glow_tts.transformer import Transformer
@ -19,7 +20,7 @@ class Decoder(nn.Module):
Default decoder_params...
for 'transformer'
encoder_params={
decoder_params={
'hidden_channels_ffn': 128,
'num_heads': 2,
"kernel_size": 3,
@ -30,12 +31,22 @@ class Decoder(nn.Module):
},
for 'residual_conv_bn'
encoder_params = {
decoder_params = {
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17
}
for 'wavenet'
decoder_params = {
"num_blocks": 12,
"hidden_channels":192,
"kernel_size": 5,
"dilation_rate": 1,
"num_layers": 4,
"dropout_p": 0.05
}
"""
# pylint: disable=dangerous-default-value
def __init__(
@ -60,6 +71,8 @@ class Decoder(nn.Module):
elif decoder_type == 'residual_conv_bn':
self.decoder = ResidualConvBNBlock(self.hidden_channels,
**decoder_params)
elif decoder_type == 'wavenet':
self.decoder = WNBlocks(in_channels=self.in_channels, hidden_channels=self.hidden_channels, **decoder_params)
else:
raise ValueError(f'[!] Unknown decoder type - {decoder_type}')

View File

@ -24,7 +24,7 @@ class SpeedySpeech(nn.Module):
"num_res_blocks": 13
},
decoder_type='residual_conv_bn',
decoder_residual_conv_bn_params={
decoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
@ -41,7 +41,7 @@ class SpeedySpeech(nn.Module):
if positional_encoding:
self.pos_encoder = PositionalEncoding(hidden_channels)
self.decoder = Decoder(out_channels, hidden_channels,
decoder_type, decoder_residual_conv_bn_params)
decoder_type, decoder_params)
self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels)
if num_speakers > 1 and not external_c:

View File

@ -117,22 +117,20 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
num_speakers=num_speakers,
c_in_channels=0,
num_splits=4,
num_sqz=2,
num_squeeze=2,
sigmoid_scale=False,
mean_only=True,
hidden_channels_enc=192,
hidden_channels_dec=192,
use_encoder_prenet=True,
use_encoder_prenet=c["use_encoder_prenet"],
external_speaker_embedding_dim=speaker_embedding_dim)
elif c.model.lower() == "speedy_speech":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
out_channels=c.audio['num_mels'],
hidden_channels=128,
hidden_channels=c['hidden_channels'],
positional_encoding=c['positional_encoding'],
encoder_type=c['encoder_type'],
encoder_params=c['encoder_params'],
decoder_type=c['decoder_type'],
decoder_residual_conv_bn_params=c['decoder_residual_conv_bn_params'],
decoder_params=c['decoder_params'],
c_in_channels=0)
return model

View File

@ -57,7 +57,7 @@ def copy_model_files(c, config_file, out_path, new_fields):
in the config file.
"""
# copy config.json
copy_config_path = OUT_PATH
copy_config_path = os.path.join(out_path, 'config.json')
config_lines = open(config_file, "r").readlines()
# add extra information fields
for key, value in new_fields.items():
@ -70,5 +70,6 @@ def copy_model_files(c, config_file, out_path, new_fields):
config_out_file.writelines(config_lines)
config_out_file.close()
# copy model stats file if available
copy_stats_path = os.path.join(out_path, 'scale_stats.npy')
copyfile(c.audio['stats_path'], copy_stats_path)
if c.audio['stats_path'] is not None:
copy_stats_path = os.path.join(out_path, 'scale_stats.npy')
copyfile(c.audio['stats_path'], copy_stats_path)