mirror of https://github.com/coqui-ai/TTS.git
Fix loader setup in `base_tts`
This commit is contained in:
parent
76c4929ab2
commit
2b59da802c
|
@ -70,15 +70,17 @@ class FFTransformerBlock(nn.Module):
|
|||
|
||||
|
||||
class FFTDurationPredictor:
|
||||
def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None):
|
||||
def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument
|
||||
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
|
||||
self.proj = nn.Linear(in_channels, 1)
|
||||
|
||||
def forward(self, x, mask=None, g=None):
|
||||
def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- mask: :math:`[B, 1, T]`
|
||||
|
||||
TODO: Handle the cond input
|
||||
"""
|
||||
x = self.fft(x, mask=mask)
|
||||
x = self.proj(x)
|
||||
|
|
|
@ -707,7 +707,8 @@ class FastPitchLoss(nn.Module):
|
|||
self.aligner_loss_alpha = c.aligner_loss_alpha
|
||||
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
|
||||
|
||||
def _binary_alignment_loss(self, alignment_hard, alignment_soft):
|
||||
@staticmethod
|
||||
def _binary_alignment_loss(alignment_hard, alignment_soft):
|
||||
"""Binary loss that forces soft alignments to match the hard alignments as
|
||||
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
|
||||
"""
|
||||
|
|
|
@ -252,13 +252,18 @@ class BaseTTS(BaseModel):
|
|||
# compute pitch frames and write to files.
|
||||
if config.compute_f0 and rank in [None, 0]:
|
||||
if not os.path.exists(config.f0_cache_path):
|
||||
dataset.pitch_extractor.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers)
|
||||
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
|
||||
dataset.pitch_extractor.compute_pitch(
|
||||
ap, config.get("f0_cache_path", None), config.num_loader_workers
|
||||
)
|
||||
|
||||
# halt DDP processes for the main process to finish computing the F0 cache
|
||||
if num_gpus > 1:
|
||||
dist.barrier()
|
||||
|
||||
# load pitch stats computed above by all the workers
|
||||
if config.compute_f0:
|
||||
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
|
||||
|
||||
# sampler for DDP
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
|
||||
|
|
Loading…
Reference in New Issue