mirror of https://github.com/coqui-ai/TTS.git
Fix VITS model SPD
This commit is contained in:
parent
625ab614f0
commit
3ab8cef99e
|
@ -119,8 +119,8 @@ class VitsConfig(BaseTTSConfig):
|
|||
compute_linear_spec: bool = True
|
||||
|
||||
# overrides
|
||||
min_seq_len: int = 13
|
||||
max_seq_len: int = 500
|
||||
min_seq_len: int = 32
|
||||
max_seq_len: int = 1000
|
||||
r: int = 1 # DO NOT CHANGE
|
||||
add_blank: bool = True
|
||||
|
||||
|
|
|
@ -228,7 +228,7 @@ class StochasticDurationPredictor(nn.Module):
|
|||
h = self.post_pre(dr)
|
||||
h = self.post_convs(h, x_mask)
|
||||
h = self.post_proj(h) * x_mask
|
||||
noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||
noise = torch.randn(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||
z_q = noise
|
||||
|
||||
# posterior encoder
|
||||
|
|
|
@ -194,7 +194,7 @@ class BaseTTS(BaseModel):
|
|||
if hasattr(self, "make_symbols"):
|
||||
custom_symbols = self.make_symbols(self.config)
|
||||
|
||||
# init dataloader
|
||||
# init dataset
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=config.r if "r" in config else 1,
|
||||
text_cleaner=config.text_cleaner,
|
||||
|
@ -220,13 +220,15 @@ class BaseTTS(BaseModel):
|
|||
else None,
|
||||
)
|
||||
|
||||
# pre-compute phonemes
|
||||
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
|
||||
if hasattr(self, "eval_data_items") and is_eval:
|
||||
dataset.items = self.eval_data_items
|
||||
elif hasattr(self, "train_data_items") and not is_eval:
|
||||
dataset.items = self.train_data_items
|
||||
else:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
# precompute phonemes for precise estimate of sequence lengths.
|
||||
# otherwise `dataset.sort_items()` uses raw text lengths
|
||||
dataset.compute_input_seq(config.num_loader_workers)
|
||||
|
||||
# TODO: find a more efficient solution
|
||||
|
@ -240,9 +242,13 @@ class BaseTTS(BaseModel):
|
|||
if num_gpus > 1:
|
||||
dist.barrier()
|
||||
|
||||
# sort input sequences from short to long
|
||||
dataset.sort_items()
|
||||
|
||||
# sampler for DDP
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
|
||||
# init dataloader
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
|
|
|
@ -177,7 +177,7 @@ class VitsArgs(Coqpit):
|
|||
num_layers_text_encoder: int = 6
|
||||
kernel_size_text_encoder: int = 3
|
||||
dropout_p_text_encoder: int = 0.1
|
||||
dropout_p_duration_predictor: int = 0.1
|
||||
dropout_p_duration_predictor: int = 0.5
|
||||
kernel_size_posterior_encoder: int = 5
|
||||
dilation_rate_posterior_encoder: int = 1
|
||||
num_layers_posterior_encoder: int = 16
|
||||
|
|
Loading…
Reference in New Issue