diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py index 12f0bbb0..2fe9bcc4 100644 --- a/TTS/tts/layers/generic/transformer.py +++ b/TTS/tts/layers/generic/transformer.py @@ -70,15 +70,17 @@ class FFTransformerBlock(nn.Module): class FFTDurationPredictor: - def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): + def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p) self.proj = nn.Linear(in_channels, 1) - def forward(self, x, mask=None, g=None): + def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument """ Shapes: - x: :math:`[B, C, T]` - mask: :math:`[B, 1, T]` + + TODO: Handle the cond input """ x = self.fft(x, mask=mask) x = self.proj(x) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 100b8fb3..a2fd7635 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -707,7 +707,8 @@ class FastPitchLoss(nn.Module): self.aligner_loss_alpha = c.aligner_loss_alpha self.binary_alignment_loss_alpha = c.binary_align_loss_alpha - def _binary_alignment_loss(self, alignment_hard, alignment_soft): + @staticmethod + def _binary_alignment_loss(alignment_hard, alignment_soft): """Binary loss that forces soft alignments to match the hard alignments as explained in `https://arxiv.org/pdf/2108.10447.pdf`. """ diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 653143cd..06c7cb2b 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -252,13 +252,18 @@ class BaseTTS(BaseModel): # compute pitch frames and write to files. if config.compute_f0 and rank in [None, 0]: if not os.path.exists(config.f0_cache_path): - dataset.pitch_extractor.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers) - dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None)) + dataset.pitch_extractor.compute_pitch( + ap, config.get("f0_cache_path", None), config.num_loader_workers + ) # halt DDP processes for the main process to finish computing the F0 cache if num_gpus > 1: dist.barrier() + # load pitch stats computed above by all the workers + if config.compute_f0: + dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None)) + # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None