Fix loader setup in `base_tts`

This commit is contained in:
Eren Gölge 2021-09-06 14:25:45 +00:00
parent 76c4929ab2
commit 2b59da802c
3 changed files with 13 additions and 5 deletions

View File

@ -70,15 +70,17 @@ class FFTransformerBlock(nn.Module):
class FFTDurationPredictor:
def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None):
def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
self.proj = nn.Linear(in_channels, 1)
def forward(self, x, mask=None, g=None):
def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument
"""
Shapes:
- x: :math:`[B, C, T]`
- mask: :math:`[B, 1, T]`
TODO: Handle the cond input
"""
x = self.fft(x, mask=mask)
x = self.proj(x)

View File

@ -707,7 +707,8 @@ class FastPitchLoss(nn.Module):
self.aligner_loss_alpha = c.aligner_loss_alpha
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
def _binary_alignment_loss(self, alignment_hard, alignment_soft):
@staticmethod
def _binary_alignment_loss(alignment_hard, alignment_soft):
"""Binary loss that forces soft alignments to match the hard alignments as
explained in `https://arxiv.org/pdf/2108.10447.pdf`.
"""

View File

@ -252,13 +252,18 @@ class BaseTTS(BaseModel):
# compute pitch frames and write to files.
if config.compute_f0 and rank in [None, 0]:
if not os.path.exists(config.f0_cache_path):
dataset.pitch_extractor.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers)
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
dataset.pitch_extractor.compute_pitch(
ap, config.get("f0_cache_path", None), config.num_loader_workers
)
# halt DDP processes for the main process to finish computing the F0 cache
if num_gpus > 1:
dist.barrier()
# load pitch stats computed above by all the workers
if config.compute_f0:
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None