mirror of https://github.com/coqui-ai/TTS.git
Fix loader setup in `base_tts`
This commit is contained in:
parent
76c4929ab2
commit
2b59da802c
|
@ -70,15 +70,17 @@ class FFTransformerBlock(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FFTDurationPredictor:
|
class FFTDurationPredictor:
|
||||||
def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None):
|
def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument
|
||||||
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
|
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
|
||||||
self.proj = nn.Linear(in_channels, 1)
|
self.proj = nn.Linear(in_channels, 1)
|
||||||
|
|
||||||
def forward(self, x, mask=None, g=None):
|
def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- x: :math:`[B, C, T]`
|
- x: :math:`[B, C, T]`
|
||||||
- mask: :math:`[B, 1, T]`
|
- mask: :math:`[B, 1, T]`
|
||||||
|
|
||||||
|
TODO: Handle the cond input
|
||||||
"""
|
"""
|
||||||
x = self.fft(x, mask=mask)
|
x = self.fft(x, mask=mask)
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
|
|
|
@ -707,7 +707,8 @@ class FastPitchLoss(nn.Module):
|
||||||
self.aligner_loss_alpha = c.aligner_loss_alpha
|
self.aligner_loss_alpha = c.aligner_loss_alpha
|
||||||
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
|
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
|
||||||
|
|
||||||
def _binary_alignment_loss(self, alignment_hard, alignment_soft):
|
@staticmethod
|
||||||
|
def _binary_alignment_loss(alignment_hard, alignment_soft):
|
||||||
"""Binary loss that forces soft alignments to match the hard alignments as
|
"""Binary loss that forces soft alignments to match the hard alignments as
|
||||||
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -252,13 +252,18 @@ class BaseTTS(BaseModel):
|
||||||
# compute pitch frames and write to files.
|
# compute pitch frames and write to files.
|
||||||
if config.compute_f0 and rank in [None, 0]:
|
if config.compute_f0 and rank in [None, 0]:
|
||||||
if not os.path.exists(config.f0_cache_path):
|
if not os.path.exists(config.f0_cache_path):
|
||||||
dataset.pitch_extractor.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers)
|
dataset.pitch_extractor.compute_pitch(
|
||||||
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
|
ap, config.get("f0_cache_path", None), config.num_loader_workers
|
||||||
|
)
|
||||||
|
|
||||||
# halt DDP processes for the main process to finish computing the F0 cache
|
# halt DDP processes for the main process to finish computing the F0 cache
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
# load pitch stats computed above by all the workers
|
||||||
|
if config.compute_f0:
|
||||||
|
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
|
||||||
|
|
||||||
# sampler for DDP
|
# sampler for DDP
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue