mirror of https://github.com/coqui-ai/TTS.git
do not check sample rate as loading stats file for normalization to enable interpolation for different sample rate vocoder
This commit is contained in:
parent
3660c57f1e
commit
c008003506
|
@ -129,3 +129,4 @@ TODO.txt
|
|||
.vscode/*
|
||||
data/*
|
||||
notebooks/data/*
|
||||
TTS/tts/layers/glow_tts/monotonic_align/core.c
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -69,19 +69,18 @@ class GlowTts(nn.Module):
|
|||
self.length_scale=1.
|
||||
|
||||
self.encoder = Encoder(num_chars,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
filter_channels=filter_channels,
|
||||
filter_channels_dp=filter_channels_dp,
|
||||
encoder_type=encoder_type,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers_enc,
|
||||
kernel_size=kernel_size,
|
||||
dropout_p=dropout_p,
|
||||
mean_only=mean_only,
|
||||
use_prenet=use_encoder_prenet,
|
||||
c_in_channels=c_in_channels)
|
||||
|
||||
out_channels=out_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
filter_channels=filter_channels,
|
||||
filter_channels_dp=filter_channels_dp,
|
||||
encoder_type=encoder_type,
|
||||
num_heads=num_heads,
|
||||
num_layers=num_layers_enc,
|
||||
kernel_size=kernel_size,
|
||||
dropout_p=dropout_p,
|
||||
mean_only=mean_only,
|
||||
use_prenet=use_encoder_prenet,
|
||||
c_in_channels=c_in_channels)
|
||||
|
||||
self.decoder = Decoder(out_channels,
|
||||
hidden_channels_dec or hidden_channels,
|
||||
|
|
|
@ -174,8 +174,9 @@ class AudioProcessor(object):
|
|||
for key in stats_config.keys():
|
||||
if key in skip_parameters:
|
||||
continue
|
||||
assert stats_config[key] == self.__dict__[key],\
|
||||
f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
|
||||
if key != 'sample_rate':
|
||||
assert stats_config[key] == self.__dict__[key],\
|
||||
f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
|
||||
return mel_mean, mel_std, linear_mean, linear_std, stats_config
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
|
@ -322,6 +323,7 @@ class AudioProcessor(object):
|
|||
def load_wav(self, filename, sr=None):
|
||||
if sr is None:
|
||||
x, sr = sf.read(filename)
|
||||
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr)
|
||||
else:
|
||||
x, sr = librosa.load(filename, sr=sr)
|
||||
if self.do_trim_silence:
|
||||
|
@ -329,7 +331,6 @@ class AudioProcessor(object):
|
|||
x = self.trim_silence(x)
|
||||
except ValueError:
|
||||
print(f' [!] File cannot be trimmed for silence - {filename}')
|
||||
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr)
|
||||
if self.do_sound_norm:
|
||||
x = self.sound_norm(x)
|
||||
return x
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
import torch
|
||||
|
||||
from TTS.vocoder.models.melgan_generator import MelganGenerator
|
||||
|
||||
|
||||
class FullbandMelganGenerator(MelganGenerator):
|
||||
def __init__(self,
|
||||
in_channels=80,
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=(2, 8, 2, 2),
|
||||
res_kernel=3,
|
||||
num_res_blocks=4):
|
||||
super(FullbandMelganGenerator,
|
||||
self).__init__(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
proj_kernel=proj_kernel,
|
||||
base_channels=base_channels,
|
||||
upsample_factors=upsample_factors,
|
||||
res_kernel=res_kernel,
|
||||
num_res_blocks=num_res_blocks)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, cond_features):
|
||||
cond_features = cond_features.to(self.layers[1].weight.device)
|
||||
cond_features = torch.nn.functional.pad(
|
||||
cond_features,
|
||||
(self.inference_padding, self.inference_padding),
|
||||
'replicate')
|
||||
return self.layers(cond_features)
|
|
@ -67,6 +67,15 @@ def setup_generator(c):
|
|||
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||
if c.generator_model in 'fullband_melgan_generator':
|
||||
model = MyModel(
|
||||
in_channels=c.audio['num_mels'],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||
if c.generator_model in 'parallel_wavegan_generator':
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
|
|
Loading…
Reference in New Issue