mirror of https://github.com/coqui-ai/TTS.git
Drop fairseq for Hubert
This commit is contained in:
parent
c03768bb53
commit
17ac188958
|
@ -23,7 +23,7 @@ colormap = (
|
||||||
[0, 0, 0],
|
[0, 0, 0],
|
||||||
[183, 183, 183],
|
[183, 183, 183],
|
||||||
],
|
],
|
||||||
dtype=np.float,
|
dtype=float,
|
||||||
)
|
)
|
||||||
/ 255
|
/ 255
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
|
@ -46,11 +46,11 @@ class BarkConfig(BaseTTSConfig):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str = "bark"
|
model: str = "bark"
|
||||||
audio: BarkAudioConfig = BarkAudioConfig()
|
audio: BarkAudioConfig = field(default_factory=BarkAudioConfig)
|
||||||
num_chars: int = 0
|
num_chars: int = 0
|
||||||
semantic_config: GPTConfig = GPTConfig()
|
semantic_config: GPTConfig = field(default_factory=GPTConfig)
|
||||||
fine_config: FineGPTConfig = FineGPTConfig()
|
fine_config: FineGPTConfig = field(default_factory=FineGPTConfig)
|
||||||
coarse_config: GPTConfig = GPTConfig()
|
coarse_config: GPTConfig = field(default_factory=GPTConfig)
|
||||||
CONTEXT_WINDOW_SIZE: int = 1024
|
CONTEXT_WINDOW_SIZE: int = 1024
|
||||||
SEMANTIC_RATE_HZ: float = 49.9
|
SEMANTIC_RATE_HZ: float = 49.9
|
||||||
SEMANTIC_VOCAB_SIZE: int = 10_000
|
SEMANTIC_VOCAB_SIZE: int = 10_000
|
||||||
|
|
|
@ -10,11 +10,11 @@ License: MIT
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import fairseq
|
|
||||||
import torch
|
import torch
|
||||||
from einops import pack, unpack
|
from einops import pack, unpack
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchaudio.functional import resample
|
from torchaudio.functional import resample
|
||||||
|
from transformers import HubertModel
|
||||||
|
|
||||||
logging.root.setLevel(logging.ERROR)
|
logging.root.setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
@ -49,22 +49,11 @@ class CustomHubert(nn.Module):
|
||||||
self.target_sample_hz = target_sample_hz
|
self.target_sample_hz = target_sample_hz
|
||||||
self.seq_len_multiple_of = seq_len_multiple_of
|
self.seq_len_multiple_of = seq_len_multiple_of
|
||||||
self.output_layer = output_layer
|
self.output_layer = output_layer
|
||||||
|
|
||||||
if device is not None:
|
if device is not None:
|
||||||
self.to(device)
|
self.to(device)
|
||||||
|
self.model = HubertModel.from_pretrained("facebook/hubert-base-ls960")
|
||||||
model_path = Path(checkpoint_path)
|
|
||||||
|
|
||||||
assert model_path.exists(), f"path {checkpoint_path} does not exist"
|
|
||||||
|
|
||||||
checkpoint = torch.load(checkpoint_path)
|
|
||||||
load_model_input = {checkpoint_path: checkpoint}
|
|
||||||
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
|
|
||||||
|
|
||||||
if device is not None:
|
if device is not None:
|
||||||
model[0].to(device)
|
self.model.to(device)
|
||||||
|
|
||||||
self.model = model[0]
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -81,19 +70,13 @@ class CustomHubert(nn.Module):
|
||||||
if exists(self.seq_len_multiple_of):
|
if exists(self.seq_len_multiple_of):
|
||||||
wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
|
wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
|
||||||
|
|
||||||
embed = self.model(
|
outputs = self.model.forward(
|
||||||
wav_input,
|
wav_input,
|
||||||
features_only=True,
|
output_hidden_states=True,
|
||||||
mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
|
|
||||||
output_layer=self.output_layer,
|
|
||||||
)
|
)
|
||||||
|
embed = outputs["hidden_states"][self.output_layer]
|
||||||
embed, packed_shape = pack([embed["x"]], "* d")
|
embed, packed_shape = pack([embed], "* d")
|
||||||
|
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device)
|
||||||
# codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
|
|
||||||
|
|
||||||
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long()
|
|
||||||
|
|
||||||
if flatten:
|
if flatten:
|
||||||
return codebook_indices
|
return codebook_indices
|
||||||
|
|
||||||
|
|
|
@ -130,7 +130,7 @@ def generate_voice(
|
||||||
# generate semantic tokens
|
# generate semantic tokens
|
||||||
# Load the HuBERT model
|
# Load the HuBERT model
|
||||||
hubert_manager = HubertManager()
|
hubert_manager = HubertManager()
|
||||||
hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"])
|
# hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"])
|
||||||
hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"])
|
hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"])
|
||||||
|
|
||||||
hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device)
|
hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device)
|
||||||
|
|
|
@ -207,7 +207,7 @@ def maximum_path_numpy(value, mask, max_neg_val=None):
|
||||||
device = value.device
|
device = value.device
|
||||||
dtype = value.dtype
|
dtype = value.dtype
|
||||||
value = value.cpu().detach().numpy()
|
value = value.cpu().detach().numpy()
|
||||||
mask = mask.cpu().detach().numpy().astype(np.bool)
|
mask = mask.cpu().detach().numpy().astype(bool)
|
||||||
|
|
||||||
b, t_x, t_y = value.shape
|
b, t_x, t_y = value.shape
|
||||||
direction = np.zeros(value.shape, dtype=np.int64)
|
direction = np.zeros(value.shape, dtype=np.int64)
|
||||||
|
|
Loading…
Reference in New Issue