mirror of https://github.com/coqui-ai/TTS.git
Fix VITS model SPD
This commit is contained in:
parent
625ab614f0
commit
3ab8cef99e
|
@ -119,8 +119,8 @@ class VitsConfig(BaseTTSConfig):
|
||||||
compute_linear_spec: bool = True
|
compute_linear_spec: bool = True
|
||||||
|
|
||||||
# overrides
|
# overrides
|
||||||
min_seq_len: int = 13
|
min_seq_len: int = 32
|
||||||
max_seq_len: int = 500
|
max_seq_len: int = 1000
|
||||||
r: int = 1 # DO NOT CHANGE
|
r: int = 1 # DO NOT CHANGE
|
||||||
add_blank: bool = True
|
add_blank: bool = True
|
||||||
|
|
||||||
|
|
|
@ -228,7 +228,7 @@ class StochasticDurationPredictor(nn.Module):
|
||||||
h = self.post_pre(dr)
|
h = self.post_pre(dr)
|
||||||
h = self.post_convs(h, x_mask)
|
h = self.post_convs(h, x_mask)
|
||||||
h = self.post_proj(h) * x_mask
|
h = self.post_proj(h) * x_mask
|
||||||
noise = torch.rand(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
noise = torch.randn(dr.size(0), 2, dr.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||||
z_q = noise
|
z_q = noise
|
||||||
|
|
||||||
# posterior encoder
|
# posterior encoder
|
||||||
|
|
|
@ -194,7 +194,7 @@ class BaseTTS(BaseModel):
|
||||||
if hasattr(self, "make_symbols"):
|
if hasattr(self, "make_symbols"):
|
||||||
custom_symbols = self.make_symbols(self.config)
|
custom_symbols = self.make_symbols(self.config)
|
||||||
|
|
||||||
# init dataloader
|
# init dataset
|
||||||
dataset = TTSDataset(
|
dataset = TTSDataset(
|
||||||
outputs_per_step=config.r if "r" in config else 1,
|
outputs_per_step=config.r if "r" in config else 1,
|
||||||
text_cleaner=config.text_cleaner,
|
text_cleaner=config.text_cleaner,
|
||||||
|
@ -220,13 +220,15 @@ class BaseTTS(BaseModel):
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pre-compute phonemes
|
||||||
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
|
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
|
||||||
if hasattr(self, "eval_data_items") and is_eval:
|
if hasattr(self, "eval_data_items") and is_eval:
|
||||||
dataset.items = self.eval_data_items
|
dataset.items = self.eval_data_items
|
||||||
elif hasattr(self, "train_data_items") and not is_eval:
|
elif hasattr(self, "train_data_items") and not is_eval:
|
||||||
dataset.items = self.train_data_items
|
dataset.items = self.train_data_items
|
||||||
else:
|
else:
|
||||||
# precompute phonemes to have a better estimate of sequence lengths.
|
# precompute phonemes for precise estimate of sequence lengths.
|
||||||
|
# otherwise `dataset.sort_items()` uses raw text lengths
|
||||||
dataset.compute_input_seq(config.num_loader_workers)
|
dataset.compute_input_seq(config.num_loader_workers)
|
||||||
|
|
||||||
# TODO: find a more efficient solution
|
# TODO: find a more efficient solution
|
||||||
|
@ -240,9 +242,13 @@ class BaseTTS(BaseModel):
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
|
|
||||||
|
# sort input sequences from short to long
|
||||||
dataset.sort_items()
|
dataset.sort_items()
|
||||||
|
|
||||||
|
# sampler for DDP
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
|
||||||
|
# init dataloader
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||||
|
|
|
@ -177,7 +177,7 @@ class VitsArgs(Coqpit):
|
||||||
num_layers_text_encoder: int = 6
|
num_layers_text_encoder: int = 6
|
||||||
kernel_size_text_encoder: int = 3
|
kernel_size_text_encoder: int = 3
|
||||||
dropout_p_text_encoder: int = 0.1
|
dropout_p_text_encoder: int = 0.1
|
||||||
dropout_p_duration_predictor: int = 0.1
|
dropout_p_duration_predictor: int = 0.5
|
||||||
kernel_size_posterior_encoder: int = 5
|
kernel_size_posterior_encoder: int = 5
|
||||||
dilation_rate_posterior_encoder: int = 1
|
dilation_rate_posterior_encoder: int = 1
|
||||||
num_layers_posterior_encoder: int = 16
|
num_layers_posterior_encoder: int = 16
|
||||||
|
|
Loading…
Reference in New Issue