mirror of https://github.com/coqui-ai/TTS.git
Loading only one decoder and removing lazy loading
This commit is contained in:
parent
2ecf84a2c6
commit
0d36dcfd81
|
@ -198,13 +198,12 @@ class XttsArgs(Coqpit):
|
||||||
Args:
|
Args:
|
||||||
gpt_batch_size (int): The size of the auto-regressive batch.
|
gpt_batch_size (int): The size of the auto-regressive batch.
|
||||||
enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
|
enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
|
||||||
lazy_load (bool, optional): Whether to load models on demand. It reduces VRAM usage. Defaults to False.
|
|
||||||
kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
|
kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
|
||||||
gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
|
gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
|
||||||
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
|
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
|
||||||
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
|
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
|
||||||
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
|
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
|
||||||
vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet.
|
use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True.
|
||||||
|
|
||||||
For GPT model:
|
For GPT model:
|
||||||
ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
||||||
|
@ -234,12 +233,12 @@ class XttsArgs(Coqpit):
|
||||||
|
|
||||||
gpt_batch_size: int = 1
|
gpt_batch_size: int = 1
|
||||||
enable_redaction: bool = False
|
enable_redaction: bool = False
|
||||||
lazy_load: bool = True
|
|
||||||
kv_cache: bool = True
|
kv_cache: bool = True
|
||||||
gpt_checkpoint: str = None
|
gpt_checkpoint: str = None
|
||||||
clvp_checkpoint: str = None
|
clvp_checkpoint: str = None
|
||||||
decoder_checkpoint: str = None
|
decoder_checkpoint: str = None
|
||||||
num_chars: int = 255
|
num_chars: int = 255
|
||||||
|
use_hifigan: bool = True
|
||||||
|
|
||||||
# XTTS GPT Encoder params
|
# XTTS GPT Encoder params
|
||||||
tokenizer_file: str = ""
|
tokenizer_file: str = ""
|
||||||
|
@ -297,7 +296,6 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
def __init__(self, config: Coqpit):
|
def __init__(self, config: Coqpit):
|
||||||
super().__init__(config, ap=None, tokenizer=None)
|
super().__init__(config, ap=None, tokenizer=None)
|
||||||
self.lazy_load = self.args.lazy_load
|
|
||||||
self.mel_stats_path = None
|
self.mel_stats_path = None
|
||||||
self.config = config
|
self.config = config
|
||||||
self.gpt_checkpoint = self.args.gpt_checkpoint
|
self.gpt_checkpoint = self.args.gpt_checkpoint
|
||||||
|
@ -307,7 +305,6 @@ class Xtts(BaseTTS):
|
||||||
|
|
||||||
self.tokenizer = VoiceBpeTokenizer()
|
self.tokenizer = VoiceBpeTokenizer()
|
||||||
self.gpt = None
|
self.gpt = None
|
||||||
self.diffusion_decoder = None
|
|
||||||
self.init_models()
|
self.init_models()
|
||||||
self.register_buffer("mel_stats", torch.ones(80))
|
self.register_buffer("mel_stats", torch.ones(80))
|
||||||
|
|
||||||
|
@ -334,6 +331,8 @@ class Xtts(BaseTTS):
|
||||||
stop_audio_token=self.args.gpt_stop_audio_token,
|
stop_audio_token=self.args.gpt_stop_audio_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if self.args.use_hifigan:
|
||||||
self.hifigan_decoder = HifiDecoder(
|
self.hifigan_decoder = HifiDecoder(
|
||||||
input_sample_rate=self.args.input_sample_rate,
|
input_sample_rate=self.args.input_sample_rate,
|
||||||
output_sample_rate=self.args.output_sample_rate,
|
output_sample_rate=self.args.output_sample_rate,
|
||||||
|
@ -344,6 +343,7 @@ class Xtts(BaseTTS):
|
||||||
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
self.diffusion_decoder = DiffusionTts(
|
self.diffusion_decoder = DiffusionTts(
|
||||||
model_channels=self.args.diff_model_channels,
|
model_channels=self.args.diff_model_channels,
|
||||||
num_layers=self.args.diff_num_layers,
|
num_layers=self.args.diff_num_layers,
|
||||||
|
@ -357,27 +357,12 @@ class Xtts(BaseTTS):
|
||||||
layer_drop=self.args.diff_layer_drop,
|
layer_drop=self.args.diff_layer_drop,
|
||||||
unconditioned_percentage=self.args.diff_unconditioned_percentage,
|
unconditioned_percentage=self.args.diff_unconditioned_percentage,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.vocoder = UnivNetGenerator()
|
self.vocoder = UnivNetGenerator()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self):
|
def device(self):
|
||||||
return next(self.parameters()).device
|
return next(self.parameters()).device
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def lazy_load_model(self, model):
|
|
||||||
"""Context to load a model on demand.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model (nn.Module): The model to be loaded.
|
|
||||||
"""
|
|
||||||
if self.lazy_load:
|
|
||||||
yield model
|
|
||||||
else:
|
|
||||||
m = model.to(self.device)
|
|
||||||
yield m
|
|
||||||
m = model.cpu()
|
|
||||||
|
|
||||||
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
|
def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
|
||||||
"""Compute the conditioning latents for the GPT model from the given audio.
|
"""Compute the conditioning latents for the GPT model from the given audio.
|
||||||
|
|
||||||
|
@ -411,8 +396,7 @@ class Xtts(BaseTTS):
|
||||||
)
|
)
|
||||||
diffusion_conds.append(cond_mel)
|
diffusion_conds.append(cond_mel)
|
||||||
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
diffusion_conds = torch.stack(diffusion_conds, dim=1)
|
||||||
with self.lazy_load_model(self.diffusion_decoder) as diffusion:
|
diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
|
||||||
diffusion_latent = diffusion.get_conditioning(diffusion_conds)
|
|
||||||
return diffusion_latent
|
return diffusion_latent
|
||||||
|
|
||||||
def get_speaker_embedding(
|
def get_speaker_embedding(
|
||||||
|
@ -430,10 +414,14 @@ class Xtts(BaseTTS):
|
||||||
audio_path,
|
audio_path,
|
||||||
gpt_cond_len=3,
|
gpt_cond_len=3,
|
||||||
):
|
):
|
||||||
gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
|
speaker_embedding = None
|
||||||
diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path)
|
diffusion_cond_latents = None
|
||||||
|
if self.args.use_hifigan:
|
||||||
speaker_embedding = self.get_speaker_embedding(audio_path)
|
speaker_embedding = self.get_speaker_embedding(audio_path)
|
||||||
return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device), speaker_embedding
|
else:
|
||||||
|
diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path)
|
||||||
|
gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
|
||||||
|
return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
|
||||||
|
|
||||||
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
def synthesize(self, text, config, speaker_wav, language, **kwargs):
|
||||||
"""Synthesize speech with the given input text.
|
"""Synthesize speech with the given input text.
|
||||||
|
@ -500,7 +488,6 @@ class Xtts(BaseTTS):
|
||||||
cond_free_k=2,
|
cond_free_k=2,
|
||||||
diffusion_temperature=1.0,
|
diffusion_temperature=1.0,
|
||||||
decoder_sampler="ddim",
|
decoder_sampler="ddim",
|
||||||
use_hifigan=True,
|
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -579,7 +566,6 @@ class Xtts(BaseTTS):
|
||||||
cond_free_k=cond_free_k,
|
cond_free_k=cond_free_k,
|
||||||
diffusion_temperature=diffusion_temperature,
|
diffusion_temperature=diffusion_temperature,
|
||||||
decoder_sampler=decoder_sampler,
|
decoder_sampler=decoder_sampler,
|
||||||
use_hifigan=use_hifigan,
|
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -614,7 +600,7 @@ class Xtts(BaseTTS):
|
||||||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||||
|
|
||||||
if not use_hifigan:
|
if not self.args.use_hifigan:
|
||||||
diffuser = load_discrete_vocoder_diffuser(
|
diffuser = load_discrete_vocoder_diffuser(
|
||||||
desired_diffusion_steps=decoder_iterations,
|
desired_diffusion_steps=decoder_iterations,
|
||||||
cond_free=cond_free,
|
cond_free=cond_free,
|
||||||
|
@ -623,9 +609,7 @@ class Xtts(BaseTTS):
|
||||||
)
|
)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.gpt = self.gpt.to(self.device)
|
gpt_codes = self.gpt.generate(
|
||||||
with self.lazy_load_model(self.gpt) as gpt:
|
|
||||||
gpt_codes = gpt.generate(
|
|
||||||
cond_latents=gpt_cond_latent,
|
cond_latents=gpt_cond_latent,
|
||||||
text_inputs=text_tokens,
|
text_inputs=text_tokens,
|
||||||
input_tokens=None,
|
input_tokens=None,
|
||||||
|
@ -643,7 +627,7 @@ class Xtts(BaseTTS):
|
||||||
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
||||||
)
|
)
|
||||||
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
|
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
|
||||||
gpt_latents = gpt(
|
gpt_latents = self.gpt(
|
||||||
text_tokens,
|
text_tokens,
|
||||||
text_len,
|
text_len,
|
||||||
gpt_codes,
|
gpt_codes,
|
||||||
|
@ -663,20 +647,17 @@ class Xtts(BaseTTS):
|
||||||
gpt_latents = gpt_latents[:, :k]
|
gpt_latents = gpt_latents[:, :k]
|
||||||
break
|
break
|
||||||
|
|
||||||
if use_hifigan:
|
if self.args.use_hifigan:
|
||||||
with self.lazy_load_model(self.hifigan_decoder) as hifigan_decoder:
|
wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
|
||||||
wav = hifigan_decoder(gpt_latents, g=speaker_embedding)
|
|
||||||
else:
|
else:
|
||||||
with self.lazy_load_model(self.diffusion_decoder) as diffusion:
|
|
||||||
mel = do_spectrogram_diffusion(
|
mel = do_spectrogram_diffusion(
|
||||||
diffusion,
|
self.diffusion_decoder,
|
||||||
diffuser,
|
diffuser,
|
||||||
gpt_latents,
|
gpt_latents,
|
||||||
diffusion_conditioning,
|
diffusion_conditioning,
|
||||||
temperature=diffusion_temperature,
|
temperature=diffusion_temperature,
|
||||||
)
|
)
|
||||||
with self.lazy_load_model(self.vocoder) as vocoder:
|
wav = self.vocoder.inference(mel)
|
||||||
wav = vocoder.inference(mel)
|
|
||||||
|
|
||||||
return {"wav": wav.cpu().numpy().squeeze()}
|
return {"wav": wav.cpu().numpy().squeeze()}
|
||||||
|
|
||||||
|
@ -713,6 +694,7 @@ class Xtts(BaseTTS):
|
||||||
# Decoder inference
|
# Decoder inference
|
||||||
**hf_generate_kwargs,
|
**hf_generate_kwargs,
|
||||||
):
|
):
|
||||||
|
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
|
||||||
text = f"[{language}]{text.strip().lower()}"
|
text = f"[{language}]{text.strip().lower()}"
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
@ -781,7 +763,7 @@ class Xtts(BaseTTS):
|
||||||
vocab_path=None,
|
vocab_path=None,
|
||||||
eval=False,
|
eval=False,
|
||||||
strict=True,
|
strict=True,
|
||||||
use_deepspeed=False
|
use_deepspeed=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
||||||
|
@ -807,14 +789,20 @@ class Xtts(BaseTTS):
|
||||||
self.init_models()
|
self.init_models()
|
||||||
if eval:
|
if eval:
|
||||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
|
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
|
||||||
self.load_state_dict(load_fsspec(model_path, map_location=self.device)["model"], strict=strict)
|
|
||||||
|
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
||||||
|
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"]
|
||||||
|
for key in list(checkpoint.keys()):
|
||||||
|
if key.split(".")[0] in ignore_keys:
|
||||||
|
del checkpoint[key]
|
||||||
|
self.load_state_dict(checkpoint, strict=strict)
|
||||||
|
|
||||||
if eval:
|
if eval:
|
||||||
|
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
|
||||||
|
if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
|
||||||
|
if hasattr(self, "vocoder"): self.vocoder.eval()
|
||||||
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
|
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
|
||||||
self.gpt.eval()
|
self.gpt.eval()
|
||||||
self.diffusion_decoder.eval()
|
|
||||||
self.vocoder.eval()
|
|
||||||
self.hifigan_decoder.eval()
|
|
||||||
|
|
||||||
def train_step(self):
|
def train_step(self):
|
||||||
raise NotImplementedError("XTTS Training is not implemented")
|
raise NotImplementedError("XTTS Training is not implemented")
|
||||||
|
|
Loading…
Reference in New Issue