mirror of https://github.com/coqui-ai/TTS.git
refactor(vocoder): remove duplicate function
This commit is contained in:
parent
6ecf47312c
commit
0f69d31f70
|
@ -12,6 +12,13 @@ from TTS.vocoder.layers.upsample import ConvUpsample
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
||||||
|
assert layers % stacks == 0
|
||||||
|
layers_per_cycle = layers // stacks
|
||||||
|
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||||
|
return (kernel_size - 1) * sum(dilations) + 1
|
||||||
|
|
||||||
|
|
||||||
class ParallelWaveganGenerator(torch.nn.Module):
|
class ParallelWaveganGenerator(torch.nn.Module):
|
||||||
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
|
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
|
||||||
It is similar to WaveNet with no causal convolution.
|
It is similar to WaveNet with no causal convolution.
|
||||||
|
@ -144,16 +151,9 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
||||||
|
|
||||||
self.apply(_apply_weight_norm)
|
self.apply(_apply_weight_norm)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
|
||||||
assert layers % stacks == 0
|
|
||||||
layers_per_cycle = layers // stacks
|
|
||||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
|
||||||
return (kernel_size - 1) * sum(dilations) + 1
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def receptive_field_size(self):
|
def receptive_field_size(self):
|
||||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
return _get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False, cache=False
|
self, config, checkpoint_path, eval=False, cache=False
|
||||||
|
|
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
||||||
from torch.nn.utils import parametrize
|
from torch.nn.utils import parametrize
|
||||||
|
|
||||||
from TTS.vocoder.layers.lvc_block import LVCBlock
|
from TTS.vocoder.layers.lvc_block import LVCBlock
|
||||||
|
from TTS.vocoder.models.parallel_wavegan_generator import _get_receptive_field_size
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -133,17 +134,10 @@ class UnivnetGenerator(torch.nn.Module):
|
||||||
|
|
||||||
self.apply(_apply_weight_norm)
|
self.apply(_apply_weight_norm)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
|
||||||
assert layers % stacks == 0
|
|
||||||
layers_per_cycle = layers // stacks
|
|
||||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
|
||||||
return (kernel_size - 1) * sum(dilations) + 1
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def receptive_field_size(self):
|
def receptive_field_size(self):
|
||||||
"""Return receptive field size."""
|
"""Return receptive field size."""
|
||||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
return _get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, c):
|
def inference(self, c):
|
||||||
|
|
Loading…
Reference in New Issue