Fix loader setup in `base_tts`

This commit is contained in:
Eren Gölge 2021-09-06 14:25:45 +00:00
parent 76c4929ab2
commit 2b59da802c
3 changed files with 13 additions and 5 deletions

View File

@ -70,15 +70,17 @@ class FFTransformerBlock(nn.Module):
class FFTDurationPredictor: class FFTDurationPredictor:
def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): def __init__(self, in_channels, hidden_channels, num_heads, num_layers, dropout_p=0.1, cond_channels=None): # pylint: disable=unused-argument
self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p) self.fft = FFTransformerBlock(in_channels, num_heads, hidden_channels, num_layers, dropout_p)
self.proj = nn.Linear(in_channels, 1) self.proj = nn.Linear(in_channels, 1)
def forward(self, x, mask=None, g=None): def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument
""" """
Shapes: Shapes:
- x: :math:`[B, C, T]` - x: :math:`[B, C, T]`
- mask: :math:`[B, 1, T]` - mask: :math:`[B, 1, T]`
TODO: Handle the cond input
""" """
x = self.fft(x, mask=mask) x = self.fft(x, mask=mask)
x = self.proj(x) x = self.proj(x)

View File

@ -707,7 +707,8 @@ class FastPitchLoss(nn.Module):
self.aligner_loss_alpha = c.aligner_loss_alpha self.aligner_loss_alpha = c.aligner_loss_alpha
self.binary_alignment_loss_alpha = c.binary_align_loss_alpha self.binary_alignment_loss_alpha = c.binary_align_loss_alpha
def _binary_alignment_loss(self, alignment_hard, alignment_soft): @staticmethod
def _binary_alignment_loss(alignment_hard, alignment_soft):
"""Binary loss that forces soft alignments to match the hard alignments as """Binary loss that forces soft alignments to match the hard alignments as
explained in `https://arxiv.org/pdf/2108.10447.pdf`. explained in `https://arxiv.org/pdf/2108.10447.pdf`.
""" """

View File

@ -252,13 +252,18 @@ class BaseTTS(BaseModel):
# compute pitch frames and write to files. # compute pitch frames and write to files.
if config.compute_f0 and rank in [None, 0]: if config.compute_f0 and rank in [None, 0]:
if not os.path.exists(config.f0_cache_path): if not os.path.exists(config.f0_cache_path):
dataset.pitch_extractor.compute_pitch(config.get("f0_cache_path", None), config.num_loader_workers) dataset.pitch_extractor.compute_pitch(
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None)) ap, config.get("f0_cache_path", None), config.num_loader_workers
)
# halt DDP processes for the main process to finish computing the F0 cache # halt DDP processes for the main process to finish computing the F0 cache
if num_gpus > 1: if num_gpus > 1:
dist.barrier() dist.barrier()
# load pitch stats computed above by all the workers
if config.compute_f0:
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
# sampler for DDP # sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None sampler = DistributedSampler(dataset) if num_gpus > 1 else None