SS refactoring

This commit is contained in:
erogol 2021-01-05 14:31:11 +01:00
parent e82d31b6ac
commit 7586fbc4de
4 changed files with 25 additions and 13 deletions

View File

@ -1,5 +1,6 @@
from torch import nn from torch import nn
from TTS.tts.layers.generic.res_conv_bn import ConvBNBlock, ResidualConvBNBlock from TTS.tts.layers.generic.res_conv_bn import ConvBNBlock, ResidualConvBNBlock
from TTS.tts.layers.generic.wavenet import WNBlocks
from TTS.tts.layers.glow_tts.transformer import Transformer from TTS.tts.layers.glow_tts.transformer import Transformer
@ -19,7 +20,7 @@ class Decoder(nn.Module):
Default decoder_params... Default decoder_params...
for 'transformer' for 'transformer'
encoder_params={ decoder_params={
'hidden_channels_ffn': 128, 'hidden_channels_ffn': 128,
'num_heads': 2, 'num_heads': 2,
"kernel_size": 3, "kernel_size": 3,
@ -30,12 +31,22 @@ class Decoder(nn.Module):
}, },
for 'residual_conv_bn' for 'residual_conv_bn'
encoder_params = { decoder_params = {
"kernel_size": 4, "kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1], "dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2, "num_conv_blocks": 2,
"num_res_blocks": 17 "num_res_blocks": 17
} }
for 'wavenet'
decoder_params = {
"num_blocks": 12,
"hidden_channels":192,
"kernel_size": 5,
"dilation_rate": 1,
"num_layers": 4,
"dropout_p": 0.05
}
""" """
# pylint: disable=dangerous-default-value # pylint: disable=dangerous-default-value
def __init__( def __init__(
@ -60,6 +71,8 @@ class Decoder(nn.Module):
elif decoder_type == 'residual_conv_bn': elif decoder_type == 'residual_conv_bn':
self.decoder = ResidualConvBNBlock(self.hidden_channels, self.decoder = ResidualConvBNBlock(self.hidden_channels,
**decoder_params) **decoder_params)
elif decoder_type == 'wavenet':
self.decoder = WNBlocks(in_channels=self.in_channels, hidden_channels=self.hidden_channels, **decoder_params)
else: else:
raise ValueError(f'[!] Unknown decoder type - {decoder_type}') raise ValueError(f'[!] Unknown decoder type - {decoder_type}')

View File

@ -24,7 +24,7 @@ class SpeedySpeech(nn.Module):
"num_res_blocks": 13 "num_res_blocks": 13
}, },
decoder_type='residual_conv_bn', decoder_type='residual_conv_bn',
decoder_residual_conv_bn_params={ decoder_params={
"kernel_size": 4, "kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1], "dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2, "num_conv_blocks": 2,
@ -41,7 +41,7 @@ class SpeedySpeech(nn.Module):
if positional_encoding: if positional_encoding:
self.pos_encoder = PositionalEncoding(hidden_channels) self.pos_encoder = PositionalEncoding(hidden_channels)
self.decoder = Decoder(out_channels, hidden_channels, self.decoder = Decoder(out_channels, hidden_channels,
decoder_type, decoder_residual_conv_bn_params) decoder_type, decoder_params)
self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels) self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels)
if num_speakers > 1 and not external_c: if num_speakers > 1 and not external_c:

View File

@ -117,22 +117,20 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
num_speakers=num_speakers, num_speakers=num_speakers,
c_in_channels=0, c_in_channels=0,
num_splits=4, num_splits=4,
num_sqz=2, num_squeeze=2,
sigmoid_scale=False, sigmoid_scale=False,
mean_only=True, mean_only=True,
hidden_channels_enc=192, use_encoder_prenet=c["use_encoder_prenet"],
hidden_channels_dec=192,
use_encoder_prenet=True,
external_speaker_embedding_dim=speaker_embedding_dim) external_speaker_embedding_dim=speaker_embedding_dim)
elif c.model.lower() == "speedy_speech": elif c.model.lower() == "speedy_speech":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
out_channels=c.audio['num_mels'], out_channels=c.audio['num_mels'],
hidden_channels=128, hidden_channels=c['hidden_channels'],
positional_encoding=c['positional_encoding'], positional_encoding=c['positional_encoding'],
encoder_type=c['encoder_type'], encoder_type=c['encoder_type'],
encoder_params=c['encoder_params'], encoder_params=c['encoder_params'],
decoder_type=c['decoder_type'], decoder_type=c['decoder_type'],
decoder_residual_conv_bn_params=c['decoder_residual_conv_bn_params'], decoder_params=c['decoder_params'],
c_in_channels=0) c_in_channels=0)
return model return model

View File

@ -57,7 +57,7 @@ def copy_model_files(c, config_file, out_path, new_fields):
in the config file. in the config file.
""" """
# copy config.json # copy config.json
copy_config_path = OUT_PATH copy_config_path = os.path.join(out_path, 'config.json')
config_lines = open(config_file, "r").readlines() config_lines = open(config_file, "r").readlines()
# add extra information fields # add extra information fields
for key, value in new_fields.items(): for key, value in new_fields.items():
@ -70,5 +70,6 @@ def copy_model_files(c, config_file, out_path, new_fields):
config_out_file.writelines(config_lines) config_out_file.writelines(config_lines)
config_out_file.close() config_out_file.close()
# copy model stats file if available # copy model stats file if available
if c.audio['stats_path'] is not None:
copy_stats_path = os.path.join(out_path, 'scale_stats.npy') copy_stats_path = os.path.join(out_path, 'scale_stats.npy')
copyfile(c.audio['stats_path'], copy_stats_path) copyfile(c.audio['stats_path'], copy_stats_path)