Merge pull request #22 from idiap/bark

fix(bark): add missing argument for load_voice()
This commit is contained in:
Enno Hermann 2024-05-16 15:21:33 +01:00 committed by GitHub
commit d73c9ccba3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 7 deletions

View File

@ -2,10 +2,11 @@ import logging
import os
import re
from glob import glob
from typing import Dict, List
from typing import Dict, List, Optional, Tuple
import librosa
import numpy as np
import numpy.typing as npt
import torch
import torchaudio
import tqdm
@ -48,7 +49,7 @@ def get_voices(extra_voice_dirs: List[str] = []): # pylint: disable=dangerous-d
return voices
def load_npz(npz_file):
def load_npz(npz_file: str) -> Tuple[npt.NDArray[np.int64], npt.NDArray[np.int64], npt.NDArray[np.int64]]:
x_history = np.load(npz_file)
semantic = x_history["semantic_prompt"]
coarse = x_history["coarse_prompt"]
@ -56,7 +57,11 @@ def load_npz(npz_file):
return semantic, coarse, fine
def load_voice(model, voice: str, extra_voice_dirs: List[str] = []): # pylint: disable=dangerous-default-value
def load_voice(
model, voice: str, extra_voice_dirs: List[str] = []
) -> Tuple[
Optional[npt.NDArray[np.int64]], Optional[npt.NDArray[np.int64]], Optional[npt.NDArray[np.int64]]
]: # pylint: disable=dangerous-default-value
if voice == "random":
return None, None, None
@ -107,11 +112,10 @@ def generate_voice(
model,
output_path,
):
"""Generate a new voice from a given audio and text prompt.
"""Generate a new voice from a given audio.
Args:
audio (np.ndarray): The audio to use as a base for the new voice.
text (str): Transcription of the audio you are clonning.
model (BarkModel): The BarkModel to use for generating the new voice.
output_path (str): The path to save the generated voice to.
"""

View File

@ -164,7 +164,7 @@ class Bark(BaseTTS):
return audio_arr, [x_semantic, c, f]
def generate_voice(self, audio, speaker_id, voice_dir):
"""Generate a voice from the given audio and text.
"""Generate a voice from the given audio.
Args:
audio (str): Path to the audio file.
@ -174,7 +174,7 @@ class Bark(BaseTTS):
if voice_dir is not None:
voice_dirs = [voice_dir]
try:
_ = load_voice(speaker_id, voice_dirs)
_ = load_voice(self, speaker_id, voice_dirs)
except (KeyError, FileNotFoundError):
output_path = os.path.join(voice_dir, speaker_id + ".npz")
os.makedirs(voice_dir, exist_ok=True)