Loading only one decoder and removing lazy loading

This commit is contained in:
WeberJulian 2023-10-04 07:31:21 -03:00
parent 2ecf84a2c6
commit 0d36dcfd81
1 changed files with 96 additions and 108 deletions

View File

@ -198,13 +198,12 @@ class XttsArgs(Coqpit):
Args: Args:
gpt_batch_size (int): The size of the auto-regressive batch. gpt_batch_size (int): The size of the auto-regressive batch.
enable_redaction (bool, optional): Whether to enable redaction. Defaults to True. enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
lazy_load (bool, optional): Whether to load models on demand. It reduces VRAM usage. Defaults to False.
kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True. kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None. gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet. use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True.
For GPT model: For GPT model:
ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
@ -234,12 +233,12 @@ class XttsArgs(Coqpit):
gpt_batch_size: int = 1 gpt_batch_size: int = 1
enable_redaction: bool = False enable_redaction: bool = False
lazy_load: bool = True
kv_cache: bool = True kv_cache: bool = True
gpt_checkpoint: str = None gpt_checkpoint: str = None
clvp_checkpoint: str = None clvp_checkpoint: str = None
decoder_checkpoint: str = None decoder_checkpoint: str = None
num_chars: int = 255 num_chars: int = 255
use_hifigan: bool = True
# XTTS GPT Encoder params # XTTS GPT Encoder params
tokenizer_file: str = "" tokenizer_file: str = ""
@ -297,7 +296,6 @@ class Xtts(BaseTTS):
def __init__(self, config: Coqpit): def __init__(self, config: Coqpit):
super().__init__(config, ap=None, tokenizer=None) super().__init__(config, ap=None, tokenizer=None)
self.lazy_load = self.args.lazy_load
self.mel_stats_path = None self.mel_stats_path = None
self.config = config self.config = config
self.gpt_checkpoint = self.args.gpt_checkpoint self.gpt_checkpoint = self.args.gpt_checkpoint
@ -307,7 +305,6 @@ class Xtts(BaseTTS):
self.tokenizer = VoiceBpeTokenizer() self.tokenizer = VoiceBpeTokenizer()
self.gpt = None self.gpt = None
self.diffusion_decoder = None
self.init_models() self.init_models()
self.register_buffer("mel_stats", torch.ones(80)) self.register_buffer("mel_stats", torch.ones(80))
@ -334,50 +331,38 @@ class Xtts(BaseTTS):
stop_audio_token=self.args.gpt_stop_audio_token, stop_audio_token=self.args.gpt_stop_audio_token,
) )
self.hifigan_decoder = HifiDecoder(
input_sample_rate=self.args.input_sample_rate,
output_sample_rate=self.args.output_sample_rate,
output_hop_length=self.args.output_hop_length,
ar_mel_length_compression=self.args.ar_mel_length_compression,
decoder_input_dim=self.args.decoder_input_dim,
d_vector_dim=self.args.d_vector_dim,
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
)
self.diffusion_decoder = DiffusionTts( if self.args.use_hifigan:
model_channels=self.args.diff_model_channels, self.hifigan_decoder = HifiDecoder(
num_layers=self.args.diff_num_layers, input_sample_rate=self.args.input_sample_rate,
in_channels=self.args.diff_in_channels, output_sample_rate=self.args.output_sample_rate,
out_channels=self.args.diff_out_channels, output_hop_length=self.args.output_hop_length,
in_latent_channels=self.args.diff_in_latent_channels, ar_mel_length_compression=self.args.ar_mel_length_compression,
in_tokens=self.args.diff_in_tokens, decoder_input_dim=self.args.decoder_input_dim,
dropout=self.args.diff_dropout, d_vector_dim=self.args.d_vector_dim,
use_fp16=self.args.diff_use_fp16, cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
num_heads=self.args.diff_num_heads, )
layer_drop=self.args.diff_layer_drop,
unconditioned_percentage=self.args.diff_unconditioned_percentage,
)
self.vocoder = UnivNetGenerator() else:
self.diffusion_decoder = DiffusionTts(
model_channels=self.args.diff_model_channels,
num_layers=self.args.diff_num_layers,
in_channels=self.args.diff_in_channels,
out_channels=self.args.diff_out_channels,
in_latent_channels=self.args.diff_in_latent_channels,
in_tokens=self.args.diff_in_tokens,
dropout=self.args.diff_dropout,
use_fp16=self.args.diff_use_fp16,
num_heads=self.args.diff_num_heads,
layer_drop=self.args.diff_layer_drop,
unconditioned_percentage=self.args.diff_unconditioned_percentage,
)
self.vocoder = UnivNetGenerator()
@property @property
def device(self): def device(self):
return next(self.parameters()).device return next(self.parameters()).device
@contextmanager
def lazy_load_model(self, model):
"""Context to load a model on demand.
Args:
model (nn.Module): The model to be loaded.
"""
if self.lazy_load:
yield model
else:
m = model.to(self.device)
yield m
m = model.cpu()
def get_gpt_cond_latents(self, audio_path: str, length: int = 3): def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
"""Compute the conditioning latents for the GPT model from the given audio. """Compute the conditioning latents for the GPT model from the given audio.
@ -411,8 +396,7 @@ class Xtts(BaseTTS):
) )
diffusion_conds.append(cond_mel) diffusion_conds.append(cond_mel)
diffusion_conds = torch.stack(diffusion_conds, dim=1) diffusion_conds = torch.stack(diffusion_conds, dim=1)
with self.lazy_load_model(self.diffusion_decoder) as diffusion: diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
diffusion_latent = diffusion.get_conditioning(diffusion_conds)
return diffusion_latent return diffusion_latent
def get_speaker_embedding( def get_speaker_embedding(
@ -430,10 +414,14 @@ class Xtts(BaseTTS):
audio_path, audio_path,
gpt_cond_len=3, gpt_cond_len=3,
): ):
speaker_embedding = None
diffusion_cond_latents = None
if self.args.use_hifigan:
speaker_embedding = self.get_speaker_embedding(audio_path)
else:
diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path)
gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T] gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path) return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
speaker_embedding = self.get_speaker_embedding(audio_path)
return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device), speaker_embedding
def synthesize(self, text, config, speaker_wav, language, **kwargs): def synthesize(self, text, config, speaker_wav, language, **kwargs):
"""Synthesize speech with the given input text. """Synthesize speech with the given input text.
@ -500,7 +488,6 @@ class Xtts(BaseTTS):
cond_free_k=2, cond_free_k=2,
diffusion_temperature=1.0, diffusion_temperature=1.0,
decoder_sampler="ddim", decoder_sampler="ddim",
use_hifigan=True,
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
""" """
@ -579,7 +566,6 @@ class Xtts(BaseTTS):
cond_free_k=cond_free_k, cond_free_k=cond_free_k,
diffusion_temperature=diffusion_temperature, diffusion_temperature=diffusion_temperature,
decoder_sampler=decoder_sampler, decoder_sampler=decoder_sampler,
use_hifigan=use_hifigan,
**hf_generate_kwargs, **hf_generate_kwargs,
) )
@ -614,7 +600,7 @@ class Xtts(BaseTTS):
text_tokens.shape[-1] < self.args.gpt_max_text_tokens text_tokens.shape[-1] < self.args.gpt_max_text_tokens
), " ❗ XTTS can only generate text with a maximum of 400 tokens." ), " ❗ XTTS can only generate text with a maximum of 400 tokens."
if not use_hifigan: if not self.args.use_hifigan:
diffuser = load_discrete_vocoder_diffuser( diffuser = load_discrete_vocoder_diffuser(
desired_diffusion_steps=decoder_iterations, desired_diffusion_steps=decoder_iterations,
cond_free=cond_free, cond_free=cond_free,
@ -623,60 +609,55 @@ class Xtts(BaseTTS):
) )
with torch.no_grad(): with torch.no_grad():
self.gpt = self.gpt.to(self.device) gpt_codes = self.gpt.generate(
with self.lazy_load_model(self.gpt) as gpt: cond_latents=gpt_cond_latent,
gpt_codes = gpt.generate( text_inputs=text_tokens,
cond_latents=gpt_cond_latent, input_tokens=None,
text_inputs=text_tokens, do_sample=do_sample,
input_tokens=None, top_p=top_p,
do_sample=do_sample, top_k=top_k,
top_p=top_p, temperature=temperature,
top_k=top_k, num_return_sequences=self.gpt_batch_size,
temperature=temperature, length_penalty=length_penalty,
num_return_sequences=self.gpt_batch_size, repetition_penalty=repetition_penalty,
length_penalty=length_penalty, output_attentions=False,
repetition_penalty=repetition_penalty, **hf_generate_kwargs,
output_attentions=False, )
**hf_generate_kwargs, expected_output_len = torch.tensor(
) [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
expected_output_len = torch.tensor( )
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
) gpt_latents = self.gpt(
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) text_tokens,
gpt_latents = gpt( text_len,
text_tokens, gpt_codes,
text_len, expected_output_len,
gpt_codes, cond_latents=gpt_cond_latent,
expected_output_len, return_attentions=False,
cond_latents=gpt_cond_latent, return_latent=True,
return_attentions=False, )
return_latent=True, silence_token = 83
) ctokens = 0
silence_token = 83 for k in range(gpt_codes.shape[-1]):
ctokens = 0 if gpt_codes[0, k] == silence_token:
for k in range(gpt_codes.shape[-1]): ctokens += 1
if gpt_codes[0, k] == silence_token: else:
ctokens += 1 ctokens = 0
else: if ctokens > 8:
ctokens = 0 gpt_latents = gpt_latents[:, :k]
if ctokens > 8: break
gpt_latents = gpt_latents[:, :k]
break
if use_hifigan: if self.args.use_hifigan:
with self.lazy_load_model(self.hifigan_decoder) as hifigan_decoder: wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
wav = hifigan_decoder(gpt_latents, g=speaker_embedding)
else: else:
with self.lazy_load_model(self.diffusion_decoder) as diffusion: mel = do_spectrogram_diffusion(
mel = do_spectrogram_diffusion( self.diffusion_decoder,
diffusion, diffuser,
diffuser, gpt_latents,
gpt_latents, diffusion_conditioning,
diffusion_conditioning, temperature=diffusion_temperature,
temperature=diffusion_temperature, )
) wav = self.vocoder.inference(mel)
with self.lazy_load_model(self.vocoder) as vocoder:
wav = vocoder.inference(mel)
return {"wav": wav.cpu().numpy().squeeze()} return {"wav": wav.cpu().numpy().squeeze()}
@ -713,6 +694,7 @@ class Xtts(BaseTTS):
# Decoder inference # Decoder inference
**hf_generate_kwargs, **hf_generate_kwargs,
): ):
assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
text = f"[{language}]{text.strip().lower()}" text = f"[{language}]{text.strip().lower()}"
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
@ -781,7 +763,7 @@ class Xtts(BaseTTS):
vocab_path=None, vocab_path=None,
eval=False, eval=False,
strict=True, strict=True,
use_deepspeed=False use_deepspeed=False,
): ):
""" """
Loads a checkpoint from disk and initializes the model's state and tokenizer. Loads a checkpoint from disk and initializes the model's state and tokenizer.
@ -807,14 +789,20 @@ class Xtts(BaseTTS):
self.init_models() self.init_models()
if eval: if eval:
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache) self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
self.load_state_dict(load_fsspec(model_path, map_location=self.device)["model"], strict=strict)
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"]
for key in list(checkpoint.keys()):
if key.split(".")[0] in ignore_keys:
del checkpoint[key]
self.load_state_dict(checkpoint, strict=strict)
if eval: if eval:
if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
if hasattr(self, "vocoder"): self.vocoder.eval()
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed) self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
self.gpt.eval() self.gpt.eval()
self.diffusion_decoder.eval()
self.vocoder.eval()
self.hifigan_decoder.eval()
def train_step(self): def train_step(self):
raise NotImplementedError("XTTS Training is not implemented") raise NotImplementedError("XTTS Training is not implemented")