mirror of https://github.com/coqui-ai/TTS.git
XTTS: add inference_stream_text (slightly friendlier for text-streaming)
This commit is contained in:
parent
dbf1a08a0d
commit
57a47d26bb
|
@ -209,6 +209,8 @@ class Xtts(BaseTTS):
|
||||||
self.decoder_checkpoint = self.args.decoder_checkpoint # TODO: check if this is even needed
|
self.decoder_checkpoint = self.args.decoder_checkpoint # TODO: check if this is even needed
|
||||||
self.models_dir = config.model_dir
|
self.models_dir = config.model_dir
|
||||||
self.gpt_batch_size = self.args.gpt_batch_size
|
self.gpt_batch_size = self.args.gpt_batch_size
|
||||||
|
self._stream_text_holder = []
|
||||||
|
self._stream_generator = None
|
||||||
|
|
||||||
self.tokenizer = VoiceBpeTokenizer()
|
self.tokenizer = VoiceBpeTokenizer()
|
||||||
self.gpt = None
|
self.gpt = None
|
||||||
|
@ -632,64 +634,140 @@ class Xtts(BaseTTS):
|
||||||
length_scale = 1.0 / max(speed, 0.05)
|
length_scale = 1.0 / max(speed, 0.05)
|
||||||
gpt_cond_latent = gpt_cond_latent.to(self.device)
|
gpt_cond_latent = gpt_cond_latent.to(self.device)
|
||||||
speaker_embedding = speaker_embedding.to(self.device)
|
speaker_embedding = speaker_embedding.to(self.device)
|
||||||
if enable_text_splitting:
|
text_streaming = (text is None)
|
||||||
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
|
||||||
else:
|
|
||||||
text = [text]
|
|
||||||
|
|
||||||
for sent in text:
|
while True:
|
||||||
sent = sent.strip().lower()
|
if text_streaming:
|
||||||
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
yield None
|
||||||
|
if len(self._stream_text_holder) == 0:
|
||||||
|
return
|
||||||
|
text, enable_text_splitting = self._stream_text_holder
|
||||||
|
|
||||||
assert (
|
if enable_text_splitting:
|
||||||
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
||||||
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
else:
|
||||||
|
text = [text]
|
||||||
|
|
||||||
fake_inputs = self.gpt.compute_embeddings(
|
for sent in text:
|
||||||
gpt_cond_latent.to(self.device),
|
sent = sent.strip().lower()
|
||||||
text_tokens,
|
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
||||||
)
|
|
||||||
gpt_generator = self.gpt.get_generator(
|
|
||||||
fake_inputs=fake_inputs,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
temperature=temperature,
|
|
||||||
do_sample=do_sample,
|
|
||||||
num_beams=1,
|
|
||||||
num_return_sequences=1,
|
|
||||||
length_penalty=float(length_penalty),
|
|
||||||
repetition_penalty=float(repetition_penalty),
|
|
||||||
output_attentions=False,
|
|
||||||
output_hidden_states=True,
|
|
||||||
**hf_generate_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
last_tokens = []
|
assert (
|
||||||
all_latents = []
|
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
||||||
wav_gen_prev = None
|
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
||||||
wav_overlap = None
|
|
||||||
is_end = False
|
|
||||||
|
|
||||||
while not is_end:
|
fake_inputs = self.gpt.compute_embeddings(
|
||||||
try:
|
gpt_cond_latent.to(self.device),
|
||||||
x, latent = next(gpt_generator)
|
text_tokens,
|
||||||
last_tokens += [x]
|
)
|
||||||
all_latents += [latent]
|
gpt_generator = self.gpt.get_generator(
|
||||||
except StopIteration:
|
fake_inputs=fake_inputs,
|
||||||
is_end = True
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
|
do_sample=do_sample,
|
||||||
|
num_beams=1,
|
||||||
|
num_return_sequences=1,
|
||||||
|
length_penalty=float(length_penalty),
|
||||||
|
repetition_penalty=float(repetition_penalty),
|
||||||
|
output_attentions=False,
|
||||||
|
output_hidden_states=True,
|
||||||
|
**hf_generate_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
last_tokens = []
|
||||||
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
all_latents = []
|
||||||
if length_scale != 1.0:
|
wav_gen_prev = None
|
||||||
gpt_latents = F.interpolate(
|
wav_overlap = None
|
||||||
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
is_end = False
|
||||||
).transpose(1, 2)
|
|
||||||
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
while not is_end:
|
||||||
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
try:
|
||||||
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
x, latent = next(gpt_generator)
|
||||||
)
|
last_tokens += [x]
|
||||||
last_tokens = []
|
all_latents += [latent]
|
||||||
yield wav_chunk
|
except StopIteration:
|
||||||
|
is_end = True
|
||||||
|
|
||||||
|
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
||||||
|
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
||||||
|
if length_scale != 1.0:
|
||||||
|
gpt_latents = F.interpolate(
|
||||||
|
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
||||||
|
).transpose(1, 2)
|
||||||
|
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
||||||
|
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
||||||
|
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
||||||
|
)
|
||||||
|
last_tokens = []
|
||||||
|
yield wav_chunk
|
||||||
|
|
||||||
|
if not text_streaming:
|
||||||
|
return
|
||||||
|
|
||||||
|
def inference_stream_text(
|
||||||
|
self,
|
||||||
|
language,
|
||||||
|
gpt_cond_latent,
|
||||||
|
speaker_embedding,
|
||||||
|
# Streaming
|
||||||
|
stream_chunk_size=20,
|
||||||
|
overlap_wav_len=1024,
|
||||||
|
# GPT inference
|
||||||
|
temperature=0.75,
|
||||||
|
length_penalty=1.0,
|
||||||
|
repetition_penalty=10.0,
|
||||||
|
top_k=50,
|
||||||
|
top_p=0.85,
|
||||||
|
do_sample=True,
|
||||||
|
speed=1.0,
|
||||||
|
**hf_generate_kwargs,
|
||||||
|
):
|
||||||
|
if self._stream_generator is not None:
|
||||||
|
raise Exception('Inference text-streaming already in progress. '
|
||||||
|
'Did you forget to call inference_finalize_text?')
|
||||||
|
|
||||||
|
# Arguments `text` and `enable_text_splitting` given through holder
|
||||||
|
self._stream_text_holder = [None, None]
|
||||||
|
self._stream_generator = self.inference_stream(
|
||||||
|
None,
|
||||||
|
language,
|
||||||
|
gpt_cond_latent,
|
||||||
|
speaker_embedding,
|
||||||
|
stream_chunk_size=stream_chunk_size,
|
||||||
|
overlap_wav_len=overlap_wav_len,
|
||||||
|
temperature=temperature,
|
||||||
|
length_penalty=length_penalty,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
do_sample=do_sample,
|
||||||
|
speed=speed,
|
||||||
|
**hf_generate_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the generator and return it
|
||||||
|
_ = next(self._stream_generator)
|
||||||
|
return self._stream_generator
|
||||||
|
|
||||||
|
def inference_add_text(self, text: str, enable_text_splitting=False):
|
||||||
|
if self._stream_generator is None:
|
||||||
|
raise Exception('Inference text-streaming not started. '
|
||||||
|
'Please call inference_stream_text first')
|
||||||
|
self._stream_text_holder[0] = text
|
||||||
|
self._stream_text_holder[1] = enable_text_splitting
|
||||||
|
|
||||||
|
def inference_finalize_text(self):
|
||||||
|
if self._stream_generator is None:
|
||||||
|
raise Exception('Inference text-streaming was not started '
|
||||||
|
'(start with inference_stream_text)')
|
||||||
|
# Finalize and reset the generator
|
||||||
|
self._stream_text_holder.clear()
|
||||||
|
try:
|
||||||
|
_ = next(self._stream_generator)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
self._stream_generator = None
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|
|
@ -220,7 +220,7 @@ torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
##### Streaming manually
|
##### Streaming inference
|
||||||
|
|
||||||
Here the goal is to stream the audio as it is being generated. This is useful for real-time applications.
|
Here the goal is to stream the audio as it is being generated. This is useful for real-time applications.
|
||||||
Streaming inference is typically slower than regular inference, but it allows to get a first chunk of audio faster.
|
Streaming inference is typically slower than regular inference, but it allows to get a first chunk of audio faster.
|
||||||
|
@ -253,16 +253,50 @@ chunks = model.inference_stream(
|
||||||
speaker_embedding
|
speaker_embedding
|
||||||
)
|
)
|
||||||
|
|
||||||
wav_chuncks = []
|
wav_chunks = []
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
print(f"Time to first chunck: {time.time() - t0}")
|
print(f"Time to first chunck: {time.time() - t0}")
|
||||||
print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
|
print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
|
||||||
wav_chuncks.append(chunk)
|
wav_chunks.append(chunk)
|
||||||
wav = torch.cat(wav_chuncks, dim=0)
|
wav = torch.cat(wav_chunks, dim=0)
|
||||||
torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
|
torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you also need to do text-streaming you can use `inference_stream_text`, like so:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ...same setup as before
|
||||||
|
|
||||||
|
def text_streaming_generator():
|
||||||
|
yield "It took me quite a long time to develop a voice and now that I have it I am not going to be silent."
|
||||||
|
yield "Having discovered not just one, but many voices, I will champion each."
|
||||||
|
|
||||||
|
print("Inference with text streaming...")
|
||||||
|
|
||||||
|
text_gen = text_streaming_generator()
|
||||||
|
inf_gen = model.inference_stream_text(
|
||||||
|
"en",
|
||||||
|
gpt_cond_latent,
|
||||||
|
speaker_embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
wav_chunks = []
|
||||||
|
for text in text_gen:
|
||||||
|
# Add text progressively
|
||||||
|
model.inference_add_text(text, enable_text_splitting=True)
|
||||||
|
for chunk in enumerate(inf_gen):
|
||||||
|
if chunk is None:
|
||||||
|
break # all chunks generated for the current text
|
||||||
|
print(f"Received chunk {len(wav_chunks)} of audio length {chunk.shape[-1]}")
|
||||||
|
wav_chunks.append(chunk)
|
||||||
|
|
||||||
|
# Call finalize to discard the inference generator
|
||||||
|
model.inference_finalize_text()
|
||||||
|
|
||||||
|
wav = torch.cat(wav_chunks, dim=0)
|
||||||
|
torchaudio.save("xtts_streaming_text.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
|
||||||
|
```
|
||||||
|
|
||||||
### Training
|
### Training
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue