Drop fairseq for Hubert

This commit is contained in:
Eren G??lge 2023-06-26 19:27:48 +02:00
parent c03768bb53
commit 17ac188958
5 changed files with 16 additions and 33 deletions

View File

@ -23,7 +23,7 @@ colormap = (
[0, 0, 0], [0, 0, 0],
[183, 183, 183], [183, 183, 183],
], ],
dtype=np.float, dtype=float,
) )
/ 255 / 255
) )

View File

@ -1,5 +1,5 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Dict from typing import Dict
from TTS.tts.configs.shared_configs import BaseTTSConfig from TTS.tts.configs.shared_configs import BaseTTSConfig
@ -46,11 +46,11 @@ class BarkConfig(BaseTTSConfig):
""" """
model: str = "bark" model: str = "bark"
audio: BarkAudioConfig = BarkAudioConfig() audio: BarkAudioConfig = field(default_factory=BarkAudioConfig)
num_chars: int = 0 num_chars: int = 0
semantic_config: GPTConfig = GPTConfig() semantic_config: GPTConfig = field(default_factory=GPTConfig)
fine_config: FineGPTConfig = FineGPTConfig() fine_config: FineGPTConfig = field(default_factory=FineGPTConfig)
coarse_config: GPTConfig = GPTConfig() coarse_config: GPTConfig = field(default_factory=GPTConfig)
CONTEXT_WINDOW_SIZE: int = 1024 CONTEXT_WINDOW_SIZE: int = 1024
SEMANTIC_RATE_HZ: float = 49.9 SEMANTIC_RATE_HZ: float = 49.9
SEMANTIC_VOCAB_SIZE: int = 10_000 SEMANTIC_VOCAB_SIZE: int = 10_000

View File

@ -10,11 +10,11 @@ License: MIT
import logging import logging
from pathlib import Path from pathlib import Path
import fairseq
import torch import torch
from einops import pack, unpack from einops import pack, unpack
from torch import nn from torch import nn
from torchaudio.functional import resample from torchaudio.functional import resample
from transformers import HubertModel
logging.root.setLevel(logging.ERROR) logging.root.setLevel(logging.ERROR)
@ -49,22 +49,11 @@ class CustomHubert(nn.Module):
self.target_sample_hz = target_sample_hz self.target_sample_hz = target_sample_hz
self.seq_len_multiple_of = seq_len_multiple_of self.seq_len_multiple_of = seq_len_multiple_of
self.output_layer = output_layer self.output_layer = output_layer
if device is not None: if device is not None:
self.to(device) self.to(device)
self.model = HubertModel.from_pretrained("facebook/hubert-base-ls960")
model_path = Path(checkpoint_path)
assert model_path.exists(), f"path {checkpoint_path} does not exist"
checkpoint = torch.load(checkpoint_path)
load_model_input = {checkpoint_path: checkpoint}
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
if device is not None: if device is not None:
model[0].to(device) self.model.to(device)
self.model = model[0]
self.model.eval() self.model.eval()
@property @property
@ -81,19 +70,13 @@ class CustomHubert(nn.Module):
if exists(self.seq_len_multiple_of): if exists(self.seq_len_multiple_of):
wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
embed = self.model( outputs = self.model.forward(
wav_input, wav_input,
features_only=True, output_hidden_states=True,
mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
output_layer=self.output_layer,
) )
embed = outputs["hidden_states"][self.output_layer]
embed, packed_shape = pack([embed["x"]], "* d") embed, packed_shape = pack([embed], "* d")
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device)
# codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long()
if flatten: if flatten:
return codebook_indices return codebook_indices

View File

@ -130,7 +130,7 @@ def generate_voice(
# generate semantic tokens # generate semantic tokens
# Load the HuBERT model # Load the HuBERT model
hubert_manager = HubertManager() hubert_manager = HubertManager()
hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"]) # hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"])
hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"]) hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"])
hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device) hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device)

View File

@ -207,7 +207,7 @@ def maximum_path_numpy(value, mask, max_neg_val=None):
device = value.device device = value.device
dtype = value.dtype dtype = value.dtype
value = value.cpu().detach().numpy() value = value.cpu().detach().numpy()
mask = mask.cpu().detach().numpy().astype(np.bool) mask = mask.cpu().detach().numpy().astype(bool)
b, t_x, t_y = value.shape b, t_x, t_y = value.shape
direction = np.zeros(value.shape, dtype=np.int64) direction = np.zeros(value.shape, dtype=np.int64)