Fix VITS model SPD

This commit is contained in:
Eren Gölge 2021-08-18 14:55:46 +00:00
parent 625ab614f0
commit 3ab8cef99e
4 changed files with 12 additions and 6 deletions

View File

@ -119,8 +119,8 @@ class VitsConfig(BaseTTSConfig):
compute_linear_spec: bool = True
# overrides
min_seq_len: int = 13
max_seq_len: int = 500
min_seq_len: int = 32
max_seq_len: int = 1000
r: int = 1 # DO NOT CHANGE
add_blank: bool = True

View File

@ -228,7 +228,7 @@ class StochasticDurationPredictor(nn.Module):
h = self.post_pre(dr)
h = self.post_convs(h, x_mask)
h = self.post_proj(h) * x_mask
noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
noise = torch.randn(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = noise
# posterior encoder

View File

@ -194,7 +194,7 @@ class BaseTTS(BaseModel):
if hasattr(self, "make_symbols"):
custom_symbols = self.make_symbols(self.config)
# init dataloader
# init dataset
dataset = TTSDataset(
outputs_per_step=config.r if "r" in config else 1,
text_cleaner=config.text_cleaner,
@ -220,13 +220,15 @@ class BaseTTS(BaseModel):
else None,
)
# pre-compute phonemes
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
if hasattr(self, "eval_data_items") and is_eval:
dataset.items = self.eval_data_items
elif hasattr(self, "train_data_items") and not is_eval:
dataset.items = self.train_data_items
else:
# precompute phonemes to have a better estimate of sequence lengths.
# precompute phonemes for precise estimate of sequence lengths.
# otherwise `dataset.sort_items()` uses raw text lengths
dataset.compute_input_seq(config.num_loader_workers)
# TODO: find a more efficient solution
@ -240,9 +242,13 @@ class BaseTTS(BaseModel):
if num_gpus > 1:
dist.barrier()
# sort input sequences from short to long
dataset.sort_items()
# sampler for DDP
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
# init dataloader
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,

View File

@ -177,7 +177,7 @@ class VitsArgs(Coqpit):
num_layers_text_encoder: int = 6
kernel_size_text_encoder: int = 3
dropout_p_text_encoder: int = 0.1
dropout_p_duration_predictor: int = 0.1
dropout_p_duration_predictor: int = 0.5
kernel_size_posterior_encoder: int = 5
dilation_rate_posterior_encoder: int = 1
num_layers_posterior_encoder: int = 16