mirror of https://github.com/coqui-ai/TTS.git
Implement FreeVC (#2451)
* Update .gitignore * Draft FreeVC implementation * Tests and relevant updates * Update API tests * Add missings * Update requirements * :( * Lazy handle for vc * Update docs for voice conversion * Make style
This commit is contained in:
parent
090cadf270
commit
d309f50e53
|
@ -137,7 +137,7 @@ VCTK-Corpus-removed-silence/*
|
||||||
# ignore training logs
|
# ignore training logs
|
||||||
trainer_*_log.txt
|
trainer_*_log.txt
|
||||||
|
|
||||||
# files used internally fro dev, test etc.
|
# files used internally for dev, test etc.
|
||||||
tests/outputs/*
|
tests/outputs/*
|
||||||
tests/train_outputs/*
|
tests/train_outputs/*
|
||||||
TODO.txt
|
TODO.txt
|
||||||
|
@ -168,3 +168,4 @@ internal/*
|
||||||
wandb
|
wandb
|
||||||
depot/*
|
depot/*
|
||||||
coqui_recipes/*
|
coqui_recipes/*
|
||||||
|
local_scripts/*
|
||||||
|
|
|
@ -802,5 +802,18 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"voice_conversion_models":{
|
||||||
|
"multilingual":{
|
||||||
|
"vctk":{
|
||||||
|
"freevc24":{
|
||||||
|
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.13.0_models/voice_conversion_models--multilingual--vctk--freevc24.zip",
|
||||||
|
"description": "FreeVC model trained on VCTK dataset from https://github.com/OlaWod/FreeVC",
|
||||||
|
"author": "Jing-Yi Li @OlaWod",
|
||||||
|
"license": "MIT",
|
||||||
|
"commit": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
89
TTS/api.py
89
TTS/api.py
|
@ -1,5 +1,7 @@
|
||||||
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from TTS.utils.audio.numpy_transforms import save_wav
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
from TTS.utils.synthesizer import Synthesizer
|
from TTS.utils.synthesizer import Synthesizer
|
||||||
|
|
||||||
|
@ -49,11 +51,14 @@ class TTS:
|
||||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||||
"""
|
"""
|
||||||
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)
|
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)
|
||||||
|
|
||||||
self.synthesizer = None
|
self.synthesizer = None
|
||||||
|
self.voice_converter = None
|
||||||
|
|
||||||
if model_name:
|
if model_name:
|
||||||
self.load_model_by_name(model_name, gpu)
|
self.load_tts_model_by_name(model_name, gpu)
|
||||||
if model_path:
|
if model_path:
|
||||||
self.load_model_by_path(
|
self.load_tts_model_by_path(
|
||||||
model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu
|
model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -96,12 +101,22 @@ class TTS:
|
||||||
|
|
||||||
def download_model_by_name(self, model_name: str):
|
def download_model_by_name(self, model_name: str):
|
||||||
model_path, config_path, model_item = self.manager.download_model(model_name)
|
model_path, config_path, model_item = self.manager.download_model(model_name)
|
||||||
if model_item["default_vocoder"] is None:
|
if model_item.get("default_vocoder") is None:
|
||||||
return model_path, config_path, None, None
|
return model_path, config_path, None, None
|
||||||
vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
|
vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
|
||||||
return model_path, config_path, vocoder_path, vocoder_config_path
|
return model_path, config_path, vocoder_path, vocoder_config_path
|
||||||
|
|
||||||
def load_model_by_name(self, model_name: str, gpu: bool = False):
|
def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
|
||||||
|
"""Load one of the voice conversion models by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name (str): Model name to load. You can list models by ```tts.models```.
|
||||||
|
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||||
|
"""
|
||||||
|
model_path, config_path, _, _ = self.download_model_by_name(model_name)
|
||||||
|
self.voice_converter = Synthesizer(vc_checkpoint=model_path, vc_config=config_path, use_cuda=gpu)
|
||||||
|
|
||||||
|
def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
|
||||||
"""Load one of 🐸TTS models by name.
|
"""Load one of 🐸TTS models by name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -127,7 +142,7 @@ class TTS:
|
||||||
use_cuda=gpu,
|
use_cuda=gpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_model_by_path(
|
def load_tts_model_by_path(
|
||||||
self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False
|
self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False
|
||||||
):
|
):
|
||||||
"""Load a model from a path.
|
"""Load a model from a path.
|
||||||
|
@ -219,3 +234,67 @@ class TTS:
|
||||||
"""
|
"""
|
||||||
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav)
|
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav)
|
||||||
self.synthesizer.save_wav(wav=wav, path=file_path)
|
self.synthesizer.save_wav(wav=wav, path=file_path)
|
||||||
|
|
||||||
|
def voice_conversion(
|
||||||
|
self,
|
||||||
|
sourve_wav: str,
|
||||||
|
target_wav: str,
|
||||||
|
):
|
||||||
|
"""Voice conversion with FreeVC. Convert source wav to target speaker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_wav (str):
|
||||||
|
Path to the source wav file.
|
||||||
|
target_wav (str):
|
||||||
|
Path to the target wav file.
|
||||||
|
"""
|
||||||
|
wav = self.synthesizer.voice_conversion(source_wav=sourve_wav, target_wav=target_wav)
|
||||||
|
return wav
|
||||||
|
|
||||||
|
def tts_with_vc(self, text: str, language: str = None, speaker_wav: str = None):
|
||||||
|
"""Convert text to speech with voice conversion.
|
||||||
|
|
||||||
|
It combines tts with voice conversion to fake voice cloning.
|
||||||
|
|
||||||
|
- Convert text to speech with tts.
|
||||||
|
- Convert the output wav to target speaker with voice conversion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str):
|
||||||
|
Input text to synthesize.
|
||||||
|
language (str, optional):
|
||||||
|
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
|
||||||
|
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
|
||||||
|
speaker_wav (str, optional):
|
||||||
|
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
||||||
|
# Lazy code... save it to a temp file to resample it while reading it for VC
|
||||||
|
self.tts_to_file(text=text, speaker=None, language=language, file_path=fp.name)
|
||||||
|
if self.voice_converter is None:
|
||||||
|
self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24")
|
||||||
|
wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav)
|
||||||
|
return wav
|
||||||
|
|
||||||
|
def tts_with_vc_to_file(
|
||||||
|
self, text: str, language: str = None, speaker_wav: str = None, file_path: str = "output.wav"
|
||||||
|
):
|
||||||
|
"""Convert text to speech with voice conversion and save to file.
|
||||||
|
|
||||||
|
Check `tts_with_vc` for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str):
|
||||||
|
Input text to synthesize.
|
||||||
|
language (str, optional):
|
||||||
|
Language code for multi-lingual models. You can check whether loaded model is multi-lingual
|
||||||
|
`tts.is_multi_lingual` and list available languages by `tts.languages`. Defaults to None.
|
||||||
|
speaker_wav (str, optional):
|
||||||
|
Path to a reference wav file to use for voice cloning with supporting models like YourTTS.
|
||||||
|
Defaults to None.
|
||||||
|
file_path (str, optional):
|
||||||
|
Output file path. Defaults to "output.wav".
|
||||||
|
"""
|
||||||
|
wav = self.tts_with_vc(text=text, language=language, speaker_wav=speaker_wav)
|
||||||
|
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
|
||||||
|
|
|
@ -100,6 +100,12 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
```
|
```
|
||||||
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
$ tts --text "Text for TTS" --out_path output/path/speech.wav --model_path path/to/config.json --config_path path/to/model.pth --speakers_file_path path/to/speaker.json --speaker_idx <speaker_id>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Voice Conversion Models
|
||||||
|
|
||||||
|
```
|
||||||
|
$ tts --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --source_wav <path/to/speaker/wav> --target_wav <path/to/reference/wav>
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
# We remove Markdown code formatting programmatically here to allow us to copy-and-paste from main README to keep
|
# We remove Markdown code formatting programmatically here to allow us to copy-and-paste from main README to keep
|
||||||
# documentation in sync more easily.
|
# documentation in sync more easily.
|
||||||
|
@ -245,6 +251,20 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
default=True,
|
default=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# voice conversion args
|
||||||
|
parser.add_argument(
|
||||||
|
"--source_wav",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Original audio file to convert in the voice of the target_wav",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--target_wav",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Target audio file to convert in the voice of the source_wav",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# print the description if either text or list_models is not set
|
# print the description if either text or list_models is not set
|
||||||
|
@ -256,6 +276,8 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
args.reference_wav,
|
args.reference_wav,
|
||||||
args.model_info_by_idx,
|
args.model_info_by_idx,
|
||||||
args.model_info_by_name,
|
args.model_info_by_name,
|
||||||
|
args.source_wav,
|
||||||
|
args.target_wav,
|
||||||
]
|
]
|
||||||
if not any(check_args):
|
if not any(check_args):
|
||||||
parser.parse_args(["-h"])
|
parser.parse_args(["-h"])
|
||||||
|
@ -264,21 +286,23 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
path = Path(__file__).parent / "../.models.json"
|
path = Path(__file__).parent / "../.models.json"
|
||||||
manager = ModelManager(path, progress_bar=args.progress_bar)
|
manager = ModelManager(path, progress_bar=args.progress_bar)
|
||||||
|
|
||||||
model_path = None
|
tts_path = None
|
||||||
config_path = None
|
tts_config_path = None
|
||||||
speakers_file_path = None
|
speakers_file_path = None
|
||||||
language_ids_file_path = None
|
language_ids_file_path = None
|
||||||
vocoder_path = None
|
vocoder_path = None
|
||||||
vocoder_config_path = None
|
vocoder_config_path = None
|
||||||
encoder_path = None
|
encoder_path = None
|
||||||
encoder_config_path = None
|
encoder_config_path = None
|
||||||
|
vc_path = None
|
||||||
|
vc_config_path = None
|
||||||
|
|
||||||
# CASE1 #list : list pre-trained TTS models
|
# CASE1 #list : list pre-trained TTS models
|
||||||
if args.list_models:
|
if args.list_models:
|
||||||
manager.list_models()
|
manager.list_models()
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
# CASE2 #info : model info of pre-trained TTS models
|
# CASE2 #info : model info for pre-trained TTS models
|
||||||
if args.model_info_by_idx:
|
if args.model_info_by_idx:
|
||||||
model_query = args.model_info_by_idx
|
model_query = args.model_info_by_idx
|
||||||
manager.model_info_by_idx(model_query)
|
manager.model_info_by_idx(model_query)
|
||||||
|
@ -292,15 +316,27 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
# CASE3: load pre-trained model paths
|
# CASE3: load pre-trained model paths
|
||||||
if args.model_name is not None and not args.model_path:
|
if args.model_name is not None and not args.model_path:
|
||||||
model_path, config_path, model_item = manager.download_model(args.model_name)
|
model_path, config_path, model_item = manager.download_model(args.model_name)
|
||||||
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
|
||||||
|
|
||||||
|
# tts model
|
||||||
|
if model_item["model_type"] == "tts_models":
|
||||||
|
tts_path = model_path
|
||||||
|
tts_config_path = config_path
|
||||||
|
if "default_vocoder" in model_item:
|
||||||
|
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
||||||
|
|
||||||
|
# voice conversion model
|
||||||
|
if model_item["model_type"] == "voice_conversion_models":
|
||||||
|
vc_path = model_path
|
||||||
|
vc_config_path = config_path
|
||||||
|
|
||||||
|
# load vocoder
|
||||||
if args.vocoder_name is not None and not args.vocoder_path:
|
if args.vocoder_name is not None and not args.vocoder_path:
|
||||||
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
||||||
|
|
||||||
# CASE4: set custom model paths
|
# CASE4: set custom model paths
|
||||||
if args.model_path is not None:
|
if args.model_path is not None:
|
||||||
model_path = args.model_path
|
tts_path = args.model_path
|
||||||
config_path = args.config_path
|
tts_config_path = args.config_path
|
||||||
speakers_file_path = args.speakers_file_path
|
speakers_file_path = args.speakers_file_path
|
||||||
language_ids_file_path = args.language_ids_file_path
|
language_ids_file_path = args.language_ids_file_path
|
||||||
|
|
||||||
|
@ -314,14 +350,16 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
|
|
||||||
# load models
|
# load models
|
||||||
synthesizer = Synthesizer(
|
synthesizer = Synthesizer(
|
||||||
model_path,
|
tts_path,
|
||||||
config_path,
|
tts_config_path,
|
||||||
speakers_file_path,
|
speakers_file_path,
|
||||||
language_ids_file_path,
|
language_ids_file_path,
|
||||||
vocoder_path,
|
vocoder_path,
|
||||||
vocoder_config_path,
|
vocoder_config_path,
|
||||||
encoder_path,
|
encoder_path,
|
||||||
encoder_config_path,
|
encoder_config_path,
|
||||||
|
vc_path,
|
||||||
|
vc_config_path,
|
||||||
args.use_cuda,
|
args.use_cuda,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -354,16 +392,22 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
print(" > Text: {}".format(args.text))
|
print(" > Text: {}".format(args.text))
|
||||||
|
|
||||||
# kick it
|
# kick it
|
||||||
wav = synthesizer.tts(
|
if tts_path is not None:
|
||||||
args.text,
|
wav = synthesizer.tts(
|
||||||
args.speaker_idx,
|
args.text,
|
||||||
args.language_idx,
|
args.speaker_idx,
|
||||||
args.speaker_wav,
|
args.language_idx,
|
||||||
reference_wav=args.reference_wav,
|
args.speaker_wav,
|
||||||
style_wav=args.capacitron_style_wav,
|
reference_wav=args.reference_wav,
|
||||||
style_text=args.capacitron_style_text,
|
style_wav=args.capacitron_style_wav,
|
||||||
reference_speaker_name=args.reference_speaker_idx,
|
style_text=args.capacitron_style_text,
|
||||||
)
|
reference_speaker_name=args.reference_speaker_idx,
|
||||||
|
)
|
||||||
|
elif vc_path is not None:
|
||||||
|
wav = synthesizer.voice_conversion(
|
||||||
|
source_wav=args.source_wav,
|
||||||
|
target_wav=args.target_wav,
|
||||||
|
)
|
||||||
|
|
||||||
# save the results
|
# save the results
|
||||||
print(" > Saving output to {}".format(args.out_path))
|
print(" > Saving output to {}".format(args.out_path))
|
||||||
|
|
|
@ -37,7 +37,7 @@ def register_config(model_name: str) -> Coqpit:
|
||||||
"""
|
"""
|
||||||
config_class = None
|
config_class = None
|
||||||
config_name = model_name + "_config"
|
config_name = model_name + "_config"
|
||||||
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs"]
|
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder.configs", "TTS.vc.configs"]
|
||||||
for path in paths:
|
for path in paths:
|
||||||
try:
|
try:
|
||||||
config_class = find_module(path, config_name)
|
config_class = find_module(path, config_name)
|
||||||
|
|
|
@ -27,6 +27,8 @@ class BaseTTS(BaseTrainerModel):
|
||||||
It defines common `tts` specific functions on top of `Model` implementation.
|
It defines common `tts` specific functions on top of `Model` implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
MODEL_TYPE = "tts"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
|
|
|
@ -85,6 +85,7 @@ def to_camel(text):
|
||||||
text = text.capitalize()
|
text = text.capitalize()
|
||||||
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||||
text = text.replace("Tts", "TTS")
|
text = text.replace("Tts", "TTS")
|
||||||
|
text = text.replace("vc", "VC")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -185,6 +185,13 @@ class ModelManager(object):
|
||||||
"""
|
"""
|
||||||
return self._list_for_model_type("vocoder_models")
|
return self._list_for_model_type("vocoder_models")
|
||||||
|
|
||||||
|
def list_vc_models(self):
|
||||||
|
"""Print all the voice conversion models and return a list of model names
|
||||||
|
|
||||||
|
Format is `language/dataset/model`
|
||||||
|
"""
|
||||||
|
return self._list_for_model_type("voice_conversion_models")
|
||||||
|
|
||||||
def list_langs(self):
|
def list_langs(self):
|
||||||
"""Print all the available languages"""
|
"""Print all the available languages"""
|
||||||
print(" Name format: type/language")
|
print(" Name format: type/language")
|
||||||
|
@ -234,6 +241,7 @@ class ModelManager(object):
|
||||||
model_type, lang, dataset, model = model_name.split("/")
|
model_type, lang, dataset, model = model_name.split("/")
|
||||||
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
|
||||||
model_item = self.models_dict[model_type][lang][dataset][model]
|
model_item = self.models_dict[model_type][lang][dataset][model]
|
||||||
|
model_item["model_type"] = model_type
|
||||||
# set the model specific output path
|
# set the model specific output path
|
||||||
output_path = os.path.join(self.output_prefix, model_full_name)
|
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||||
if os.path.exists(output_path):
|
if os.path.exists(output_path):
|
||||||
|
|
|
@ -12,6 +12,8 @@ from TTS.tts.models import setup_model as setup_tts_model
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
|
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.audio.numpy_transforms import save_wav
|
||||||
|
from TTS.vc.models import setup_model as setup_vc_model
|
||||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||||
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
||||||
|
|
||||||
|
@ -19,14 +21,16 @@ from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
||||||
class Synthesizer(object):
|
class Synthesizer(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tts_checkpoint: str,
|
tts_checkpoint: str = "",
|
||||||
tts_config_path: str,
|
tts_config_path: str = "",
|
||||||
tts_speakers_file: str = "",
|
tts_speakers_file: str = "",
|
||||||
tts_languages_file: str = "",
|
tts_languages_file: str = "",
|
||||||
vocoder_checkpoint: str = "",
|
vocoder_checkpoint: str = "",
|
||||||
vocoder_config: str = "",
|
vocoder_config: str = "",
|
||||||
encoder_checkpoint: str = "",
|
encoder_checkpoint: str = "",
|
||||||
encoder_config: str = "",
|
encoder_config: str = "",
|
||||||
|
vc_checkpoint: str = "",
|
||||||
|
vc_config: str = "",
|
||||||
use_cuda: bool = False,
|
use_cuda: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
|
"""General 🐸 TTS interface for inference. It takes a tts and a vocoder
|
||||||
|
@ -41,12 +45,14 @@ class Synthesizer(object):
|
||||||
TODO: set the segmenter based on the source language
|
TODO: set the segmenter based on the source language
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tts_checkpoint (str): path to the tts model file.
|
tts_checkpoint (str, optional): path to the tts model file.
|
||||||
tts_config_path (str): path to the tts config file.
|
tts_config_path (str, optional): path to the tts config file.
|
||||||
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
|
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
|
||||||
vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
|
vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
|
||||||
encoder_checkpoint (str, optional): path to the speaker encoder model file. Defaults to `""`,
|
encoder_checkpoint (str, optional): path to the speaker encoder model file. Defaults to `""`,
|
||||||
encoder_config (str, optional): path to the speaker encoder config file. Defaults to `""`,
|
encoder_config (str, optional): path to the speaker encoder config file. Defaults to `""`,
|
||||||
|
vc_checkpoint (str, optional): path to the voice conversion model file. Defaults to `""`,
|
||||||
|
vc_config (str, optional): path to the voice conversion config file. Defaults to `""`,
|
||||||
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
||||||
"""
|
"""
|
||||||
self.tts_checkpoint = tts_checkpoint
|
self.tts_checkpoint = tts_checkpoint
|
||||||
|
@ -57,10 +63,13 @@ class Synthesizer(object):
|
||||||
self.vocoder_config = vocoder_config
|
self.vocoder_config = vocoder_config
|
||||||
self.encoder_checkpoint = encoder_checkpoint
|
self.encoder_checkpoint = encoder_checkpoint
|
||||||
self.encoder_config = encoder_config
|
self.encoder_config = encoder_config
|
||||||
|
self.vc_checkpoint = vc_checkpoint
|
||||||
|
self.vc_config = vc_config
|
||||||
self.use_cuda = use_cuda
|
self.use_cuda = use_cuda
|
||||||
|
|
||||||
self.tts_model = None
|
self.tts_model = None
|
||||||
self.vocoder_model = None
|
self.vocoder_model = None
|
||||||
|
self.vc_model = None
|
||||||
self.speaker_manager = None
|
self.speaker_manager = None
|
||||||
self.tts_speakers = {}
|
self.tts_speakers = {}
|
||||||
self.language_manager = None
|
self.language_manager = None
|
||||||
|
@ -72,12 +81,19 @@ class Synthesizer(object):
|
||||||
|
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
||||||
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
|
|
||||||
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
if tts_checkpoint:
|
||||||
|
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
|
||||||
|
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
||||||
|
|
||||||
if vocoder_checkpoint:
|
if vocoder_checkpoint:
|
||||||
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
|
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
|
||||||
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
|
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
|
||||||
|
|
||||||
|
if vc_checkpoint:
|
||||||
|
self._load_vc(vc_checkpoint, vc_config, use_cuda)
|
||||||
|
self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_segmenter(lang: str):
|
def _get_segmenter(lang: str):
|
||||||
"""get the sentence segmenter for the given language.
|
"""get the sentence segmenter for the given language.
|
||||||
|
@ -90,6 +106,26 @@ class Synthesizer(object):
|
||||||
"""
|
"""
|
||||||
return pysbd.Segmenter(language=lang, clean=True)
|
return pysbd.Segmenter(language=lang, clean=True)
|
||||||
|
|
||||||
|
def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> None:
|
||||||
|
"""Load the voice conversion model.
|
||||||
|
|
||||||
|
1. Load the model config.
|
||||||
|
2. Init the model from the config.
|
||||||
|
3. Load the model weights.
|
||||||
|
4. Move the model to the GPU if CUDA is enabled.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vc_checkpoint (str): path to the model checkpoint.
|
||||||
|
tts_config_path (str): path to the model config file.
|
||||||
|
use_cuda (bool): enable/disable CUDA use.
|
||||||
|
"""
|
||||||
|
# pylint: disable=global-statement
|
||||||
|
self.vc_config = load_config(vc_config_path)
|
||||||
|
self.vc_model = setup_vc_model(config=self.vc_config)
|
||||||
|
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
|
||||||
|
if use_cuda:
|
||||||
|
self.vc_model.cuda()
|
||||||
|
|
||||||
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
||||||
"""Load the TTS model.
|
"""Load the TTS model.
|
||||||
|
|
||||||
|
@ -168,7 +204,11 @@ class Synthesizer(object):
|
||||||
path (str): output path to save the waveform.
|
path (str): output path to save the waveform.
|
||||||
"""
|
"""
|
||||||
wav = np.array(wav)
|
wav = np.array(wav)
|
||||||
self.tts_model.ap.save_wav(wav, path, self.output_sample_rate)
|
save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate)
|
||||||
|
|
||||||
|
def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]:
|
||||||
|
output_wav = self.vc_model.voice_conversion(source_wav, target_wav)
|
||||||
|
return output_wav
|
||||||
|
|
||||||
def tts(
|
def tts(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from TTS.vc.configs.shared_configs import BaseVCConfig
|
||||||
|
from TTS.vc.models.freevc import FreeVCArgs, FreeVCAudioConfig, FreeVCConfig
|
|
@ -0,0 +1,155 @@
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from coqpit import Coqpit, check_argument
|
||||||
|
|
||||||
|
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseVCConfig(BaseTrainingConfig):
|
||||||
|
"""Shared parameters among all the tts models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
audio (BaseAudioConfig):
|
||||||
|
Audio processor config object instance.
|
||||||
|
|
||||||
|
batch_group_size (int):
|
||||||
|
Size of the batch groups used for bucketing. By default, the dataloader orders samples by the sequence
|
||||||
|
length for a more efficient and stable training. If `batch_group_size > 1` then it performs bucketing to
|
||||||
|
prevent using the same batches for each epoch.
|
||||||
|
|
||||||
|
loss_masking (bool):
|
||||||
|
enable / disable masking loss values against padded segments of samples in a batch.
|
||||||
|
|
||||||
|
min_text_len (int):
|
||||||
|
Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0.
|
||||||
|
|
||||||
|
max_text_len (int):
|
||||||
|
Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf").
|
||||||
|
|
||||||
|
min_audio_len (int):
|
||||||
|
Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0.
|
||||||
|
|
||||||
|
max_audio_len (int):
|
||||||
|
Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the
|
||||||
|
dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an
|
||||||
|
OOM error in training. Defaults to float("inf").
|
||||||
|
|
||||||
|
compute_f0 (int):
|
||||||
|
(Not in use yet).
|
||||||
|
|
||||||
|
compute_energy (int):
|
||||||
|
(Not in use yet).
|
||||||
|
|
||||||
|
compute_linear_spec (bool):
|
||||||
|
If True data loader computes and returns linear spectrograms alongside the other data.
|
||||||
|
|
||||||
|
precompute_num_workers (int):
|
||||||
|
Number of workers to precompute features. Defaults to 0.
|
||||||
|
|
||||||
|
use_noise_augment (bool):
|
||||||
|
Augment the input audio with random noise.
|
||||||
|
|
||||||
|
start_by_longest (bool):
|
||||||
|
If True, the data loader will start loading the longest batch first. It is useful for checking OOM issues.
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
|
shuffle (bool):
|
||||||
|
If True, the data loader will shuffle the dataset when there is not sampler defined. Defaults to True.
|
||||||
|
|
||||||
|
drop_last (bool):
|
||||||
|
If True, the data loader will drop the last batch if it is not complete. It helps to prevent
|
||||||
|
issues that emerge from the partial batch statistics. Defaults to True.
|
||||||
|
|
||||||
|
add_blank (bool):
|
||||||
|
Add blank characters between each other two characters. It improves performance for some models at expense
|
||||||
|
of slower run-time due to the longer input sequence.
|
||||||
|
|
||||||
|
datasets (List[BaseDatasetConfig]):
|
||||||
|
List of datasets used for training. If multiple datasets are provided, they are merged and used together
|
||||||
|
for training.
|
||||||
|
|
||||||
|
optimizer (str):
|
||||||
|
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
|
||||||
|
Defaults to ``.
|
||||||
|
|
||||||
|
optimizer_params (dict):
|
||||||
|
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||||
|
|
||||||
|
lr_scheduler (str):
|
||||||
|
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
|
||||||
|
`TTS.utils.training`. Defaults to ``.
|
||||||
|
|
||||||
|
lr_scheduler_params (dict):
|
||||||
|
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
|
||||||
|
|
||||||
|
test_sentences (List[str]):
|
||||||
|
List of sentences to be used at testing. Defaults to '[]'
|
||||||
|
|
||||||
|
eval_split_max_size (int):
|
||||||
|
Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled).
|
||||||
|
|
||||||
|
eval_split_size (float):
|
||||||
|
If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set.
|
||||||
|
If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%).
|
||||||
|
|
||||||
|
use_speaker_weighted_sampler (bool):
|
||||||
|
Enable / Disable the batch balancer by speaker. Defaults to ```False```.
|
||||||
|
|
||||||
|
speaker_weighted_sampler_alpha (float):
|
||||||
|
Number that control the influence of the speaker sampler weights. Defaults to ```1.0```.
|
||||||
|
|
||||||
|
use_language_weighted_sampler (bool):
|
||||||
|
Enable / Disable the batch balancer by language. Defaults to ```False```.
|
||||||
|
|
||||||
|
language_weighted_sampler_alpha (float):
|
||||||
|
Number that control the influence of the language sampler weights. Defaults to ```1.0```.
|
||||||
|
|
||||||
|
use_length_weighted_sampler (bool):
|
||||||
|
Enable / Disable the batch balancer by audio length. If enabled the dataset will be divided
|
||||||
|
into 10 buckets considering the min and max audio of the dataset. The sampler weights will be
|
||||||
|
computed forcing to have the same quantity of data for each bucket in each training batch. Defaults to ```False```.
|
||||||
|
|
||||||
|
length_weighted_sampler_alpha (float):
|
||||||
|
Number that control the influence of the length sampler weights. Defaults to ```1.0```.
|
||||||
|
"""
|
||||||
|
|
||||||
|
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||||
|
# training params
|
||||||
|
batch_group_size: int = 0
|
||||||
|
loss_masking: bool = None
|
||||||
|
# dataloading
|
||||||
|
min_audio_len: int = 1
|
||||||
|
max_audio_len: int = float("inf")
|
||||||
|
min_text_len: int = 1
|
||||||
|
max_text_len: int = float("inf")
|
||||||
|
compute_f0: bool = False
|
||||||
|
compute_energy: bool = False
|
||||||
|
compute_linear_spec: bool = False
|
||||||
|
precompute_num_workers: int = 0
|
||||||
|
use_noise_augment: bool = False
|
||||||
|
start_by_longest: bool = False
|
||||||
|
shuffle: bool = False
|
||||||
|
drop_last: bool = False
|
||||||
|
# dataset
|
||||||
|
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||||
|
# optimizer
|
||||||
|
optimizer: str = "radam"
|
||||||
|
optimizer_params: dict = None
|
||||||
|
# scheduler
|
||||||
|
lr_scheduler: str = None
|
||||||
|
lr_scheduler_params: dict = field(default_factory=lambda: {})
|
||||||
|
# testing
|
||||||
|
test_sentences: List[str] = field(default_factory=lambda: [])
|
||||||
|
# evaluation
|
||||||
|
eval_split_max_size: int = None
|
||||||
|
eval_split_size: float = 0.01
|
||||||
|
# weighted samplers
|
||||||
|
use_speaker_weighted_sampler: bool = False
|
||||||
|
speaker_weighted_sampler_alpha: float = 1.0
|
||||||
|
use_language_weighted_sampler: bool = False
|
||||||
|
language_weighted_sampler_alpha: float = 1.0
|
||||||
|
use_length_weighted_sampler: bool = False
|
||||||
|
length_weighted_sampler_alpha: float = 1.0
|
|
@ -0,0 +1,17 @@
|
||||||
|
import importlib
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
|
||||||
|
def to_camel(text):
|
||||||
|
text = text.capitalize()
|
||||||
|
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC":
|
||||||
|
print(" > Using model: {}".format(config.model))
|
||||||
|
# fetch the right model implementation.
|
||||||
|
if "model" in config and config["model"].lower() == "freevc":
|
||||||
|
MyModel = importlib.import_module("TTS.vc.models.freevc").FreeVC
|
||||||
|
model = MyModel.init_from_config(config, samples)
|
||||||
|
return model
|
|
@ -0,0 +1,429 @@
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from coqpit import Coqpit
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.sampler import WeightedRandomSampler
|
||||||
|
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||||
|
|
||||||
|
from TTS.model import BaseTrainerModel
|
||||||
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
|
from TTS.tts.utils.data import get_length_balancer_weights
|
||||||
|
from TTS.tts.utils.languages import LanguageManager, get_language_balancer_weights
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_balancer_weights
|
||||||
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
|
||||||
|
# pylint: skip-file
|
||||||
|
|
||||||
|
|
||||||
|
class BaseVC(BaseTrainerModel):
|
||||||
|
"""Base `vc` class. Every new `vc` model must inherit this.
|
||||||
|
|
||||||
|
It defines common `vc` specific functions on top of `Model` implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODEL_TYPE = "vc"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Coqpit,
|
||||||
|
ap: "AudioProcessor",
|
||||||
|
speaker_manager: SpeakerManager = None,
|
||||||
|
language_manager: LanguageManager = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.ap = ap
|
||||||
|
self.speaker_manager = speaker_manager
|
||||||
|
self.language_manager = language_manager
|
||||||
|
self._set_model_args(config)
|
||||||
|
|
||||||
|
def _set_model_args(self, config: Coqpit):
|
||||||
|
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
|
||||||
|
|
||||||
|
`ModelArgs` has all the fields reuqired to initialize the model architecture.
|
||||||
|
|
||||||
|
`ModelConfig` has all the fields required for training, inference and containes `ModelArgs`.
|
||||||
|
|
||||||
|
If the config is for training with a name like "*Config", then the model args are embeded in the
|
||||||
|
config.model_args
|
||||||
|
|
||||||
|
If the config is for the model with a name like "*Args", then we assign the directly.
|
||||||
|
"""
|
||||||
|
# don't use isintance not to import recursively
|
||||||
|
if "Config" in config.__class__.__name__:
|
||||||
|
self.config = config
|
||||||
|
self.args = config.model_args
|
||||||
|
elif "Args" in config.__class__.__name__:
|
||||||
|
self.args = config
|
||||||
|
else:
|
||||||
|
raise ValueError("config must be either a *Config or *Args")
|
||||||
|
|
||||||
|
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||||
|
"""Initialize a speaker embedding layer if needen and define expected embedding channel size for defining
|
||||||
|
`in_channels` size of the connected layers.
|
||||||
|
|
||||||
|
This implementation yields 3 possible outcomes:
|
||||||
|
|
||||||
|
1. If `config.use_speaker_embedding` and `config.use_d_vector_file are False, do nothing.
|
||||||
|
2. If `config.use_d_vector_file` is True, set expected embedding channel size to `config.d_vector_dim` or 512.
|
||||||
|
3. If `config.use_speaker_embedding`, initialize a speaker embedding layer with channel size of
|
||||||
|
`config.d_vector_dim` or 512.
|
||||||
|
|
||||||
|
You can override this function for new models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Coqpit): Model configuration.
|
||||||
|
"""
|
||||||
|
# set number of speakers
|
||||||
|
if self.speaker_manager is not None:
|
||||||
|
self.num_speakers = self.speaker_manager.num_speakers
|
||||||
|
elif hasattr(config, "num_speakers"):
|
||||||
|
self.num_speakers = config.num_speakers
|
||||||
|
|
||||||
|
# set ultimate speaker embedding size
|
||||||
|
if config.use_speaker_embedding or config.use_d_vector_file:
|
||||||
|
self.embedded_speaker_dim = (
|
||||||
|
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
||||||
|
)
|
||||||
|
# init speaker embedding layer
|
||||||
|
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||||
|
print(" > Init speaker_embedding layer.")
|
||||||
|
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||||
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
|
|
||||||
|
def get_aux_input(self, **kwargs) -> Dict:
|
||||||
|
"""Prepare and return `aux_input` used by `forward()`"""
|
||||||
|
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}
|
||||||
|
|
||||||
|
def get_aux_input_from_test_sentences(self, sentence_info):
|
||||||
|
if hasattr(self.config, "model_args"):
|
||||||
|
config = self.config.model_args
|
||||||
|
else:
|
||||||
|
config = self.config
|
||||||
|
|
||||||
|
# extract speaker and language info
|
||||||
|
text, speaker_name, style_wav, language_name = None, None, None, None
|
||||||
|
|
||||||
|
if isinstance(sentence_info, list):
|
||||||
|
if len(sentence_info) == 1:
|
||||||
|
text = sentence_info[0]
|
||||||
|
elif len(sentence_info) == 2:
|
||||||
|
text, speaker_name = sentence_info
|
||||||
|
elif len(sentence_info) == 3:
|
||||||
|
text, speaker_name, style_wav = sentence_info
|
||||||
|
elif len(sentence_info) == 4:
|
||||||
|
text, speaker_name, style_wav, language_name = sentence_info
|
||||||
|
else:
|
||||||
|
text = sentence_info
|
||||||
|
|
||||||
|
# get speaker id/d_vector
|
||||||
|
speaker_id, d_vector, language_id = None, None, None
|
||||||
|
if self.speaker_manager is not None:
|
||||||
|
if config.use_d_vector_file:
|
||||||
|
if speaker_name is None:
|
||||||
|
d_vector = self.speaker_manager.get_random_embedding()
|
||||||
|
else:
|
||||||
|
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
|
||||||
|
elif config.use_speaker_embedding:
|
||||||
|
if speaker_name is None:
|
||||||
|
speaker_id = self.speaker_manager.get_random_id()
|
||||||
|
else:
|
||||||
|
speaker_id = self.speaker_manager.name_to_id[speaker_name]
|
||||||
|
|
||||||
|
# get language id
|
||||||
|
if self.language_manager is not None and config.use_language_embedding and language_name is not None:
|
||||||
|
language_id = self.language_manager.name_to_id[language_name]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"text": text,
|
||||||
|
"speaker_id": speaker_id,
|
||||||
|
"style_wav": style_wav,
|
||||||
|
"d_vector": d_vector,
|
||||||
|
"language_id": language_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
def format_batch(self, batch: Dict) -> Dict:
|
||||||
|
"""Generic batch formatting for `VCDataset`.
|
||||||
|
|
||||||
|
You must override this if you use a custom dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (Dict): [description]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: [description]
|
||||||
|
"""
|
||||||
|
# setup input batch
|
||||||
|
text_input = batch["token_id"]
|
||||||
|
text_lengths = batch["token_id_lengths"]
|
||||||
|
speaker_names = batch["speaker_names"]
|
||||||
|
linear_input = batch["linear"]
|
||||||
|
mel_input = batch["mel"]
|
||||||
|
mel_lengths = batch["mel_lengths"]
|
||||||
|
stop_targets = batch["stop_targets"]
|
||||||
|
item_idx = batch["item_idxs"]
|
||||||
|
d_vectors = batch["d_vectors"]
|
||||||
|
speaker_ids = batch["speaker_ids"]
|
||||||
|
attn_mask = batch["attns"]
|
||||||
|
waveform = batch["waveform"]
|
||||||
|
pitch = batch["pitch"]
|
||||||
|
energy = batch["energy"]
|
||||||
|
language_ids = batch["language_ids"]
|
||||||
|
max_text_length = torch.max(text_lengths.float())
|
||||||
|
max_spec_length = torch.max(mel_lengths.float())
|
||||||
|
|
||||||
|
# compute durations from attention masks
|
||||||
|
durations = None
|
||||||
|
if attn_mask is not None:
|
||||||
|
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
||||||
|
for idx, am in enumerate(attn_mask):
|
||||||
|
# compute raw durations
|
||||||
|
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
|
||||||
|
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
||||||
|
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
||||||
|
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
||||||
|
dur[c_idxs] = counts
|
||||||
|
# smooth the durations and set any 0 duration to 1
|
||||||
|
# by cutting off from the largest duration indeces.
|
||||||
|
extra_frames = dur.sum() - mel_lengths[idx]
|
||||||
|
largest_idxs = torch.argsort(-dur)[:extra_frames]
|
||||||
|
dur[largest_idxs] -= 1
|
||||||
|
assert (
|
||||||
|
dur.sum() == mel_lengths[idx]
|
||||||
|
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||||
|
durations[idx, : text_lengths[idx]] = dur
|
||||||
|
|
||||||
|
# set stop targets wrt reduction factor
|
||||||
|
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
||||||
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||||
|
stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"text_input": text_input,
|
||||||
|
"text_lengths": text_lengths,
|
||||||
|
"speaker_names": speaker_names,
|
||||||
|
"mel_input": mel_input,
|
||||||
|
"mel_lengths": mel_lengths,
|
||||||
|
"linear_input": linear_input,
|
||||||
|
"stop_targets": stop_targets,
|
||||||
|
"stop_target_lengths": stop_target_lengths,
|
||||||
|
"attn_mask": attn_mask,
|
||||||
|
"durations": durations,
|
||||||
|
"speaker_ids": speaker_ids,
|
||||||
|
"d_vectors": d_vectors,
|
||||||
|
"max_text_length": float(max_text_length),
|
||||||
|
"max_spec_length": float(max_spec_length),
|
||||||
|
"item_idx": item_idx,
|
||||||
|
"waveform": waveform,
|
||||||
|
"pitch": pitch,
|
||||||
|
"energy": energy,
|
||||||
|
"language_ids": language_ids,
|
||||||
|
"audio_unique_names": batch["audio_unique_names"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1):
|
||||||
|
weights = None
|
||||||
|
data_items = dataset.samples
|
||||||
|
|
||||||
|
if getattr(config, "use_language_weighted_sampler", False):
|
||||||
|
alpha = getattr(config, "language_weighted_sampler_alpha", 1.0)
|
||||||
|
print(" > Using Language weighted sampler with alpha:", alpha)
|
||||||
|
weights = get_language_balancer_weights(data_items) * alpha
|
||||||
|
|
||||||
|
if getattr(config, "use_speaker_weighted_sampler", False):
|
||||||
|
alpha = getattr(config, "speaker_weighted_sampler_alpha", 1.0)
|
||||||
|
print(" > Using Speaker weighted sampler with alpha:", alpha)
|
||||||
|
if weights is not None:
|
||||||
|
weights += get_speaker_balancer_weights(data_items) * alpha
|
||||||
|
else:
|
||||||
|
weights = get_speaker_balancer_weights(data_items) * alpha
|
||||||
|
|
||||||
|
if getattr(config, "use_length_weighted_sampler", False):
|
||||||
|
alpha = getattr(config, "length_weighted_sampler_alpha", 1.0)
|
||||||
|
print(" > Using Length weighted sampler with alpha:", alpha)
|
||||||
|
if weights is not None:
|
||||||
|
weights += get_length_balancer_weights(data_items) * alpha
|
||||||
|
else:
|
||||||
|
weights = get_length_balancer_weights(data_items) * alpha
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
sampler = WeightedRandomSampler(weights, len(weights))
|
||||||
|
else:
|
||||||
|
sampler = None
|
||||||
|
|
||||||
|
# sampler for DDP
|
||||||
|
if sampler is None:
|
||||||
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
else: # If a sampler is already defined use this sampler and DDP sampler together
|
||||||
|
sampler = DistributedSamplerWrapper(sampler) if num_gpus > 1 else sampler
|
||||||
|
|
||||||
|
return sampler
|
||||||
|
|
||||||
|
def get_data_loader(
|
||||||
|
self,
|
||||||
|
config: Coqpit,
|
||||||
|
assets: Dict,
|
||||||
|
is_eval: bool,
|
||||||
|
samples: Union[List[Dict], List[List]],
|
||||||
|
verbose: bool,
|
||||||
|
num_gpus: int,
|
||||||
|
rank: int = None,
|
||||||
|
) -> "DataLoader":
|
||||||
|
if is_eval and not config.run_eval:
|
||||||
|
loader = None
|
||||||
|
else:
|
||||||
|
# setup multi-speaker attributes
|
||||||
|
if self.speaker_manager is not None:
|
||||||
|
if hasattr(config, "model_args"):
|
||||||
|
speaker_id_mapping = (
|
||||||
|
self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None
|
||||||
|
)
|
||||||
|
d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None
|
||||||
|
config.use_d_vector_file = config.model_args.use_d_vector_file
|
||||||
|
else:
|
||||||
|
speaker_id_mapping = self.speaker_manager.name_to_id if config.use_speaker_embedding else None
|
||||||
|
d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None
|
||||||
|
else:
|
||||||
|
speaker_id_mapping = None
|
||||||
|
d_vector_mapping = None
|
||||||
|
|
||||||
|
# setup multi-lingual attributes
|
||||||
|
if self.language_manager is not None:
|
||||||
|
language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None
|
||||||
|
else:
|
||||||
|
language_id_mapping = None
|
||||||
|
|
||||||
|
# init dataloader
|
||||||
|
dataset = TTSDataset(
|
||||||
|
outputs_per_step=config.r if "r" in config else 1,
|
||||||
|
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||||
|
compute_f0=config.get("compute_f0", False),
|
||||||
|
f0_cache_path=config.get("f0_cache_path", None),
|
||||||
|
compute_energy=config.get("compute_energy", False),
|
||||||
|
energy_cache_path=config.get("energy_cache_path", None),
|
||||||
|
samples=samples,
|
||||||
|
ap=self.ap,
|
||||||
|
return_wav=config.return_wav if "return_wav" in config else False,
|
||||||
|
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||||
|
min_text_len=config.min_text_len,
|
||||||
|
max_text_len=config.max_text_len,
|
||||||
|
min_audio_len=config.min_audio_len,
|
||||||
|
max_audio_len=config.max_audio_len,
|
||||||
|
phoneme_cache_path=config.phoneme_cache_path,
|
||||||
|
precompute_num_workers=config.precompute_num_workers,
|
||||||
|
use_noise_augment=False if is_eval else config.use_noise_augment,
|
||||||
|
verbose=verbose,
|
||||||
|
speaker_id_mapping=speaker_id_mapping,
|
||||||
|
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
||||||
|
tokenizer=None,
|
||||||
|
start_by_longest=config.start_by_longest,
|
||||||
|
language_id_mapping=language_id_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
# wait all the DDP process to be ready
|
||||||
|
if num_gpus > 1:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
# sort input sequences from short to long
|
||||||
|
dataset.preprocess_samples()
|
||||||
|
|
||||||
|
# get samplers
|
||||||
|
sampler = self.get_sampler(config, dataset, num_gpus)
|
||||||
|
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||||
|
shuffle=config.shuffle if sampler is None else False, # if there is no other sampler
|
||||||
|
collate_fn=dataset.collate_fn,
|
||||||
|
drop_last=config.drop_last, # setting this False might cause issues in AMP training.
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||||
|
pin_memory=False,
|
||||||
|
)
|
||||||
|
return loader
|
||||||
|
|
||||||
|
def _get_test_aux_input(
|
||||||
|
self,
|
||||||
|
) -> Dict:
|
||||||
|
d_vector = None
|
||||||
|
if self.config.use_d_vector_file:
|
||||||
|
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
|
||||||
|
d_vector = (random.sample(sorted(d_vector), 1),)
|
||||||
|
|
||||||
|
aux_inputs = {
|
||||||
|
"speaker_id": None
|
||||||
|
if not self.config.use_speaker_embedding
|
||||||
|
else random.sample(sorted(self.speaker_manager.name_to_id.values()), 1),
|
||||||
|
"d_vector": d_vector,
|
||||||
|
"style_wav": None, # TODO: handle GST style input
|
||||||
|
}
|
||||||
|
return aux_inputs
|
||||||
|
|
||||||
|
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
||||||
|
"""Generic test run for `vc` models used by `Trainer`.
|
||||||
|
|
||||||
|
You can override this for a different behaviour.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
assets (dict): A dict of training assets. For `vc` models, it must include `{'audio_processor': ap}`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||||
|
"""
|
||||||
|
print(" | > Synthesizing test sentences.")
|
||||||
|
test_audios = {}
|
||||||
|
test_figures = {}
|
||||||
|
test_sentences = self.config.test_sentences
|
||||||
|
aux_inputs = self._get_test_aux_input()
|
||||||
|
for idx, sen in enumerate(test_sentences):
|
||||||
|
if isinstance(sen, list):
|
||||||
|
aux_inputs = self.get_aux_input_from_test_sentences(sen)
|
||||||
|
sen = aux_inputs["text"]
|
||||||
|
outputs_dict = synthesis(
|
||||||
|
self,
|
||||||
|
sen,
|
||||||
|
self.config,
|
||||||
|
"cuda" in str(next(self.parameters()).device),
|
||||||
|
speaker_id=aux_inputs["speaker_id"],
|
||||||
|
d_vector=aux_inputs["d_vector"],
|
||||||
|
style_wav=aux_inputs["style_wav"],
|
||||||
|
use_griffin_lim=True,
|
||||||
|
do_trim_silence=False,
|
||||||
|
)
|
||||||
|
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
|
||||||
|
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||||
|
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
|
||||||
|
)
|
||||||
|
test_figures["{}-alignment".format(idx)] = plot_alignment(
|
||||||
|
outputs_dict["outputs"]["alignments"], output_fig=False
|
||||||
|
)
|
||||||
|
return test_figures, test_audios
|
||||||
|
|
||||||
|
def on_init_start(self, trainer):
|
||||||
|
"""Save the speaker.pth and language_ids.json at the beginning of the training. Also update both paths."""
|
||||||
|
if self.speaker_manager is not None:
|
||||||
|
output_path = os.path.join(trainer.output_path, "speakers.pth")
|
||||||
|
self.speaker_manager.save_ids_to_file(output_path)
|
||||||
|
trainer.config.speakers_file = output_path
|
||||||
|
# some models don't have `model_args` set
|
||||||
|
if hasattr(trainer.config, "model_args"):
|
||||||
|
trainer.config.model_args.speakers_file = output_path
|
||||||
|
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||||
|
print(f" > `speakers.pth` is saved to {output_path}.")
|
||||||
|
print(" > `speakers_file` is updated in the config.json.")
|
||||||
|
|
||||||
|
if self.language_manager is not None:
|
||||||
|
output_path = os.path.join(trainer.output_path, "language_ids.json")
|
||||||
|
self.language_manager.save_ids_to_file(output_path)
|
||||||
|
trainer.config.language_ids_file = output_path
|
||||||
|
if hasattr(trainer.config, "model_args"):
|
||||||
|
trainer.config.model_args.language_ids_file = output_path
|
||||||
|
trainer.config.save_json(os.path.join(trainer.output_path, "config.json"))
|
||||||
|
print(f" > `language_ids.json` is saved to {output_path}.")
|
||||||
|
print(" > `language_ids_file` is updated in the config.json.")
|
|
@ -0,0 +1,833 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from coqpit import Coqpit
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||||
|
|
||||||
|
import TTS.vc.modules.freevc.commons as commons
|
||||||
|
import TTS.vc.modules.freevc.modules as modules
|
||||||
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
|
from TTS.utils.io import load_fsspec, save_checkpoint
|
||||||
|
from TTS.vc.configs.shared_configs import BaseVCConfig
|
||||||
|
from TTS.vc.models.base_vc import BaseVC
|
||||||
|
from TTS.vc.modules.freevc.commons import get_padding, init_weights
|
||||||
|
from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
|
||||||
|
from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
|
||||||
|
from TTS.vc.modules.freevc.wavlm import get_wavlm
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualCouplingBlock(nn.Module):
|
||||||
|
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.dilation_rate = dilation_rate
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.n_flows = n_flows
|
||||||
|
self.gin_channels = gin_channels
|
||||||
|
|
||||||
|
self.flows = nn.ModuleList()
|
||||||
|
for i in range(n_flows):
|
||||||
|
self.flows.append(
|
||||||
|
modules.ResidualCouplingLayer(
|
||||||
|
channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation_rate,
|
||||||
|
n_layers,
|
||||||
|
gin_channels=gin_channels,
|
||||||
|
mean_only=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.flows.append(modules.Flip())
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None, reverse=False):
|
||||||
|
if not reverse:
|
||||||
|
for flow in self.flows:
|
||||||
|
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
||||||
|
else:
|
||||||
|
for flow in reversed(self.flows):
|
||||||
|
x = flow(x, x_mask, g=g, reverse=reverse)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.dilation_rate = dilation_rate
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.gin_channels = gin_channels
|
||||||
|
|
||||||
|
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||||
|
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
||||||
|
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||||
|
|
||||||
|
def forward(self, x, x_lengths, g=None):
|
||||||
|
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||||
|
x = self.pre(x) * x_mask
|
||||||
|
x = self.enc(x, x_mask, g=g)
|
||||||
|
stats = self.proj(x) * x_mask
|
||||||
|
m, logs = torch.split(stats, self.out_channels, dim=1)
|
||||||
|
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
||||||
|
return z, m, logs, x_mask
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
initial_channel,
|
||||||
|
resblock,
|
||||||
|
resblock_kernel_sizes,
|
||||||
|
resblock_dilation_sizes,
|
||||||
|
upsample_rates,
|
||||||
|
upsample_initial_channel,
|
||||||
|
upsample_kernel_sizes,
|
||||||
|
gin_channels=0,
|
||||||
|
):
|
||||||
|
super(Generator, self).__init__()
|
||||||
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||||
|
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
||||||
|
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
|
self.ups.append(
|
||||||
|
weight_norm(
|
||||||
|
ConvTranspose1d(
|
||||||
|
upsample_initial_channel // (2**i),
|
||||||
|
upsample_initial_channel // (2 ** (i + 1)),
|
||||||
|
k,
|
||||||
|
u,
|
||||||
|
padding=(k - u) // 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
|
self.resblocks.append(resblock(ch, k, d))
|
||||||
|
|
||||||
|
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||||
|
self.ups.apply(init_weights)
|
||||||
|
|
||||||
|
if gin_channels != 0:
|
||||||
|
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||||
|
|
||||||
|
def forward(self, x, g=None):
|
||||||
|
x = self.conv_pre(x)
|
||||||
|
if g is not None:
|
||||||
|
x = x + self.cond(g)
|
||||||
|
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||||
|
x = self.ups[i](x)
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
x = F.leaky_relu(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
x = torch.tanh(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
print("Removing weight norm...")
|
||||||
|
for l in self.ups:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
for l in self.resblocks:
|
||||||
|
l.remove_weight_norm()
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorP(torch.nn.Module):
|
||||||
|
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
||||||
|
super(DiscriminatorP, self).__init__()
|
||||||
|
self.period = period
|
||||||
|
self.use_spectral_norm = use_spectral_norm
|
||||||
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||||
|
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||||
|
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||||
|
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
||||||
|
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fmap = []
|
||||||
|
|
||||||
|
# 1d to 2d
|
||||||
|
b, c, t = x.shape
|
||||||
|
if t % self.period != 0: # pad first
|
||||||
|
n_pad = self.period - (t % self.period)
|
||||||
|
x = F.pad(x, (0, n_pad), "reflect")
|
||||||
|
t = t + n_pad
|
||||||
|
x = x.view(b, c, t // self.period, self.period)
|
||||||
|
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||||
|
fmap.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
|
||||||
|
return x, fmap
|
||||||
|
|
||||||
|
|
||||||
|
class DiscriminatorS(torch.nn.Module):
|
||||||
|
def __init__(self, use_spectral_norm=False):
|
||||||
|
super(DiscriminatorS, self).__init__()
|
||||||
|
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
||||||
|
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
||||||
|
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
||||||
|
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
||||||
|
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
||||||
|
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
fmap = []
|
||||||
|
|
||||||
|
for l in self.convs:
|
||||||
|
x = l(x)
|
||||||
|
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
||||||
|
fmap.append(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
fmap.append(x)
|
||||||
|
x = torch.flatten(x, 1, -1)
|
||||||
|
|
||||||
|
return x, fmap
|
||||||
|
|
||||||
|
|
||||||
|
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||||
|
def __init__(self, use_spectral_norm=False):
|
||||||
|
super(MultiPeriodDiscriminator, self).__init__()
|
||||||
|
periods = [2, 3, 5, 7, 11]
|
||||||
|
|
||||||
|
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||||
|
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||||
|
self.discriminators = nn.ModuleList(discs)
|
||||||
|
|
||||||
|
def forward(self, y, y_hat):
|
||||||
|
y_d_rs = []
|
||||||
|
y_d_gs = []
|
||||||
|
fmap_rs = []
|
||||||
|
fmap_gs = []
|
||||||
|
for i, d in enumerate(self.discriminators):
|
||||||
|
y_d_r, fmap_r = d(y)
|
||||||
|
y_d_g, fmap_g = d(y_hat)
|
||||||
|
y_d_rs.append(y_d_r)
|
||||||
|
y_d_gs.append(y_d_g)
|
||||||
|
fmap_rs.append(fmap_r)
|
||||||
|
fmap_gs.append(fmap_g)
|
||||||
|
|
||||||
|
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||||
|
|
||||||
|
|
||||||
|
class SpeakerEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256):
|
||||||
|
super(SpeakerEncoder, self).__init__()
|
||||||
|
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
||||||
|
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, mels):
|
||||||
|
self.lstm.flatten_parameters()
|
||||||
|
_, (hidden, _) = self.lstm(mels)
|
||||||
|
embeds_raw = self.relu(self.linear(hidden[-1]))
|
||||||
|
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
|
||||||
|
mel_slices = []
|
||||||
|
for i in range(0, total_frames - partial_frames, partial_hop):
|
||||||
|
mel_range = torch.arange(i, i + partial_frames)
|
||||||
|
mel_slices.append(mel_range)
|
||||||
|
|
||||||
|
return mel_slices
|
||||||
|
|
||||||
|
def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
|
||||||
|
mel_len = mel.size(1)
|
||||||
|
last_mel = mel[:, -partial_frames:]
|
||||||
|
|
||||||
|
if mel_len > partial_frames:
|
||||||
|
mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
|
||||||
|
mels = list(mel[:, s] for s in mel_slices)
|
||||||
|
mels.append(last_mel)
|
||||||
|
mels = torch.stack(tuple(mels), 0).squeeze(1)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
partial_embeds = self(mels)
|
||||||
|
embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
|
||||||
|
# embed = embed / torch.linalg.norm(embed, 2)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
embed = self(last_mel)
|
||||||
|
|
||||||
|
return embed
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FreeVCAudioConfig(Coqpit):
|
||||||
|
"""Audio configuration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_wav_value (float):
|
||||||
|
The maximum value of the waveform.
|
||||||
|
|
||||||
|
input_sample_rate (int):
|
||||||
|
The sampling rate of the input waveform.
|
||||||
|
|
||||||
|
output_sample_rate (int):
|
||||||
|
The sampling rate of the output waveform.
|
||||||
|
|
||||||
|
filter_length (int):
|
||||||
|
The length of the filter.
|
||||||
|
|
||||||
|
hop_length (int):
|
||||||
|
The hop length.
|
||||||
|
|
||||||
|
win_length (int):
|
||||||
|
The window length.
|
||||||
|
|
||||||
|
n_mel_channels (int):
|
||||||
|
The number of mel channels.
|
||||||
|
|
||||||
|
mel_fmin (float):
|
||||||
|
The minimum frequency of the mel filterbank.
|
||||||
|
|
||||||
|
mel_fmax (Optional[float]):
|
||||||
|
The maximum frequency of the mel filterbank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_wav_value: float = field(default=32768.0)
|
||||||
|
input_sample_rate: int = field(default=16000)
|
||||||
|
output_sample_rate: int = field(default=24000)
|
||||||
|
filter_length: int = field(default=1280)
|
||||||
|
hop_length: int = field(default=320)
|
||||||
|
win_length: int = field(default=1280)
|
||||||
|
n_mel_channels: int = field(default=80)
|
||||||
|
mel_fmin: float = field(default=0.0)
|
||||||
|
mel_fmax: Optional[float] = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FreeVCArgs(Coqpit):
|
||||||
|
"""FreeVC model arguments
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spec_channels (int):
|
||||||
|
The number of channels in the spectrogram.
|
||||||
|
|
||||||
|
inter_channels (int):
|
||||||
|
The number of channels in the intermediate layers.
|
||||||
|
|
||||||
|
hidden_channels (int):
|
||||||
|
The number of channels in the hidden layers.
|
||||||
|
|
||||||
|
filter_channels (int):
|
||||||
|
The number of channels in the filter layers.
|
||||||
|
|
||||||
|
n_heads (int):
|
||||||
|
The number of attention heads.
|
||||||
|
|
||||||
|
n_layers (int):
|
||||||
|
The number of layers.
|
||||||
|
|
||||||
|
kernel_size (int):
|
||||||
|
The size of the kernel.
|
||||||
|
|
||||||
|
p_dropout (float):
|
||||||
|
The dropout probability.
|
||||||
|
|
||||||
|
resblock (str):
|
||||||
|
The type of residual block.
|
||||||
|
|
||||||
|
resblock_kernel_sizes (List[int]):
|
||||||
|
The kernel sizes for the residual blocks.
|
||||||
|
|
||||||
|
resblock_dilation_sizes (List[List[int]]):
|
||||||
|
The dilation sizes for the residual blocks.
|
||||||
|
|
||||||
|
upsample_rates (List[int]):
|
||||||
|
The upsample rates.
|
||||||
|
|
||||||
|
upsample_initial_channel (int):
|
||||||
|
The number of channels in the initial upsample layer.
|
||||||
|
|
||||||
|
upsample_kernel_sizes (List[int]):
|
||||||
|
The kernel sizes for the upsample layers.
|
||||||
|
|
||||||
|
n_layers_q (int):
|
||||||
|
The number of layers in the quantization network.
|
||||||
|
|
||||||
|
use_spectral_norm (bool):
|
||||||
|
Whether to use spectral normalization.
|
||||||
|
|
||||||
|
gin_channels (int):
|
||||||
|
The number of channels in the global conditioning vector.
|
||||||
|
|
||||||
|
ssl_dim (int):
|
||||||
|
The dimension of the self-supervised learning embedding.
|
||||||
|
|
||||||
|
use_spk (bool):
|
||||||
|
Whether to use external speaker encoder.
|
||||||
|
"""
|
||||||
|
|
||||||
|
spec_channels: int = field(default=641)
|
||||||
|
inter_channels: int = field(default=192)
|
||||||
|
hidden_channels: int = field(default=192)
|
||||||
|
filter_channels: int = field(default=768)
|
||||||
|
n_heads: int = field(default=2)
|
||||||
|
n_layers: int = field(default=6)
|
||||||
|
kernel_size: int = field(default=3)
|
||||||
|
p_dropout: float = field(default=0.1)
|
||||||
|
resblock: str = field(default="1")
|
||||||
|
resblock_kernel_sizes: List[int] = field(default_factory=lambda: [3, 7, 11])
|
||||||
|
resblock_dilation_sizes: List[List[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||||
|
upsample_rates: List[int] = field(default_factory=lambda: [10, 8, 2, 2])
|
||||||
|
upsample_initial_channel: int = field(default=512)
|
||||||
|
upsample_kernel_sizes: List[int] = field(default_factory=lambda: [16, 16, 4, 4])
|
||||||
|
n_layers_q: int = field(default=3)
|
||||||
|
use_spectral_norm: bool = field(default=False)
|
||||||
|
gin_channels: int = field(default=256)
|
||||||
|
ssl_dim: int = field(default=1024)
|
||||||
|
use_spk: bool = field(default=False)
|
||||||
|
num_spks: int = field(default=0)
|
||||||
|
segment_size: int = field(default=8960)
|
||||||
|
|
||||||
|
|
||||||
|
class FreeVC(BaseVC):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Papaer::
|
||||||
|
https://arxiv.org/abs/2210.15418#
|
||||||
|
|
||||||
|
Paper Abstract::
|
||||||
|
Voice conversion (VC) can be achieved by first extracting source content information and target speaker
|
||||||
|
information, and then reconstructing waveform with these information. However, current approaches normally
|
||||||
|
either extract dirty content information with speaker information leaked in, or demand a large amount of
|
||||||
|
annotated data for training. Besides, the quality of reconstructed waveform can be degraded by the
|
||||||
|
mismatch between conversion model and vocoder. In this paper, we adopt the end-to-end framework of VITS for
|
||||||
|
high-quality waveform reconstruction, and propose strategies for clean content information extraction without
|
||||||
|
text annotation. We disentangle content information by imposing an information bottleneck to WavLM features,
|
||||||
|
and propose the spectrogram-resize based data augmentation to improve the purity of extracted content
|
||||||
|
information. Experimental results show that the proposed method outperforms the latest VC models trained with
|
||||||
|
annotated data and has greater robustness.
|
||||||
|
|
||||||
|
Original Code::
|
||||||
|
https://github.com/OlaWod/FreeVC
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||||
|
>>> from TTS.vc.models.freevc import FreeVC
|
||||||
|
>>> config = FreeVCConfig()
|
||||||
|
>>> model = FreeVC(config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||||
|
super().__init__(config, None, speaker_manager, None)
|
||||||
|
|
||||||
|
self.init_multispeaker(config)
|
||||||
|
|
||||||
|
self.spec_channels = self.args.spec_channels
|
||||||
|
self.inter_channels = self.args.inter_channels
|
||||||
|
self.hidden_channels = self.args.hidden_channels
|
||||||
|
self.filter_channels = self.args.filter_channels
|
||||||
|
self.n_heads = self.args.n_heads
|
||||||
|
self.n_layers = self.args.n_layers
|
||||||
|
self.kernel_size = self.args.kernel_size
|
||||||
|
self.p_dropout = self.args.p_dropout
|
||||||
|
self.resblock = self.args.resblock
|
||||||
|
self.resblock_kernel_sizes = self.args.resblock_kernel_sizes
|
||||||
|
self.resblock_dilation_sizes = self.args.resblock_dilation_sizes
|
||||||
|
self.upsample_rates = self.args.upsample_rates
|
||||||
|
self.upsample_initial_channel = self.args.upsample_initial_channel
|
||||||
|
self.upsample_kernel_sizes = self.args.upsample_kernel_sizes
|
||||||
|
self.segment_size = self.args.segment_size
|
||||||
|
self.gin_channels = self.args.gin_channels
|
||||||
|
self.ssl_dim = self.args.ssl_dim
|
||||||
|
self.use_spk = self.args.use_spk
|
||||||
|
|
||||||
|
self.enc_p = Encoder(self.args.ssl_dim, self.inter_channels, self.hidden_channels, 5, 1, 16)
|
||||||
|
self.dec = Generator(
|
||||||
|
self.inter_channels,
|
||||||
|
self.resblock,
|
||||||
|
self.resblock_kernel_sizes,
|
||||||
|
self.resblock_dilation_sizes,
|
||||||
|
self.upsample_rates,
|
||||||
|
self.upsample_initial_channel,
|
||||||
|
self.upsample_kernel_sizes,
|
||||||
|
gin_channels=self.gin_channels,
|
||||||
|
)
|
||||||
|
self.enc_q = Encoder(
|
||||||
|
self.spec_channels, self.inter_channels, self.hidden_channels, 5, 1, 16, gin_channels=self.gin_channels
|
||||||
|
)
|
||||||
|
self.flow = ResidualCouplingBlock(
|
||||||
|
self.inter_channels, self.hidden_channels, 5, 1, 4, gin_channels=self.gin_channels
|
||||||
|
)
|
||||||
|
if not self.use_spk:
|
||||||
|
self.enc_spk = SpeakerEncoder(model_hidden_size=self.gin_channels, model_embedding_size=self.gin_channels)
|
||||||
|
else:
|
||||||
|
self.load_pretrained_speaker_encoder()
|
||||||
|
|
||||||
|
self.wavlm = get_wavlm()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
def load_pretrained_speaker_encoder(self):
|
||||||
|
"""Load pretrained speaker encoder model as mentioned in the paper."""
|
||||||
|
print(" > Loading pretrained speaker encoder model ...")
|
||||||
|
self.enc_spk_ex = SpeakerEncoderEx(
|
||||||
|
"https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_multispeaker(self, config: Coqpit):
|
||||||
|
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||||
|
or with external `d_vectors` computed from a speaker encoder model.
|
||||||
|
|
||||||
|
You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Coqpit): Model configuration.
|
||||||
|
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||||
|
"""
|
||||||
|
self.num_spks = self.args.num_spks
|
||||||
|
if self.speaker_manager:
|
||||||
|
self.num_spks = self.speaker_manager.num_spks
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
c: torch.Tensor,
|
||||||
|
spec: torch.Tensor,
|
||||||
|
g: Optional[torch.Tensor] = None,
|
||||||
|
mel: Optional[torch.Tensor] = None,
|
||||||
|
c_lengths: Optional[torch.Tensor] = None,
|
||||||
|
spec_lengths: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
torch.Tensor,
|
||||||
|
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
|
]:
|
||||||
|
"""
|
||||||
|
Forward pass of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
c: WavLM features. Shape: (batch_size, c_seq_len).
|
||||||
|
spec: The input spectrogram. Shape: (batch_size, spec_seq_len, spec_dim).
|
||||||
|
g: The speaker embedding. Shape: (batch_size, spk_emb_dim).
|
||||||
|
mel: The input mel-spectrogram for the speaker encoder. Shape: (batch_size, mel_seq_len, mel_dim).
|
||||||
|
c_lengths: The lengths of the WavLM features. Shape: (batch_size,).
|
||||||
|
spec_lengths: The lengths of the spectrogram. Shape: (batch_size,).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
o: The output spectrogram. Shape: (batch_size, spec_seq_len, spec_dim).
|
||||||
|
ids_slice: The slice indices. Shape: (batch_size, num_slices).
|
||||||
|
spec_mask: The spectrogram mask. Shape: (batch_size, spec_seq_len).
|
||||||
|
(z, z_p, m_p, logs_p, m_q, logs_q): A tuple of latent variables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If c_lengths is None, set it to the length of the last dimension of c
|
||||||
|
if c_lengths is None:
|
||||||
|
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
|
||||||
|
|
||||||
|
# If spec_lengths is None, set it to the length of the last dimension of spec
|
||||||
|
if spec_lengths is None:
|
||||||
|
spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device)
|
||||||
|
|
||||||
|
# If use_spk is False, compute g from mel using enc_spk
|
||||||
|
g = None
|
||||||
|
if not self.use_spk:
|
||||||
|
g = self.enc_spk(mel).unsqueeze(-1)
|
||||||
|
|
||||||
|
# Compute m_p, logs_p, z, m_q, logs_q, and spec_mask using enc_p and enc_q
|
||||||
|
_, m_p, logs_p, _ = self.enc_p(c, c_lengths)
|
||||||
|
z, m_q, logs_q, spec_mask = self.enc_q(spec.transpose(1, 2), spec_lengths, g=g)
|
||||||
|
|
||||||
|
# Compute z_p using flow
|
||||||
|
z_p = self.flow(z, spec_mask, g=g)
|
||||||
|
|
||||||
|
# Randomly slice z and compute o using dec
|
||||||
|
z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size)
|
||||||
|
o = self.dec(z_slice, g=g)
|
||||||
|
|
||||||
|
return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def inference(self, c, g=None, mel=None, c_lengths=None):
|
||||||
|
"""
|
||||||
|
Inference pass of the model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
c (torch.Tensor): Input tensor. Shape: (batch_size, c_seq_len).
|
||||||
|
g (torch.Tensor): Speaker embedding tensor. Shape: (batch_size, spk_emb_dim).
|
||||||
|
mel (torch.Tensor): Mel-spectrogram tensor. Shape: (batch_size, mel_seq_len, mel_dim).
|
||||||
|
c_lengths (torch.Tensor): Lengths of the input tensor. Shape: (batch_size,).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor.
|
||||||
|
"""
|
||||||
|
if c_lengths == None:
|
||||||
|
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
|
||||||
|
if not self.use_spk:
|
||||||
|
g = self.enc_spk.embed_utterance(mel)
|
||||||
|
g = g.unsqueeze(-1)
|
||||||
|
z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths)
|
||||||
|
z = self.flow(z_p, c_mask, g=g, reverse=True)
|
||||||
|
o = self.dec(z * c_mask, g=g)
|
||||||
|
return o
|
||||||
|
|
||||||
|
def extract_wavlm_features(self, y):
|
||||||
|
"""Extract WavLM features from an audio tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): Audio tensor. Shape: (batch_size, audio_seq_len).
|
||||||
|
"""
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
c = self.wavlm.extract_features(y)[0]
|
||||||
|
c = c.transpose(1, 2)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def load_audio(self, wav):
|
||||||
|
"""Read and format the input audio."""
|
||||||
|
if isinstance(wav, str):
|
||||||
|
wav, _ = librosa.load(wav, sr=self.config.audio.input_sample_rate)
|
||||||
|
if isinstance(wav, np.ndarray):
|
||||||
|
wav = torch.from_numpy(wav).to(self.device)
|
||||||
|
if isinstance(wav, torch.Tensor):
|
||||||
|
wav = wav.to(self.device)
|
||||||
|
if isinstance(wav, list):
|
||||||
|
wav = torch.from_numpy(np.array(wav)).to(self.device)
|
||||||
|
return wav.float()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def voice_conversion(self, src, tgt):
|
||||||
|
"""
|
||||||
|
Voice conversion pass of the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src (str or torch.Tensor): Source utterance.
|
||||||
|
tgt (str or torch.Tensor): Target utterance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
wav_tgt = self.load_audio(tgt).cpu().numpy()
|
||||||
|
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
|
||||||
|
|
||||||
|
if self.config.model_args.use_spk:
|
||||||
|
g_tgt = self.enc_spk_ex.embed_utterance(wav_tgt)
|
||||||
|
g_tgt = torch.from_numpy(g_tgt)[None, :, None].to(self.device)
|
||||||
|
else:
|
||||||
|
wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(self.device)
|
||||||
|
mel_tgt = mel_spectrogram_torch(
|
||||||
|
wav_tgt,
|
||||||
|
self.config.audio.filter_length,
|
||||||
|
self.config.audio.n_mel_channels,
|
||||||
|
self.config.audio.input_sample_rate,
|
||||||
|
self.config.audio.hop_length,
|
||||||
|
self.config.audio.win_length,
|
||||||
|
self.config.audio.mel_fmin,
|
||||||
|
self.config.audio.mel_fmax,
|
||||||
|
)
|
||||||
|
# src
|
||||||
|
wav_src = self.load_audio(src)
|
||||||
|
c = self.extract_wavlm_features(wav_src[None, :])
|
||||||
|
|
||||||
|
if self.config.model_args.use_spk:
|
||||||
|
audio = self.inference(c, g=g_tgt)
|
||||||
|
else:
|
||||||
|
audio = self.inference(c, mel=mel_tgt.transpose(1, 2))
|
||||||
|
audio = audio[0][0].data.cpu().float().numpy()
|
||||||
|
return audio
|
||||||
|
|
||||||
|
def eval_step():
|
||||||
|
...
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
|
||||||
|
model = FreeVC(config)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def load_checkpoint(self, config, checkpoint_path, eval=False, strict=True, cache=False):
|
||||||
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
|
self.load_state_dict(state["model"], strict=strict)
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
def train_step():
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FreeVCConfig(BaseVCConfig):
|
||||||
|
"""Defines parameters for FreeVC End2End TTS model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (str):
|
||||||
|
Model name. Do not change unless you know what you are doing.
|
||||||
|
|
||||||
|
model_args (FreeVCArgs):
|
||||||
|
Model architecture arguments. Defaults to `FreeVCArgs()`.
|
||||||
|
|
||||||
|
audio (FreeVCAudioConfig):
|
||||||
|
Audio processing configuration. Defaults to `FreeVCAudioConfig()`.
|
||||||
|
|
||||||
|
grad_clip (List):
|
||||||
|
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.
|
||||||
|
|
||||||
|
lr_gen (float):
|
||||||
|
Initial learning rate for the generator. Defaults to 0.0002.
|
||||||
|
|
||||||
|
lr_disc (float):
|
||||||
|
Initial learning rate for the discriminator. Defaults to 0.0002.
|
||||||
|
|
||||||
|
lr_scheduler_gen (str):
|
||||||
|
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
||||||
|
`ExponentialLR`.
|
||||||
|
|
||||||
|
lr_scheduler_gen_params (dict):
|
||||||
|
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
||||||
|
|
||||||
|
lr_scheduler_disc (str):
|
||||||
|
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
|
||||||
|
`ExponentialLR`.
|
||||||
|
|
||||||
|
lr_scheduler_disc_params (dict):
|
||||||
|
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.
|
||||||
|
|
||||||
|
scheduler_after_epoch (bool):
|
||||||
|
If true, step the schedulers after each epoch else after each step. Defaults to `False`.
|
||||||
|
|
||||||
|
optimizer (str):
|
||||||
|
Name of the optimizer to use with both the generator and the discriminator networks. One of the
|
||||||
|
`torch.optim.*`. Defaults to `AdamW`.
|
||||||
|
|
||||||
|
kl_loss_alpha (float):
|
||||||
|
Loss weight for KL loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
disc_loss_alpha (float):
|
||||||
|
Loss weight for the discriminator loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
gen_loss_alpha (float):
|
||||||
|
Loss weight for the generator loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
feat_loss_alpha (float):
|
||||||
|
Loss weight for the feature matching loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
mel_loss_alpha (float):
|
||||||
|
Loss weight for the mel loss. Defaults to 45.0.
|
||||||
|
|
||||||
|
return_wav (bool):
|
||||||
|
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
|
||||||
|
|
||||||
|
compute_linear_spec (bool):
|
||||||
|
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||||
|
|
||||||
|
use_weighted_sampler (bool):
|
||||||
|
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
|
||||||
|
|
||||||
|
weighted_sampler_attrs (dict):
|
||||||
|
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
|
||||||
|
by overweighting `root_path` by 2.0. Defaults to `{}`.
|
||||||
|
|
||||||
|
weighted_sampler_multipliers (dict):
|
||||||
|
Weight each unique value of a key returned by the formatter for weighted sampling.
|
||||||
|
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
|
||||||
|
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
|
||||||
|
|
||||||
|
r (int):
|
||||||
|
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||||
|
|
||||||
|
add_blank (bool):
|
||||||
|
If true, a blank token is added in between every character. Defaults to `True`.
|
||||||
|
|
||||||
|
test_sentences (List[List]):
|
||||||
|
List of sentences with speaker and language information to be used for testing.
|
||||||
|
|
||||||
|
language_ids_file (str):
|
||||||
|
Path to the language ids file.
|
||||||
|
|
||||||
|
use_language_embedding (bool):
|
||||||
|
If true, language embedding is used. Defaults to `False`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> from TTS.tts.configs.freevc_config import FreeVCConfig
|
||||||
|
>>> config = FreeVCConfig()
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: str = "freevc"
|
||||||
|
# model specific params
|
||||||
|
model_args: FreeVCArgs = FreeVCArgs()
|
||||||
|
audio: FreeVCAudioConfig = FreeVCAudioConfig()
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
# TODO with training support
|
||||||
|
|
||||||
|
# loss params
|
||||||
|
# TODO with training support
|
||||||
|
|
||||||
|
# data loader params
|
||||||
|
return_wav: bool = True
|
||||||
|
compute_linear_spec: bool = True
|
||||||
|
|
||||||
|
# sampler params
|
||||||
|
use_weighted_sampler: bool = False # TODO: move it to the base config
|
||||||
|
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
|
||||||
|
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
|
||||||
|
|
||||||
|
# overrides
|
||||||
|
r: int = 1 # DO NOT CHANGE
|
||||||
|
add_blank: bool = True
|
||||||
|
|
||||||
|
# multi-speaker settings
|
||||||
|
# use speaker embedding layer
|
||||||
|
num_speakers: int = 0
|
||||||
|
speakers_file: str = None
|
||||||
|
speaker_embedding_channels: int = 256
|
||||||
|
|
||||||
|
# use d-vectors
|
||||||
|
use_d_vector_file: bool = False
|
||||||
|
d_vector_file: List[str] = None
|
||||||
|
d_vector_dim: int = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
for key, val in self.model_args.items():
|
||||||
|
if hasattr(self, key):
|
||||||
|
self[key] = val
|
|
@ -0,0 +1,170 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
m.weight.data.normal_(mean, std)
|
||||||
|
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pad_shape(pad_shape):
|
||||||
|
l = pad_shape[::-1]
|
||||||
|
pad_shape = [item for sublist in l for item in sublist]
|
||||||
|
return pad_shape
|
||||||
|
|
||||||
|
|
||||||
|
def intersperse(lst, item):
|
||||||
|
result = [item] * (len(lst) * 2 + 1)
|
||||||
|
result[1::2] = lst
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
||||||
|
"""KL(P||Q)"""
|
||||||
|
kl = (logs_q - logs_p) - 0.5
|
||||||
|
kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
||||||
|
return kl
|
||||||
|
|
||||||
|
|
||||||
|
def rand_gumbel(shape):
|
||||||
|
"""Sample from the Gumbel distribution, protect from overflows."""
|
||||||
|
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
||||||
|
return -torch.log(-torch.log(uniform_samples))
|
||||||
|
|
||||||
|
|
||||||
|
def rand_gumbel_like(x):
|
||||||
|
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def slice_segments(x, ids_str, segment_size=4):
|
||||||
|
ret = torch.zeros_like(x[:, :, :segment_size])
|
||||||
|
for i in range(x.size(0)):
|
||||||
|
idx_str = ids_str[i]
|
||||||
|
idx_end = idx_str + segment_size
|
||||||
|
ret[i] = x[i, :, idx_str:idx_end]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||||
|
b, d, t = x.size()
|
||||||
|
if x_lengths is None:
|
||||||
|
x_lengths = t
|
||||||
|
ids_str_max = x_lengths - segment_size + 1
|
||||||
|
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||||
|
ret = slice_segments(x, ids_str, segment_size)
|
||||||
|
return ret, ids_str
|
||||||
|
|
||||||
|
|
||||||
|
def rand_spec_segments(x, x_lengths=None, segment_size=4):
|
||||||
|
b, d, t = x.size()
|
||||||
|
if x_lengths is None:
|
||||||
|
x_lengths = t
|
||||||
|
ids_str_max = x_lengths - segment_size
|
||||||
|
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||||
|
ret = slice_segments(x, ids_str, segment_size)
|
||||||
|
return ret, ids_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
||||||
|
position = torch.arange(length, dtype=torch.float)
|
||||||
|
num_timescales = channels // 2
|
||||||
|
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
|
||||||
|
inv_timescales = min_timescale * torch.exp(
|
||||||
|
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
||||||
|
)
|
||||||
|
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
||||||
|
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
||||||
|
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
||||||
|
signal = signal.view(1, channels, length)
|
||||||
|
return signal
|
||||||
|
|
||||||
|
|
||||||
|
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
||||||
|
b, channels, length = x.size()
|
||||||
|
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||||
|
return x + signal.to(dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
|
||||||
|
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
||||||
|
b, channels, length = x.size()
|
||||||
|
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
||||||
|
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
||||||
|
|
||||||
|
|
||||||
|
def subsequent_mask(length):
|
||||||
|
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||||
|
n_channels_int = n_channels[0]
|
||||||
|
in_act = input_a + input_b
|
||||||
|
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||||
|
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||||
|
acts = t_act * s_act
|
||||||
|
return acts
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pad_shape(pad_shape):
|
||||||
|
l = pad_shape[::-1]
|
||||||
|
pad_shape = [item for sublist in l for item in sublist]
|
||||||
|
return pad_shape
|
||||||
|
|
||||||
|
|
||||||
|
def shift_1d(x):
|
||||||
|
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def sequence_mask(length, max_length=None):
|
||||||
|
if max_length is None:
|
||||||
|
max_length = length.max()
|
||||||
|
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||||
|
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_path(duration, mask):
|
||||||
|
"""
|
||||||
|
duration: [b, 1, t_x]
|
||||||
|
mask: [b, 1, t_y, t_x]
|
||||||
|
"""
|
||||||
|
device = duration.device
|
||||||
|
|
||||||
|
b, _, t_y, t_x = mask.shape
|
||||||
|
cum_duration = torch.cumsum(duration, -1)
|
||||||
|
|
||||||
|
cum_duration_flat = cum_duration.view(b * t_x)
|
||||||
|
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||||
|
path = path.view(b, t_x, t_y)
|
||||||
|
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||||
|
path = path.unsqueeze(1).transpose(2, 3) * mask
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
||||||
|
if isinstance(parameters, torch.Tensor):
|
||||||
|
parameters = [parameters]
|
||||||
|
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
||||||
|
norm_type = float(norm_type)
|
||||||
|
if clip_value is not None:
|
||||||
|
clip_value = float(clip_value)
|
||||||
|
|
||||||
|
total_norm = 0
|
||||||
|
for p in parameters:
|
||||||
|
param_norm = p.grad.data.norm(norm_type)
|
||||||
|
total_norm += param_norm.item() ** norm_type
|
||||||
|
if clip_value is not None:
|
||||||
|
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
||||||
|
total_norm = total_norm ** (1.0 / norm_type)
|
||||||
|
return total_norm
|
|
@ -0,0 +1,125 @@
|
||||||
|
import torch
|
||||||
|
import torch.utils.data
|
||||||
|
from librosa.filters import mel as librosa_mel_fn
|
||||||
|
|
||||||
|
MAX_WAV_VALUE = 32768.0
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||||
|
"""
|
||||||
|
PARAMS
|
||||||
|
------
|
||||||
|
C: compression factor
|
||||||
|
"""
|
||||||
|
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||||
|
|
||||||
|
|
||||||
|
def dynamic_range_decompression_torch(x, C=1):
|
||||||
|
"""
|
||||||
|
PARAMS
|
||||||
|
------
|
||||||
|
C: compression factor used to compress
|
||||||
|
"""
|
||||||
|
return torch.exp(x) / C
|
||||||
|
|
||||||
|
|
||||||
|
def spectral_normalize_torch(magnitudes):
|
||||||
|
output = dynamic_range_compression_torch(magnitudes)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def spectral_de_normalize_torch(magnitudes):
|
||||||
|
output = dynamic_range_decompression_torch(magnitudes)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
mel_basis = {}
|
||||||
|
hann_window = {}
|
||||||
|
|
||||||
|
|
||||||
|
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||||
|
if torch.min(y) < -1.0:
|
||||||
|
print("min value is ", torch.min(y))
|
||||||
|
if torch.max(y) > 1.0:
|
||||||
|
print("max value is ", torch.max(y))
|
||||||
|
|
||||||
|
global hann_window
|
||||||
|
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||||
|
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||||
|
if wnsize_dtype_device not in hann_window:
|
||||||
|
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||||
|
|
||||||
|
y = torch.nn.functional.pad(
|
||||||
|
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
||||||
|
)
|
||||||
|
y = y.squeeze(1)
|
||||||
|
|
||||||
|
spec = torch.stft(
|
||||||
|
y,
|
||||||
|
n_fft,
|
||||||
|
hop_length=hop_size,
|
||||||
|
win_length=win_size,
|
||||||
|
window=hann_window[wnsize_dtype_device],
|
||||||
|
center=center,
|
||||||
|
pad_mode="reflect",
|
||||||
|
normalized=False,
|
||||||
|
onesided=True,
|
||||||
|
return_complex=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
||||||
|
global mel_basis
|
||||||
|
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||||
|
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||||
|
if fmax_dtype_device not in mel_basis:
|
||||||
|
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||||
|
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||||
|
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||||
|
spec = spectral_normalize_torch(spec)
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||||
|
if torch.min(y) < -1.0:
|
||||||
|
print("min value is ", torch.min(y))
|
||||||
|
if torch.max(y) > 1.0:
|
||||||
|
print("max value is ", torch.max(y))
|
||||||
|
|
||||||
|
global mel_basis, hann_window
|
||||||
|
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||||
|
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||||
|
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||||
|
if fmax_dtype_device not in mel_basis:
|
||||||
|
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||||
|
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
||||||
|
if wnsize_dtype_device not in hann_window:
|
||||||
|
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||||
|
|
||||||
|
y = torch.nn.functional.pad(
|
||||||
|
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
||||||
|
)
|
||||||
|
y = y.squeeze(1)
|
||||||
|
|
||||||
|
spec = torch.stft(
|
||||||
|
y,
|
||||||
|
n_fft,
|
||||||
|
hop_length=hop_size,
|
||||||
|
win_length=win_size,
|
||||||
|
window=hann_window[wnsize_dtype_device],
|
||||||
|
center=center,
|
||||||
|
pad_mode="reflect",
|
||||||
|
normalized=False,
|
||||||
|
onesided=True,
|
||||||
|
return_complex=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||||
|
|
||||||
|
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||||
|
spec = spectral_normalize_torch(spec)
|
||||||
|
|
||||||
|
return spec
|
|
@ -0,0 +1,391 @@
|
||||||
|
import copy
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import scipy
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||||
|
|
||||||
|
import TTS.vc.modules.freevc.commons as commons
|
||||||
|
from TTS.vc.modules.freevc.commons import get_padding, init_weights
|
||||||
|
|
||||||
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, channels, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.gamma = nn.Parameter(torch.ones(channels))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(channels))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.transpose(1, -1)
|
||||||
|
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||||
|
return x.transpose(1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
class ConvReluNorm(nn.Module):
|
||||||
|
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.p_dropout = p_dropout
|
||||||
|
assert n_layers > 1, "Number of layers should be larger than 0."
|
||||||
|
|
||||||
|
self.conv_layers = nn.ModuleList()
|
||||||
|
self.norm_layers = nn.ModuleList()
|
||||||
|
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||||
|
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||||
|
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
||||||
|
for _ in range(n_layers - 1):
|
||||||
|
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||||
|
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||||
|
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||||
|
self.proj.weight.data.zero_()
|
||||||
|
self.proj.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x, x_mask):
|
||||||
|
x_org = x
|
||||||
|
for i in range(self.n_layers):
|
||||||
|
x = self.conv_layers[i](x * x_mask)
|
||||||
|
x = self.norm_layers[i](x)
|
||||||
|
x = self.relu_drop(x)
|
||||||
|
x = x_org + self.proj(x)
|
||||||
|
return x * x_mask
|
||||||
|
|
||||||
|
|
||||||
|
class DDSConv(nn.Module):
|
||||||
|
"""
|
||||||
|
Dialted and Depth-Separable Convolution
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.p_dropout = p_dropout
|
||||||
|
|
||||||
|
self.drop = nn.Dropout(p_dropout)
|
||||||
|
self.convs_sep = nn.ModuleList()
|
||||||
|
self.convs_1x1 = nn.ModuleList()
|
||||||
|
self.norms_1 = nn.ModuleList()
|
||||||
|
self.norms_2 = nn.ModuleList()
|
||||||
|
for i in range(n_layers):
|
||||||
|
dilation = kernel_size**i
|
||||||
|
padding = (kernel_size * dilation - dilation) // 2
|
||||||
|
self.convs_sep.append(
|
||||||
|
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
|
||||||
|
)
|
||||||
|
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
||||||
|
self.norms_1.append(LayerNorm(channels))
|
||||||
|
self.norms_2.append(LayerNorm(channels))
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None):
|
||||||
|
if g is not None:
|
||||||
|
x = x + g
|
||||||
|
for i in range(self.n_layers):
|
||||||
|
y = self.convs_sep[i](x * x_mask)
|
||||||
|
y = self.norms_1[i](y)
|
||||||
|
y = F.gelu(y)
|
||||||
|
y = self.convs_1x1[i](y)
|
||||||
|
y = self.norms_2[i](y)
|
||||||
|
y = F.gelu(y)
|
||||||
|
y = self.drop(y)
|
||||||
|
x = x + y
|
||||||
|
return x * x_mask
|
||||||
|
|
||||||
|
|
||||||
|
class WN(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
||||||
|
super(WN, self).__init__()
|
||||||
|
assert kernel_size % 2 == 1
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.kernel_size = (kernel_size,)
|
||||||
|
self.dilation_rate = dilation_rate
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.gin_channels = gin_channels
|
||||||
|
self.p_dropout = p_dropout
|
||||||
|
|
||||||
|
self.in_layers = torch.nn.ModuleList()
|
||||||
|
self.res_skip_layers = torch.nn.ModuleList()
|
||||||
|
self.drop = nn.Dropout(p_dropout)
|
||||||
|
|
||||||
|
if gin_channels != 0:
|
||||||
|
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
||||||
|
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
||||||
|
|
||||||
|
for i in range(n_layers):
|
||||||
|
dilation = dilation_rate**i
|
||||||
|
padding = int((kernel_size * dilation - dilation) / 2)
|
||||||
|
in_layer = torch.nn.Conv1d(
|
||||||
|
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
||||||
|
)
|
||||||
|
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
||||||
|
self.in_layers.append(in_layer)
|
||||||
|
|
||||||
|
# last one is not necessary
|
||||||
|
if i < n_layers - 1:
|
||||||
|
res_skip_channels = 2 * hidden_channels
|
||||||
|
else:
|
||||||
|
res_skip_channels = hidden_channels
|
||||||
|
|
||||||
|
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||||
|
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
||||||
|
self.res_skip_layers.append(res_skip_layer)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None, **kwargs):
|
||||||
|
output = torch.zeros_like(x)
|
||||||
|
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
||||||
|
|
||||||
|
if g is not None:
|
||||||
|
g = self.cond_layer(g)
|
||||||
|
|
||||||
|
for i in range(self.n_layers):
|
||||||
|
x_in = self.in_layers[i](x)
|
||||||
|
if g is not None:
|
||||||
|
cond_offset = i * 2 * self.hidden_channels
|
||||||
|
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||||
|
else:
|
||||||
|
g_l = torch.zeros_like(x_in)
|
||||||
|
|
||||||
|
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
||||||
|
acts = self.drop(acts)
|
||||||
|
|
||||||
|
res_skip_acts = self.res_skip_layers[i](acts)
|
||||||
|
if i < self.n_layers - 1:
|
||||||
|
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
||||||
|
x = (x + res_acts) * x_mask
|
||||||
|
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
||||||
|
else:
|
||||||
|
output = output + res_skip_acts
|
||||||
|
return output * x_mask
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
if self.gin_channels != 0:
|
||||||
|
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
||||||
|
for l in self.in_layers:
|
||||||
|
torch.nn.utils.remove_weight_norm(l)
|
||||||
|
for l in self.res_skip_layers:
|
||||||
|
torch.nn.utils.remove_weight_norm(l)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock1(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||||
|
super(ResBlock1, self).__init__()
|
||||||
|
self.convs1 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2]),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.convs1.apply(init_weights)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.convs2.apply(init_weights)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask=None):
|
||||||
|
for c1, c2 in zip(self.convs1, self.convs2):
|
||||||
|
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
if x_mask is not None:
|
||||||
|
xt = xt * x_mask
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||||
|
if x_mask is not None:
|
||||||
|
xt = xt * x_mask
|
||||||
|
xt = c2(xt)
|
||||||
|
x = xt + x
|
||||||
|
if x_mask is not None:
|
||||||
|
x = x * x_mask
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for l in self.convs1:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
for l in self.convs2:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock2(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||||
|
super(ResBlock2, self).__init__()
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
weight_norm(
|
||||||
|
Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.convs.apply(init_weights)
|
||||||
|
|
||||||
|
def forward(self, x, x_mask=None):
|
||||||
|
for c in self.convs:
|
||||||
|
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
if x_mask is not None:
|
||||||
|
xt = xt * x_mask
|
||||||
|
xt = c(xt)
|
||||||
|
x = xt + x
|
||||||
|
if x_mask is not None:
|
||||||
|
x = x * x_mask
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for l in self.convs:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
|
||||||
|
|
||||||
|
class Log(nn.Module):
|
||||||
|
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||||
|
if not reverse:
|
||||||
|
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
||||||
|
logdet = torch.sum(-y, [1, 2])
|
||||||
|
return y, logdet
|
||||||
|
else:
|
||||||
|
x = torch.exp(x) * x_mask
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Flip(nn.Module):
|
||||||
|
def forward(self, x, *args, reverse=False, **kwargs):
|
||||||
|
x = torch.flip(x, [1])
|
||||||
|
if not reverse:
|
||||||
|
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
||||||
|
return x, logdet
|
||||||
|
else:
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseAffine(nn.Module):
|
||||||
|
def __init__(self, channels):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.m = nn.Parameter(torch.zeros(channels, 1))
|
||||||
|
self.logs = nn.Parameter(torch.zeros(channels, 1))
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, reverse=False, **kwargs):
|
||||||
|
if not reverse:
|
||||||
|
y = self.m + torch.exp(self.logs) * x
|
||||||
|
y = y * x_mask
|
||||||
|
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
||||||
|
return y, logdet
|
||||||
|
else:
|
||||||
|
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualCouplingLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels,
|
||||||
|
hidden_channels,
|
||||||
|
kernel_size,
|
||||||
|
dilation_rate,
|
||||||
|
n_layers,
|
||||||
|
p_dropout=0,
|
||||||
|
gin_channels=0,
|
||||||
|
mean_only=False,
|
||||||
|
):
|
||||||
|
assert channels % 2 == 0, "channels should be divisible by 2"
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.hidden_channels = hidden_channels
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.dilation_rate = dilation_rate
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.half_channels = channels // 2
|
||||||
|
self.mean_only = mean_only
|
||||||
|
|
||||||
|
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
||||||
|
self.enc = WN(
|
||||||
|
hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels
|
||||||
|
)
|
||||||
|
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
||||||
|
self.post.weight.data.zero_()
|
||||||
|
self.post.bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, x, x_mask, g=None, reverse=False):
|
||||||
|
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
||||||
|
h = self.pre(x0) * x_mask
|
||||||
|
h = self.enc(h, x_mask, g=g)
|
||||||
|
stats = self.post(h) * x_mask
|
||||||
|
if not self.mean_only:
|
||||||
|
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
||||||
|
else:
|
||||||
|
m = stats
|
||||||
|
logs = torch.zeros_like(m)
|
||||||
|
|
||||||
|
if not reverse:
|
||||||
|
x1 = m + x1 * torch.exp(logs) * x_mask
|
||||||
|
x = torch.cat([x0, x1], 1)
|
||||||
|
logdet = torch.sum(logs, [1, 2])
|
||||||
|
return x, logdet
|
||||||
|
else:
|
||||||
|
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
||||||
|
x = torch.cat([x0, x1], 1)
|
||||||
|
return x
|
|
@ -0,0 +1,65 @@
|
||||||
|
import struct
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
# import webrtcvad
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
from scipy.ndimage.morphology import binary_dilation
|
||||||
|
|
||||||
|
from TTS.vc.modules.freevc.speaker_encoder.hparams import *
|
||||||
|
|
||||||
|
int16_max = (2**15) - 1
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], source_sr: Optional[int] = None):
|
||||||
|
"""
|
||||||
|
Applies the preprocessing operations used in training the Speaker Encoder to a waveform
|
||||||
|
either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
|
||||||
|
|
||||||
|
:param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
|
||||||
|
just .wav), either the waveform as a numpy array of floats.
|
||||||
|
:param source_sr: if passing an audio waveform, the sampling rate of the waveform before
|
||||||
|
preprocessing. After preprocessing, the waveform's sampling rate will match the data
|
||||||
|
hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
|
||||||
|
this argument will be ignored.
|
||||||
|
"""
|
||||||
|
# Load the wav from disk if needed
|
||||||
|
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
||||||
|
wav, source_sr = librosa.load(fpath_or_wav, sr=None)
|
||||||
|
else:
|
||||||
|
wav = fpath_or_wav
|
||||||
|
|
||||||
|
# Resample the wav if needed
|
||||||
|
if source_sr is not None and source_sr != sampling_rate:
|
||||||
|
wav = librosa.resample(wav, source_sr, sampling_rate)
|
||||||
|
|
||||||
|
# Apply the preprocessing: normalize volume and shorten long silences
|
||||||
|
wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
|
||||||
|
wav = trim_long_silences(wav)
|
||||||
|
|
||||||
|
return wav
|
||||||
|
|
||||||
|
|
||||||
|
def wav_to_mel_spectrogram(wav):
|
||||||
|
"""
|
||||||
|
Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
|
||||||
|
Note: this not a log-mel spectrogram.
|
||||||
|
"""
|
||||||
|
frames = librosa.feature.melspectrogram(
|
||||||
|
y=wav,
|
||||||
|
sr=sampling_rate,
|
||||||
|
n_fft=int(sampling_rate * mel_window_length / 1000),
|
||||||
|
hop_length=int(sampling_rate * mel_window_step / 1000),
|
||||||
|
n_mels=mel_n_channels,
|
||||||
|
)
|
||||||
|
return frames.astype(np.float32).T
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
|
||||||
|
if increase_only and decrease_only:
|
||||||
|
raise ValueError("Both increase only and decrease only are set")
|
||||||
|
dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav**2))
|
||||||
|
if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
|
||||||
|
return wav
|
||||||
|
return wav * (10 ** (dBFS_change / 20))
|
|
@ -0,0 +1,31 @@
|
||||||
|
## Mel-filterbank
|
||||||
|
mel_window_length = 25 # In milliseconds
|
||||||
|
mel_window_step = 10 # In milliseconds
|
||||||
|
mel_n_channels = 40
|
||||||
|
|
||||||
|
|
||||||
|
## Audio
|
||||||
|
sampling_rate = 16000
|
||||||
|
# Number of spectrogram frames in a partial utterance
|
||||||
|
partials_n_frames = 160 # 1600 ms
|
||||||
|
|
||||||
|
|
||||||
|
## Voice Activation Detection
|
||||||
|
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
||||||
|
# This sets the granularity of the VAD. Should not need to be changed.
|
||||||
|
vad_window_length = 30 # In milliseconds
|
||||||
|
# Number of frames to average together when performing the moving average smoothing.
|
||||||
|
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
||||||
|
vad_moving_average_width = 8
|
||||||
|
# Maximum number of consecutive silent frames a segment can have.
|
||||||
|
vad_max_silence_length = 6
|
||||||
|
|
||||||
|
|
||||||
|
## Audio volume normalization
|
||||||
|
audio_norm_target_dBFS = -30
|
||||||
|
|
||||||
|
|
||||||
|
## Model parameters
|
||||||
|
model_hidden_size = 256
|
||||||
|
model_embedding_size = 256
|
||||||
|
model_num_layers = 3
|
|
@ -0,0 +1,175 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from time import perf_counter as timer
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
from TTS.vc.modules.freevc.speaker_encoder import audio
|
||||||
|
from TTS.vc.modules.freevc.speaker_encoder.hparams import *
|
||||||
|
|
||||||
|
|
||||||
|
class SpeakerEncoder(nn.Module):
|
||||||
|
def __init__(self, weights_fpath, device: Union[str, torch.device] = None, verbose=True):
|
||||||
|
"""
|
||||||
|
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
|
||||||
|
If None, defaults to cuda if it is available on your machine, otherwise the model will
|
||||||
|
run on cpu. Outputs are always returned on the cpu, as numpy arrays.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Define the network
|
||||||
|
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
||||||
|
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
# Get the target device
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
elif isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Load the pretrained model'speaker weights
|
||||||
|
# weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
|
||||||
|
# if not weights_fpath.exists():
|
||||||
|
# raise Exception("Couldn't find the voice encoder pretrained model at %s." %
|
||||||
|
# weights_fpath)
|
||||||
|
|
||||||
|
start = timer()
|
||||||
|
checkpoint = load_fsspec(weights_fpath, map_location="cpu")
|
||||||
|
|
||||||
|
self.load_state_dict(checkpoint["model_state"], strict=False)
|
||||||
|
self.to(device)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print("Loaded the voice encoder model on %s in %.2f seconds." % (device.type, timer() - start))
|
||||||
|
|
||||||
|
def forward(self, mels: torch.FloatTensor):
|
||||||
|
"""
|
||||||
|
Computes the embeddings of a batch of utterance spectrograms.
|
||||||
|
:param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape
|
||||||
|
(batch_size, n_frames, n_channels)
|
||||||
|
:return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size).
|
||||||
|
Embeddings are positive and L2-normed, thus they lay in the range [0, 1].
|
||||||
|
"""
|
||||||
|
# Pass the input through the LSTM layers and retrieve the final hidden state of the last
|
||||||
|
# layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings.
|
||||||
|
_, (hidden, _) = self.lstm(mels)
|
||||||
|
embeds_raw = self.relu(self.linear(hidden[-1]))
|
||||||
|
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_partial_slices(n_samples: int, rate, min_coverage):
|
||||||
|
"""
|
||||||
|
Computes where to split an utterance waveform and its corresponding mel spectrogram to
|
||||||
|
obtain partial utterances of <partials_n_frames> each. Both the waveform and the
|
||||||
|
mel spectrogram slices are returned, so as to make each partial utterance waveform
|
||||||
|
correspond to its spectrogram.
|
||||||
|
|
||||||
|
The returned ranges may be indexing further than the length of the waveform. It is
|
||||||
|
recommended that you pad the waveform with zeros up to wav_slices[-1].stop.
|
||||||
|
|
||||||
|
:param n_samples: the number of samples in the waveform
|
||||||
|
:param rate: how many partial utterances should occur per second. Partial utterances must
|
||||||
|
cover the span of the entire utterance, thus the rate should not be lower than the inverse
|
||||||
|
of the duration of a partial utterance. By default, partial utterances are 1.6s long and
|
||||||
|
the minimum rate is thus 0.625.
|
||||||
|
:param min_coverage: when reaching the last partial utterance, it may or may not have
|
||||||
|
enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
|
||||||
|
then the last partial utterance will be considered by zero-padding the audio. Otherwise,
|
||||||
|
it will be discarded. If there aren't enough frames for one partial utterance,
|
||||||
|
this parameter is ignored so that the function always returns at least one slice.
|
||||||
|
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
||||||
|
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
||||||
|
utterances.
|
||||||
|
"""
|
||||||
|
assert 0 < min_coverage <= 1
|
||||||
|
|
||||||
|
# Compute how many frames separate two partial utterances
|
||||||
|
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||||
|
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||||
|
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
|
||||||
|
assert 0 < frame_step, "The rate is too high"
|
||||||
|
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % (
|
||||||
|
sampling_rate / (samples_per_frame * partials_n_frames)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute the slices
|
||||||
|
wav_slices, mel_slices = [], []
|
||||||
|
steps = max(1, n_frames - partials_n_frames + frame_step + 1)
|
||||||
|
for i in range(0, steps, frame_step):
|
||||||
|
mel_range = np.array([i, i + partials_n_frames])
|
||||||
|
wav_range = mel_range * samples_per_frame
|
||||||
|
mel_slices.append(slice(*mel_range))
|
||||||
|
wav_slices.append(slice(*wav_range))
|
||||||
|
|
||||||
|
# Evaluate whether extra padding is warranted or not
|
||||||
|
last_wav_range = wav_slices[-1]
|
||||||
|
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
|
||||||
|
if coverage < min_coverage and len(mel_slices) > 1:
|
||||||
|
mel_slices = mel_slices[:-1]
|
||||||
|
wav_slices = wav_slices[:-1]
|
||||||
|
|
||||||
|
return wav_slices, mel_slices
|
||||||
|
|
||||||
|
def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
|
||||||
|
"""
|
||||||
|
Computes an embedding for a single utterance. The utterance is divided in partial
|
||||||
|
utterances and an embedding is computed for each. The complete utterance embedding is the
|
||||||
|
L2-normed average embedding of the partial utterances.
|
||||||
|
|
||||||
|
TODO: independent batched version of this function
|
||||||
|
|
||||||
|
:param wav: a preprocessed utterance waveform as a numpy array of float32
|
||||||
|
:param return_partials: if True, the partial embeddings will also be returned along with
|
||||||
|
the wav slices corresponding to each partial utterance.
|
||||||
|
:param rate: how many partial utterances should occur per second. Partial utterances must
|
||||||
|
cover the span of the entire utterance, thus the rate should not be lower than the inverse
|
||||||
|
of the duration of a partial utterance. By default, partial utterances are 1.6s long and
|
||||||
|
the minimum rate is thus 0.625.
|
||||||
|
:param min_coverage: when reaching the last partial utterance, it may or may not have
|
||||||
|
enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
|
||||||
|
then the last partial utterance will be considered by zero-padding the audio. Otherwise,
|
||||||
|
it will be discarded. If there aren't enough frames for one partial utterance,
|
||||||
|
this parameter is ignored so that the function always returns at least one slice.
|
||||||
|
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
|
||||||
|
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
|
||||||
|
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
|
||||||
|
returned.
|
||||||
|
"""
|
||||||
|
# Compute where to split the utterance into partials and pad the waveform with zeros if
|
||||||
|
# the partial utterances cover a larger range.
|
||||||
|
wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
|
||||||
|
max_wave_length = wav_slices[-1].stop
|
||||||
|
if max_wave_length >= len(wav):
|
||||||
|
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
||||||
|
|
||||||
|
# Split the utterance into partials and forward them through the model
|
||||||
|
mel = audio.wav_to_mel_spectrogram(wav)
|
||||||
|
mels = np.array([mel[s] for s in mel_slices])
|
||||||
|
with torch.no_grad():
|
||||||
|
mels = torch.from_numpy(mels).to(self.device)
|
||||||
|
partial_embeds = self(mels).cpu().numpy()
|
||||||
|
|
||||||
|
# Compute the utterance embedding from the partial embeddings
|
||||||
|
raw_embed = np.mean(partial_embeds, axis=0)
|
||||||
|
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
||||||
|
|
||||||
|
if return_partials:
|
||||||
|
return embed, partial_embeds, wav_slices
|
||||||
|
return embed
|
||||||
|
|
||||||
|
def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
|
||||||
|
"""
|
||||||
|
Compute the embedding of a collection of wavs (presumably from the same speaker) by
|
||||||
|
averaging their embedding and L2-normalizing it.
|
||||||
|
|
||||||
|
:param wavs: list of wavs a numpy arrays of float32.
|
||||||
|
:param kwargs: extra arguments to embed_utterance()
|
||||||
|
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
|
||||||
|
"""
|
||||||
|
raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0)
|
||||||
|
return raw_embed / np.linalg.norm(raw_embed, 2)
|
|
@ -0,0 +1,35 @@
|
||||||
|
import os
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from TTS.utils.generic_utils import get_user_data_dir
|
||||||
|
from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
|
||||||
|
|
||||||
|
model_uri = "https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/WavLM-Large.pt"
|
||||||
|
|
||||||
|
|
||||||
|
def get_wavlm(device="cpu"):
|
||||||
|
"""Download the model and return the model object."""
|
||||||
|
|
||||||
|
output_path = get_user_data_dir("tts")
|
||||||
|
|
||||||
|
output_path = os.path.join(output_path, "wavlm")
|
||||||
|
if not os.path.exists(output_path):
|
||||||
|
os.makedirs(output_path)
|
||||||
|
|
||||||
|
output_path = os.path.join(output_path, "WavLM-Large.pt")
|
||||||
|
if not os.path.exists(output_path):
|
||||||
|
print(f" > Downloading WavLM model to {output_path} ...")
|
||||||
|
urllib.request.urlretrieve(model_uri, output_path)
|
||||||
|
|
||||||
|
checkpoint = torch.load(output_path, map_location=torch.device(device))
|
||||||
|
cfg = WavLMConfig(checkpoint["cfg"])
|
||||||
|
wavlm = WavLM(cfg).to(device)
|
||||||
|
wavlm.load_state_dict(checkpoint["model"])
|
||||||
|
wavlm.eval()
|
||||||
|
return wavlm
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
wavlm = get_wavlm()
|
|
@ -0,0 +1,99 @@
|
||||||
|
{
|
||||||
|
"_name_or_path": "./wavlm-large/",
|
||||||
|
"activation_dropout": 0.0,
|
||||||
|
"adapter_kernel_size": 3,
|
||||||
|
"adapter_stride": 2,
|
||||||
|
"add_adapter": false,
|
||||||
|
"apply_spec_augment": true,
|
||||||
|
"architectures": [
|
||||||
|
"WavLMModel"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.1,
|
||||||
|
"bos_token_id": 1,
|
||||||
|
"classifier_proj_size": 256,
|
||||||
|
"codevector_dim": 768,
|
||||||
|
"contrastive_logits_temperature": 0.1,
|
||||||
|
"conv_bias": false,
|
||||||
|
"conv_dim": [
|
||||||
|
512,
|
||||||
|
512,
|
||||||
|
512,
|
||||||
|
512,
|
||||||
|
512,
|
||||||
|
512,
|
||||||
|
512
|
||||||
|
],
|
||||||
|
"conv_kernel": [
|
||||||
|
10,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
3,
|
||||||
|
2,
|
||||||
|
2
|
||||||
|
],
|
||||||
|
"conv_stride": [
|
||||||
|
5,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
2
|
||||||
|
],
|
||||||
|
"ctc_loss_reduction": "sum",
|
||||||
|
"ctc_zero_infinity": false,
|
||||||
|
"diversity_loss_weight": 0.1,
|
||||||
|
"do_stable_layer_norm": true,
|
||||||
|
"eos_token_id": 2,
|
||||||
|
"feat_extract_activation": "gelu",
|
||||||
|
"feat_extract_dropout": 0.0,
|
||||||
|
"feat_extract_norm": "layer",
|
||||||
|
"feat_proj_dropout": 0.1,
|
||||||
|
"feat_quantizer_dropout": 0.0,
|
||||||
|
"final_dropout": 0.0,
|
||||||
|
"gradient_checkpointing": false,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout": 0.1,
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 4096,
|
||||||
|
"layer_norm_eps": 1e-05,
|
||||||
|
"layerdrop": 0.1,
|
||||||
|
"mask_channel_length": 10,
|
||||||
|
"mask_channel_min_space": 1,
|
||||||
|
"mask_channel_other": 0.0,
|
||||||
|
"mask_channel_prob": 0.0,
|
||||||
|
"mask_channel_selection": "static",
|
||||||
|
"mask_feature_length": 10,
|
||||||
|
"mask_feature_min_masks": 0,
|
||||||
|
"mask_feature_prob": 0.0,
|
||||||
|
"mask_time_length": 10,
|
||||||
|
"mask_time_min_masks": 2,
|
||||||
|
"mask_time_min_space": 1,
|
||||||
|
"mask_time_other": 0.0,
|
||||||
|
"mask_time_prob": 0.075,
|
||||||
|
"mask_time_selection": "static",
|
||||||
|
"max_bucket_distance": 800,
|
||||||
|
"model_type": "wavlm",
|
||||||
|
"num_adapter_layers": 3,
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_buckets": 320,
|
||||||
|
"num_codevector_groups": 2,
|
||||||
|
"num_codevectors_per_group": 320,
|
||||||
|
"num_conv_pos_embedding_groups": 16,
|
||||||
|
"num_conv_pos_embeddings": 128,
|
||||||
|
"num_ctc_classes": 80,
|
||||||
|
"num_feat_extract_layers": 7,
|
||||||
|
"num_hidden_layers": 24,
|
||||||
|
"num_negatives": 100,
|
||||||
|
"output_hidden_size": 1024,
|
||||||
|
"pad_token_id": 0,
|
||||||
|
"proj_codevector_dim": 768,
|
||||||
|
"replace_prob": 0.5,
|
||||||
|
"tokenizer_class": "Wav2Vec2CTCTokenizer",
|
||||||
|
"torch_dtype": "float32",
|
||||||
|
"transformers_version": "4.15.0.dev0",
|
||||||
|
"use_weighted_layer_sum": false,
|
||||||
|
"vocab_size": 32
|
||||||
|
}
|
|
@ -0,0 +1,768 @@
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on fairseq code bases
|
||||||
|
# https://github.com/pytorch/fairseq
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
class TransposeLast(nn.Module):
|
||||||
|
def __init__(self, deconstruct_idx=None):
|
||||||
|
super().__init__()
|
||||||
|
self.deconstruct_idx = deconstruct_idx
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.deconstruct_idx is not None:
|
||||||
|
x = x[self.deconstruct_idx]
|
||||||
|
return x.transpose(-2, -1)
|
||||||
|
|
||||||
|
|
||||||
|
class Fp32LayerNorm(nn.LayerNorm):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = F.layer_norm(
|
||||||
|
input.float(),
|
||||||
|
self.normalized_shape,
|
||||||
|
self.weight.float() if self.weight is not None else None,
|
||||||
|
self.bias.float() if self.bias is not None else None,
|
||||||
|
self.eps,
|
||||||
|
)
|
||||||
|
return output.type_as(input)
|
||||||
|
|
||||||
|
|
||||||
|
class Fp32GroupNorm(nn.GroupNorm):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
output = F.group_norm(
|
||||||
|
input.float(),
|
||||||
|
self.num_groups,
|
||||||
|
self.weight.float() if self.weight is not None else None,
|
||||||
|
self.bias.float() if self.bias is not None else None,
|
||||||
|
self.eps,
|
||||||
|
)
|
||||||
|
return output.type_as(input)
|
||||||
|
|
||||||
|
|
||||||
|
class GradMultiply(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, scale):
|
||||||
|
ctx.scale = scale
|
||||||
|
res = x.new(x)
|
||||||
|
return res
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad):
|
||||||
|
return grad * ctx.scale, None
|
||||||
|
|
||||||
|
|
||||||
|
class SamePad(nn.Module):
|
||||||
|
def __init__(self, kernel_size, causal=False):
|
||||||
|
super().__init__()
|
||||||
|
if causal:
|
||||||
|
self.remove = kernel_size - 1
|
||||||
|
else:
|
||||||
|
self.remove = 1 if kernel_size % 2 == 0 else 0
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.remove > 0:
|
||||||
|
x = x[:, :, : -self.remove]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Swish(nn.Module):
|
||||||
|
"""Swish function"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Construct an MultiHeadedAttention object."""
|
||||||
|
super(Swish, self).__init__()
|
||||||
|
self.act = torch.nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * self.act(x)
|
||||||
|
|
||||||
|
|
||||||
|
class GLU_Linear(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
||||||
|
super(GLU_Linear, self).__init__()
|
||||||
|
|
||||||
|
self.glu_type = glu_type
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
if glu_type == "sigmoid":
|
||||||
|
self.glu_act = torch.nn.Sigmoid()
|
||||||
|
elif glu_type == "swish":
|
||||||
|
self.glu_act = Swish()
|
||||||
|
elif glu_type == "relu":
|
||||||
|
self.glu_act = torch.nn.ReLU()
|
||||||
|
elif glu_type == "gelu":
|
||||||
|
self.glu_act = torch.nn.GELU()
|
||||||
|
|
||||||
|
if bias_in_glu:
|
||||||
|
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
||||||
|
else:
|
||||||
|
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
||||||
|
x = self.linear(x)
|
||||||
|
|
||||||
|
if self.glu_type == "bilinear":
|
||||||
|
x = x[:, :, 0 : self.output_dim] * x[:, :, self.output_dim : self.output_dim * 2]
|
||||||
|
else:
|
||||||
|
x = x[:, :, 0 : self.output_dim] * self.glu_act(x[:, :, self.output_dim : self.output_dim * 2])
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def gelu_accurate(x):
|
||||||
|
if not hasattr(gelu_accurate, "_a"):
|
||||||
|
gelu_accurate._a = math.sqrt(2 / math.pi)
|
||||||
|
return 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
|
|
||||||
|
|
||||||
|
def gelu(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.nn.functional.gelu(x.float()).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation_fn(activation: str):
|
||||||
|
"""Returns the activation function corresponding to `activation`"""
|
||||||
|
|
||||||
|
if activation == "relu":
|
||||||
|
return F.relu
|
||||||
|
elif activation == "gelu":
|
||||||
|
return gelu
|
||||||
|
elif activation == "gelu_fast":
|
||||||
|
warnings.warn("--activation-fn=gelu_fast has been renamed to gelu_accurate")
|
||||||
|
return gelu_accurate
|
||||||
|
elif activation == "gelu_accurate":
|
||||||
|
return gelu_accurate
|
||||||
|
elif activation == "tanh":
|
||||||
|
return torch.tanh
|
||||||
|
elif activation == "linear":
|
||||||
|
return lambda x: x
|
||||||
|
elif activation == "glu":
|
||||||
|
return lambda x: x
|
||||||
|
else:
|
||||||
|
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
||||||
|
|
||||||
|
|
||||||
|
def init_bert_params(module):
|
||||||
|
"""
|
||||||
|
Initialize the weights specific to the BERT Model.
|
||||||
|
This overrides the default initializations depending on the specified arguments.
|
||||||
|
1. If normal_init_linear_weights is set then weights of linear
|
||||||
|
layer will be initialized using the normal distribution and
|
||||||
|
bais will be set to the specified value.
|
||||||
|
2. If normal_init_embed_weights is set then weights of embedding
|
||||||
|
layer will be initialized using the normal distribution.
|
||||||
|
3. If normal_init_proj_weights is set then weights of
|
||||||
|
in_project_weight for MultiHeadAttention initialized using
|
||||||
|
the normal distribution (to be validated).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def normal_(data):
|
||||||
|
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
||||||
|
# so that the RNG is consistent with and without FSDP
|
||||||
|
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
||||||
|
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
normal_(module.weight.data)
|
||||||
|
if module.bias is not None:
|
||||||
|
module.bias.data.zero_()
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
normal_(module.weight.data)
|
||||||
|
if module.padding_idx is not None:
|
||||||
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
if isinstance(module, MultiheadAttention):
|
||||||
|
normal_(module.q_proj.weight.data)
|
||||||
|
normal_(module.k_proj.weight.data)
|
||||||
|
normal_(module.v_proj.weight.data)
|
||||||
|
|
||||||
|
|
||||||
|
def quant_noise(module, p, block_size):
|
||||||
|
"""
|
||||||
|
Wraps modules and applies quantization noise to the weights for
|
||||||
|
subsequent quantization with Iterative Product Quantization as
|
||||||
|
described in "Training with Quantization Noise for Extreme Model Compression"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- module: nn.Module
|
||||||
|
- p: amount of Quantization Noise
|
||||||
|
- block_size: size of the blocks for subsequent quantization with iPQ
|
||||||
|
|
||||||
|
Remarks:
|
||||||
|
- Module weights must have the right sizes wrt the block size
|
||||||
|
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
||||||
|
- For more detail on how to quantize by blocks with convolutional weights,
|
||||||
|
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
||||||
|
- We implement the simplest form of noise here as stated in the paper
|
||||||
|
which consists in randomly dropping blocks
|
||||||
|
"""
|
||||||
|
|
||||||
|
# if no quantization noise, don't register hook
|
||||||
|
if p <= 0:
|
||||||
|
return module
|
||||||
|
|
||||||
|
# supported modules
|
||||||
|
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
||||||
|
|
||||||
|
# test whether module.weight has the right sizes wrt block_size
|
||||||
|
is_conv = module.weight.ndim == 4
|
||||||
|
|
||||||
|
# 2D matrix
|
||||||
|
if not is_conv:
|
||||||
|
assert module.weight.size(1) % block_size == 0, "Input features must be a multiple of block sizes"
|
||||||
|
|
||||||
|
# 4D matrix
|
||||||
|
else:
|
||||||
|
# 1x1 convolutions
|
||||||
|
if module.kernel_size == (1, 1):
|
||||||
|
assert module.in_channels % block_size == 0, "Input channels must be a multiple of block sizes"
|
||||||
|
# regular convolutions
|
||||||
|
else:
|
||||||
|
k = module.kernel_size[0] * module.kernel_size[1]
|
||||||
|
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
||||||
|
|
||||||
|
def _forward_pre_hook(mod, input):
|
||||||
|
# no noise for evaluation
|
||||||
|
if mod.training:
|
||||||
|
if not is_conv:
|
||||||
|
# gather weight and sizes
|
||||||
|
weight = mod.weight
|
||||||
|
in_features = weight.size(1)
|
||||||
|
out_features = weight.size(0)
|
||||||
|
|
||||||
|
# split weight matrix into blocks and randomly drop selected blocks
|
||||||
|
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# gather weight and sizes
|
||||||
|
weight = mod.weight
|
||||||
|
in_channels = mod.in_channels
|
||||||
|
out_channels = mod.out_channels
|
||||||
|
|
||||||
|
# split weight matrix into blocks and randomly drop selected blocks
|
||||||
|
if mod.kernel_size == (1, 1):
|
||||||
|
mask = torch.zeros(
|
||||||
|
int(in_channels // block_size * out_channels),
|
||||||
|
device=weight.device,
|
||||||
|
)
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
||||||
|
else:
|
||||||
|
mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
|
||||||
|
mask.bernoulli_(p)
|
||||||
|
mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
||||||
|
|
||||||
|
# scale weights and apply mask
|
||||||
|
mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
|
||||||
|
s = 1 / (1 - p)
|
||||||
|
mod.weight.data = s * weight.masked_fill(mask, 0)
|
||||||
|
|
||||||
|
module.register_forward_pre_hook(_forward_pre_hook)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttention(nn.Module):
|
||||||
|
"""Multi-headed attention.
|
||||||
|
|
||||||
|
See "Attention Is All You Need" for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
kdim=None,
|
||||||
|
vdim=None,
|
||||||
|
dropout=0.0,
|
||||||
|
bias=True,
|
||||||
|
add_bias_kv=False,
|
||||||
|
add_zero_attn=False,
|
||||||
|
self_attention=False,
|
||||||
|
encoder_decoder_attention=False,
|
||||||
|
q_noise=0.0,
|
||||||
|
qn_block_size=8,
|
||||||
|
has_relative_attention_bias=False,
|
||||||
|
num_buckets=32,
|
||||||
|
max_distance=128,
|
||||||
|
gru_rel_pos=False,
|
||||||
|
rescale_init=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.kdim = kdim if kdim is not None else embed_dim
|
||||||
|
self.vdim = vdim if vdim is not None else embed_dim
|
||||||
|
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout_module = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.has_relative_attention_bias = has_relative_attention_bias
|
||||||
|
self.num_buckets = num_buckets
|
||||||
|
self.max_distance = max_distance
|
||||||
|
if self.has_relative_attention_bias:
|
||||||
|
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
||||||
|
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
self.q_head_dim = self.head_dim
|
||||||
|
self.k_head_dim = self.head_dim
|
||||||
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
|
self.self_attention = self_attention
|
||||||
|
self.encoder_decoder_attention = encoder_decoder_attention
|
||||||
|
|
||||||
|
assert not self.self_attention or self.qkv_same_dim, (
|
||||||
|
"Self-attention requires query, key and " "value to be of the same size"
|
||||||
|
)
|
||||||
|
|
||||||
|
k_bias = True
|
||||||
|
if rescale_init:
|
||||||
|
k_bias = False
|
||||||
|
|
||||||
|
k_embed_dim = embed_dim
|
||||||
|
q_embed_dim = embed_dim
|
||||||
|
|
||||||
|
self.k_proj = quant_noise(nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size)
|
||||||
|
self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
||||||
|
self.q_proj = quant_noise(nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size)
|
||||||
|
|
||||||
|
self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
||||||
|
|
||||||
|
if add_bias_kv:
|
||||||
|
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||||
|
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
||||||
|
else:
|
||||||
|
self.bias_k = self.bias_v = None
|
||||||
|
|
||||||
|
self.add_zero_attn = add_zero_attn
|
||||||
|
|
||||||
|
self.gru_rel_pos = gru_rel_pos
|
||||||
|
if self.gru_rel_pos:
|
||||||
|
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
||||||
|
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
if self.qkv_same_dim:
|
||||||
|
# Empirically observed the convergence to be much better with
|
||||||
|
# the scaled initialization
|
||||||
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||||
|
else:
|
||||||
|
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||||
|
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||||
|
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||||
|
|
||||||
|
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||||
|
if self.out_proj.bias is not None:
|
||||||
|
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||||
|
if self.bias_k is not None:
|
||||||
|
nn.init.xavier_normal_(self.bias_k)
|
||||||
|
if self.bias_v is not None:
|
||||||
|
nn.init.xavier_normal_(self.bias_v)
|
||||||
|
if self.has_relative_attention_bias:
|
||||||
|
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
||||||
|
|
||||||
|
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
||||||
|
num_buckets = self.num_buckets
|
||||||
|
max_distance = self.max_distance
|
||||||
|
relative_buckets = 0
|
||||||
|
|
||||||
|
if bidirectional:
|
||||||
|
num_buckets = num_buckets // 2
|
||||||
|
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
||||||
|
relative_positions = torch.abs(relative_positions)
|
||||||
|
else:
|
||||||
|
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
||||||
|
|
||||||
|
max_exact = num_buckets // 2
|
||||||
|
is_small = relative_positions < max_exact
|
||||||
|
|
||||||
|
relative_postion_if_large = max_exact + (
|
||||||
|
torch.log(relative_positions.float() / max_exact)
|
||||||
|
/ math.log(max_distance / max_exact)
|
||||||
|
* (num_buckets - max_exact)
|
||||||
|
).to(torch.long)
|
||||||
|
relative_postion_if_large = torch.min(
|
||||||
|
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
||||||
|
return relative_buckets
|
||||||
|
|
||||||
|
def compute_bias(self, query_length, key_length):
|
||||||
|
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
||||||
|
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
||||||
|
relative_position = memory_position - context_position
|
||||||
|
relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True)
|
||||||
|
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
||||||
|
values = self.relative_attention_bias(relative_position_bucket)
|
||||||
|
values = values.permute([2, 0, 1])
|
||||||
|
return values
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
key: Optional[Tensor],
|
||||||
|
value: Optional[Tensor],
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||||
|
need_weights: bool = True,
|
||||||
|
static_kv: bool = False,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
before_softmax: bool = False,
|
||||||
|
need_head_weights: bool = False,
|
||||||
|
position_bias: Optional[Tensor] = None,
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||||
|
"""Input shape: Time x Batch x Channel
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_padding_mask (ByteTensor, optional): mask to exclude
|
||||||
|
keys that are pads, of shape `(batch, src_len)`, where
|
||||||
|
padding elements are indicated by 1s.
|
||||||
|
need_weights (bool, optional): return the attention weights,
|
||||||
|
averaged over heads (default: False).
|
||||||
|
attn_mask (ByteTensor, optional): typically used to
|
||||||
|
implement causal attention, where the mask prevents the
|
||||||
|
attention from looking forward in time (default: None).
|
||||||
|
before_softmax (bool, optional): return the raw attention
|
||||||
|
weights and values before the attention softmax.
|
||||||
|
need_head_weights (bool, optional): return the attention
|
||||||
|
weights for each head. Implies *need_weights*. Default:
|
||||||
|
return the average attention weights over all heads.
|
||||||
|
"""
|
||||||
|
if need_head_weights:
|
||||||
|
need_weights = True
|
||||||
|
|
||||||
|
is_tpu = query.device.type == "xla"
|
||||||
|
|
||||||
|
tgt_len, bsz, embed_dim = query.size()
|
||||||
|
src_len = tgt_len
|
||||||
|
assert embed_dim == self.embed_dim
|
||||||
|
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
||||||
|
if key is not None:
|
||||||
|
src_len, key_bsz, _ = key.size()
|
||||||
|
if not torch.jit.is_scripting():
|
||||||
|
assert key_bsz == bsz
|
||||||
|
assert value is not None
|
||||||
|
assert src_len, bsz == value.shape[:2]
|
||||||
|
|
||||||
|
if self.has_relative_attention_bias and position_bias is None:
|
||||||
|
position_bias = self.compute_bias(tgt_len, src_len)
|
||||||
|
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not is_tpu # don't use PyTorch version on TPUs
|
||||||
|
and incremental_state is None
|
||||||
|
and not static_kv
|
||||||
|
# A workaround for quantization to work. Otherwise JIT compilation
|
||||||
|
# treats bias in linear module as method.
|
||||||
|
and not torch.jit.is_scripting()
|
||||||
|
and self.q_head_dim == self.head_dim
|
||||||
|
):
|
||||||
|
assert key is not None and value is not None
|
||||||
|
assert attn_mask is None
|
||||||
|
|
||||||
|
attn_mask_rel_pos = None
|
||||||
|
if position_bias is not None:
|
||||||
|
attn_mask_rel_pos = position_bias
|
||||||
|
if self.gru_rel_pos:
|
||||||
|
query_layer = query.transpose(0, 1)
|
||||||
|
new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
|
||||||
|
query_layer = query_layer.view(*new_x_shape)
|
||||||
|
query_layer = query_layer.permute(0, 2, 1, 3)
|
||||||
|
_B, _H, _L, __ = query_layer.size()
|
||||||
|
|
||||||
|
gate_a, gate_b = torch.sigmoid(
|
||||||
|
self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
|
||||||
|
).chunk(2, dim=-1)
|
||||||
|
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
||||||
|
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
||||||
|
|
||||||
|
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
|
||||||
|
k_proj_bias = self.k_proj.bias
|
||||||
|
if k_proj_bias is None:
|
||||||
|
k_proj_bias = torch.zeros_like(self.q_proj.bias)
|
||||||
|
|
||||||
|
x, attn = F.multi_head_attention_forward(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
self.embed_dim,
|
||||||
|
self.num_heads,
|
||||||
|
torch.empty([0]),
|
||||||
|
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
||||||
|
self.bias_k,
|
||||||
|
self.bias_v,
|
||||||
|
self.add_zero_attn,
|
||||||
|
self.dropout_module.p,
|
||||||
|
self.out_proj.weight,
|
||||||
|
self.out_proj.bias,
|
||||||
|
self.training,
|
||||||
|
# self.training or self.dropout_module.apply_during_inference,
|
||||||
|
key_padding_mask,
|
||||||
|
need_weights,
|
||||||
|
attn_mask_rel_pos,
|
||||||
|
use_separate_proj_weight=True,
|
||||||
|
q_proj_weight=self.q_proj.weight,
|
||||||
|
k_proj_weight=self.k_proj.weight,
|
||||||
|
v_proj_weight=self.v_proj.weight,
|
||||||
|
)
|
||||||
|
return x, attn, position_bias
|
||||||
|
|
||||||
|
if incremental_state is not None:
|
||||||
|
saved_state = self._get_input_buffer(incremental_state)
|
||||||
|
if saved_state is not None and "prev_key" in saved_state:
|
||||||
|
# previous time steps are cached - no need to recompute
|
||||||
|
# key and value if they are static
|
||||||
|
if static_kv:
|
||||||
|
assert self.encoder_decoder_attention and not self.self_attention
|
||||||
|
key = value = None
|
||||||
|
else:
|
||||||
|
saved_state = None
|
||||||
|
|
||||||
|
if self.self_attention:
|
||||||
|
q = self.q_proj(query)
|
||||||
|
k = self.k_proj(query)
|
||||||
|
v = self.v_proj(query)
|
||||||
|
elif self.encoder_decoder_attention:
|
||||||
|
# encoder-decoder attention
|
||||||
|
q = self.q_proj(query)
|
||||||
|
if key is None:
|
||||||
|
assert value is None
|
||||||
|
k = v = None
|
||||||
|
else:
|
||||||
|
k = self.k_proj(key)
|
||||||
|
v = self.v_proj(key)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert key is not None and value is not None
|
||||||
|
q = self.q_proj(query)
|
||||||
|
k = self.k_proj(key)
|
||||||
|
v = self.v_proj(value)
|
||||||
|
q *= self.scaling
|
||||||
|
|
||||||
|
if self.bias_k is not None:
|
||||||
|
assert self.bias_v is not None
|
||||||
|
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
||||||
|
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
key_padding_mask = torch.cat(
|
||||||
|
[
|
||||||
|
key_padding_mask,
|
||||||
|
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.q_head_dim).transpose(0, 1)
|
||||||
|
if k is not None:
|
||||||
|
k = k.contiguous().view(-1, bsz * self.num_heads, self.k_head_dim).transpose(0, 1)
|
||||||
|
if v is not None:
|
||||||
|
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
||||||
|
|
||||||
|
if saved_state is not None:
|
||||||
|
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
||||||
|
if "prev_key" in saved_state:
|
||||||
|
_prev_key = saved_state["prev_key"]
|
||||||
|
assert _prev_key is not None
|
||||||
|
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
if static_kv:
|
||||||
|
k = prev_key
|
||||||
|
else:
|
||||||
|
assert k is not None
|
||||||
|
k = torch.cat([prev_key, k], dim=1)
|
||||||
|
src_len = k.size(1)
|
||||||
|
if "prev_value" in saved_state:
|
||||||
|
_prev_value = saved_state["prev_value"]
|
||||||
|
assert _prev_value is not None
|
||||||
|
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
||||||
|
if static_kv:
|
||||||
|
v = prev_value
|
||||||
|
else:
|
||||||
|
assert v is not None
|
||||||
|
v = torch.cat([prev_value, v], dim=1)
|
||||||
|
prev_key_padding_mask: Optional[Tensor] = None
|
||||||
|
if "prev_key_padding_mask" in saved_state:
|
||||||
|
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
||||||
|
assert k is not None and v is not None
|
||||||
|
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
||||||
|
key_padding_mask=key_padding_mask,
|
||||||
|
prev_key_padding_mask=prev_key_padding_mask,
|
||||||
|
batch_size=bsz,
|
||||||
|
src_len=k.size(1),
|
||||||
|
static_kv=static_kv,
|
||||||
|
)
|
||||||
|
|
||||||
|
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
||||||
|
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
||||||
|
saved_state["prev_key_padding_mask"] = key_padding_mask
|
||||||
|
# In this branch incremental_state is never None
|
||||||
|
assert incremental_state is not None
|
||||||
|
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
||||||
|
assert k is not None
|
||||||
|
assert k.size(1) == src_len
|
||||||
|
|
||||||
|
# This is part of a workaround to get around fork/join parallelism
|
||||||
|
# not supporting Optional types.
|
||||||
|
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
||||||
|
key_padding_mask = None
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.size(0) == bsz
|
||||||
|
assert key_padding_mask.size(1) == src_len
|
||||||
|
|
||||||
|
if self.add_zero_attn:
|
||||||
|
assert v is not None
|
||||||
|
src_len += 1
|
||||||
|
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
||||||
|
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
key_padding_mask = torch.cat(
|
||||||
|
[
|
||||||
|
key_padding_mask,
|
||||||
|
torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||||
|
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
||||||
|
|
||||||
|
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
attn_weights += attn_mask
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
# don't attend to padding symbols
|
||||||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
if not is_tpu:
|
||||||
|
attn_weights = attn_weights.masked_fill(
|
||||||
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
||||||
|
float("-inf"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_weights = attn_weights.transpose(0, 2)
|
||||||
|
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
||||||
|
attn_weights = attn_weights.transpose(0, 2)
|
||||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
|
if before_softmax:
|
||||||
|
return attn_weights, v, position_bias
|
||||||
|
|
||||||
|
if position_bias is not None:
|
||||||
|
if self.gru_rel_pos == 1:
|
||||||
|
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
||||||
|
_B, _H, _L, __ = query_layer.size()
|
||||||
|
gate_a, gate_b = torch.sigmoid(
|
||||||
|
self.grep_linear(query_layer).view(_B, _H, _L, 2, 4).sum(-1, keepdim=False)
|
||||||
|
).chunk(2, dim=-1)
|
||||||
|
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
||||||
|
position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
||||||
|
|
||||||
|
position_bias = position_bias.view(attn_weights.size())
|
||||||
|
|
||||||
|
attn_weights = attn_weights + position_bias
|
||||||
|
|
||||||
|
attn_weights_float = F.softmax(attn_weights, dim=-1)
|
||||||
|
attn_weights = attn_weights_float.type_as(attn_weights)
|
||||||
|
attn_probs = self.dropout_module(attn_weights)
|
||||||
|
|
||||||
|
assert v is not None
|
||||||
|
attn = torch.bmm(attn_probs, v)
|
||||||
|
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||||
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||||
|
attn = self.out_proj(attn)
|
||||||
|
attn_weights: Optional[Tensor] = None
|
||||||
|
if need_weights:
|
||||||
|
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
||||||
|
if not need_head_weights:
|
||||||
|
# average attention weights over heads
|
||||||
|
attn_weights = attn_weights.mean(dim=0)
|
||||||
|
|
||||||
|
return attn, attn_weights, position_bias
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _append_prev_key_padding_mask(
|
||||||
|
key_padding_mask: Optional[Tensor],
|
||||||
|
prev_key_padding_mask: Optional[Tensor],
|
||||||
|
batch_size: int,
|
||||||
|
src_len: int,
|
||||||
|
static_kv: bool,
|
||||||
|
) -> Optional[Tensor]:
|
||||||
|
# saved key padding masks have shape (bsz, seq_len)
|
||||||
|
if prev_key_padding_mask is not None and static_kv:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask
|
||||||
|
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
||||||
|
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
|
||||||
|
# During incremental decoding, as the padding token enters and
|
||||||
|
# leaves the frame, there will be a time when prev or current
|
||||||
|
# is None
|
||||||
|
elif prev_key_padding_mask is not None:
|
||||||
|
if src_len > prev_key_padding_mask.size(1):
|
||||||
|
filler = torch.zeros(
|
||||||
|
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
||||||
|
device=prev_key_padding_mask.device,
|
||||||
|
)
|
||||||
|
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask.float()
|
||||||
|
elif key_padding_mask is not None:
|
||||||
|
if src_len > key_padding_mask.size(1):
|
||||||
|
filler = torch.zeros(
|
||||||
|
(batch_size, src_len - key_padding_mask.size(1)),
|
||||||
|
device=key_padding_mask.device,
|
||||||
|
)
|
||||||
|
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = key_padding_mask.float()
|
||||||
|
else:
|
||||||
|
new_key_padding_mask = prev_key_padding_mask
|
||||||
|
return new_key_padding_mask
|
||||||
|
|
||||||
|
def _get_input_buffer(
|
||||||
|
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
||||||
|
) -> Dict[str, Optional[Tensor]]:
|
||||||
|
result = self.get_incremental_state(incremental_state, "attn_state")
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
empty_result: Dict[str, Optional[Tensor]] = {}
|
||||||
|
return empty_result
|
||||||
|
|
||||||
|
def _set_input_buffer(
|
||||||
|
self,
|
||||||
|
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
||||||
|
buffer: Dict[str, Optional[Tensor]],
|
||||||
|
):
|
||||||
|
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
||||||
|
|
||||||
|
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
||||||
|
return attn_weights
|
|
@ -0,0 +1,719 @@
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
||||||
|
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
||||||
|
# Copyright (c) 2021 Microsoft
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# Based on fairseq code bases
|
||||||
|
# https://github.com/pytorch/fairseq
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
|
from TTS.vc.modules.freevc.wavlm.modules import (
|
||||||
|
Fp32GroupNorm,
|
||||||
|
Fp32LayerNorm,
|
||||||
|
GLU_Linear,
|
||||||
|
GradMultiply,
|
||||||
|
MultiheadAttention,
|
||||||
|
SamePad,
|
||||||
|
TransposeLast,
|
||||||
|
get_activation_fn,
|
||||||
|
init_bert_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_mask_indices(
|
||||||
|
shape: Tuple[int, int],
|
||||||
|
padding_mask: Optional[torch.Tensor],
|
||||||
|
mask_prob: float,
|
||||||
|
mask_length: int,
|
||||||
|
mask_type: str = "static",
|
||||||
|
mask_other: float = 0.0,
|
||||||
|
min_masks: int = 0,
|
||||||
|
no_overlap: bool = False,
|
||||||
|
min_space: int = 0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Computes random mask spans for a given shape
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: the the shape for which to compute masks.
|
||||||
|
should be of size 2 where first element is batch size and 2nd is timesteps
|
||||||
|
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||||||
|
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||||||
|
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||||||
|
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||||
|
mask_type: how to compute mask lengths
|
||||||
|
static = fixed size
|
||||||
|
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
||||||
|
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
||||||
|
poisson = sample from possion distribution with lambda = mask length
|
||||||
|
min_masks: minimum number of masked spans
|
||||||
|
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
||||||
|
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
||||||
|
"""
|
||||||
|
|
||||||
|
bsz, all_sz = shape
|
||||||
|
mask = np.full((bsz, all_sz), False)
|
||||||
|
|
||||||
|
all_num_mask = int(
|
||||||
|
# add a random number for probabilistic rounding
|
||||||
|
mask_prob * all_sz / float(mask_length)
|
||||||
|
+ np.random.rand()
|
||||||
|
)
|
||||||
|
|
||||||
|
all_num_mask = max(min_masks, all_num_mask)
|
||||||
|
|
||||||
|
mask_idcs = []
|
||||||
|
for i in range(bsz):
|
||||||
|
if padding_mask is not None:
|
||||||
|
sz = all_sz - padding_mask[i].long().sum().item()
|
||||||
|
num_mask = int(
|
||||||
|
# add a random number for probabilistic rounding
|
||||||
|
mask_prob * sz / float(mask_length)
|
||||||
|
+ np.random.rand()
|
||||||
|
)
|
||||||
|
num_mask = max(min_masks, num_mask)
|
||||||
|
else:
|
||||||
|
sz = all_sz
|
||||||
|
num_mask = all_num_mask
|
||||||
|
|
||||||
|
if mask_type == "static":
|
||||||
|
lengths = np.full(num_mask, mask_length)
|
||||||
|
elif mask_type == "uniform":
|
||||||
|
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
||||||
|
elif mask_type == "normal":
|
||||||
|
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
||||||
|
lengths = [max(1, int(round(x))) for x in lengths]
|
||||||
|
elif mask_type == "poisson":
|
||||||
|
lengths = np.random.poisson(mask_length, size=num_mask)
|
||||||
|
lengths = [int(round(x)) for x in lengths]
|
||||||
|
else:
|
||||||
|
raise Exception("unknown mask selection " + mask_type)
|
||||||
|
|
||||||
|
if sum(lengths) == 0:
|
||||||
|
lengths[0] = min(mask_length, sz - 1)
|
||||||
|
|
||||||
|
if no_overlap:
|
||||||
|
mask_idc = []
|
||||||
|
|
||||||
|
def arrange(s, e, length, keep_length):
|
||||||
|
span_start = np.random.randint(s, e - length)
|
||||||
|
mask_idc.extend(span_start + i for i in range(length))
|
||||||
|
|
||||||
|
new_parts = []
|
||||||
|
if span_start - s - min_space >= keep_length:
|
||||||
|
new_parts.append((s, span_start - min_space + 1))
|
||||||
|
if e - span_start - keep_length - min_space > keep_length:
|
||||||
|
new_parts.append((span_start + length + min_space, e))
|
||||||
|
return new_parts
|
||||||
|
|
||||||
|
parts = [(0, sz)]
|
||||||
|
min_length = min(lengths)
|
||||||
|
for length in sorted(lengths, reverse=True):
|
||||||
|
lens = np.fromiter(
|
||||||
|
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
||||||
|
np.int,
|
||||||
|
)
|
||||||
|
l_sum = np.sum(lens)
|
||||||
|
if l_sum == 0:
|
||||||
|
break
|
||||||
|
probs = lens / np.sum(lens)
|
||||||
|
c = np.random.choice(len(parts), p=probs)
|
||||||
|
s, e = parts.pop(c)
|
||||||
|
parts.extend(arrange(s, e, length, min_length))
|
||||||
|
mask_idc = np.asarray(mask_idc)
|
||||||
|
else:
|
||||||
|
min_len = min(lengths)
|
||||||
|
if sz - min_len <= num_mask:
|
||||||
|
min_len = sz - num_mask - 1
|
||||||
|
|
||||||
|
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
||||||
|
|
||||||
|
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
||||||
|
|
||||||
|
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
||||||
|
|
||||||
|
min_len = min([len(m) for m in mask_idcs])
|
||||||
|
for i, mask_idc in enumerate(mask_idcs):
|
||||||
|
if len(mask_idc) > min_len:
|
||||||
|
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
||||||
|
mask[i, mask_idc] = True
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
class WavLMConfig:
|
||||||
|
def __init__(self, cfg=None):
|
||||||
|
self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
|
||||||
|
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
||||||
|
|
||||||
|
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
||||||
|
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
||||||
|
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
||||||
|
self.activation_fn: str = "gelu" # activation function to use
|
||||||
|
|
||||||
|
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
||||||
|
self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
|
||||||
|
self.conv_bias: bool = False # include bias in conv encoder
|
||||||
|
self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
|
||||||
|
|
||||||
|
self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
|
||||||
|
|
||||||
|
# dropouts
|
||||||
|
self.dropout: float = 0.1 # dropout probability for the transformer
|
||||||
|
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
||||||
|
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
||||||
|
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
||||||
|
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
||||||
|
self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
|
||||||
|
|
||||||
|
# masking
|
||||||
|
self.mask_length: int = 10 # mask length
|
||||||
|
self.mask_prob: float = 0.65 # probability of replacing a token with mask
|
||||||
|
self.mask_selection: str = "static" # how to choose mask length
|
||||||
|
self.mask_other: float = (
|
||||||
|
0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
|
||||||
|
)
|
||||||
|
self.no_mask_overlap: bool = False # whether to allow masks to overlap
|
||||||
|
self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
||||||
|
|
||||||
|
# channel masking
|
||||||
|
self.mask_channel_length: int = 10 # length of the mask for features (channels)
|
||||||
|
self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
|
||||||
|
self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
|
||||||
|
self.mask_channel_other: float = (
|
||||||
|
0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
|
||||||
|
)
|
||||||
|
self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
|
||||||
|
self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
||||||
|
|
||||||
|
# positional embeddings
|
||||||
|
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
||||||
|
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
||||||
|
|
||||||
|
# relative position embedding
|
||||||
|
self.relative_position_embedding: bool = False # apply relative position embedding
|
||||||
|
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
||||||
|
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
||||||
|
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
||||||
|
|
||||||
|
if cfg is not None:
|
||||||
|
self.update(cfg)
|
||||||
|
|
||||||
|
def update(self, cfg: dict):
|
||||||
|
self.__dict__.update(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
class WavLM(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cfg: WavLMConfig,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
logger.info(f"WavLM Config: {cfg.__dict__}")
|
||||||
|
|
||||||
|
self.cfg = cfg
|
||||||
|
feature_enc_layers = eval(cfg.conv_feature_layers)
|
||||||
|
self.embed = feature_enc_layers[-1][0]
|
||||||
|
|
||||||
|
self.feature_extractor = ConvFeatureExtractionModel(
|
||||||
|
conv_layers=feature_enc_layers,
|
||||||
|
dropout=0.0,
|
||||||
|
mode=cfg.extractor_mode,
|
||||||
|
conv_bias=cfg.conv_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.post_extract_proj = (
|
||||||
|
nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mask_prob = cfg.mask_prob
|
||||||
|
self.mask_selection = cfg.mask_selection
|
||||||
|
self.mask_other = cfg.mask_other
|
||||||
|
self.mask_length = cfg.mask_length
|
||||||
|
self.no_mask_overlap = cfg.no_mask_overlap
|
||||||
|
self.mask_min_space = cfg.mask_min_space
|
||||||
|
|
||||||
|
self.mask_channel_prob = cfg.mask_channel_prob
|
||||||
|
self.mask_channel_selection = cfg.mask_channel_selection
|
||||||
|
self.mask_channel_other = cfg.mask_channel_other
|
||||||
|
self.mask_channel_length = cfg.mask_channel_length
|
||||||
|
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
||||||
|
self.mask_channel_min_space = cfg.mask_channel_min_space
|
||||||
|
|
||||||
|
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||||
|
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
||||||
|
|
||||||
|
self.feature_grad_mult = cfg.feature_grad_mult
|
||||||
|
|
||||||
|
self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
|
||||||
|
|
||||||
|
self.encoder = TransformerEncoder(cfg)
|
||||||
|
self.layer_norm = LayerNorm(self.embed)
|
||||||
|
|
||||||
|
def apply_mask(self, x, padding_mask):
|
||||||
|
B, T, C = x.shape
|
||||||
|
if self.mask_prob > 0:
|
||||||
|
mask_indices = compute_mask_indices(
|
||||||
|
(B, T),
|
||||||
|
padding_mask,
|
||||||
|
self.mask_prob,
|
||||||
|
self.mask_length,
|
||||||
|
self.mask_selection,
|
||||||
|
self.mask_other,
|
||||||
|
min_masks=2,
|
||||||
|
no_overlap=self.no_mask_overlap,
|
||||||
|
min_space=self.mask_min_space,
|
||||||
|
)
|
||||||
|
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
||||||
|
x[mask_indices] = self.mask_emb
|
||||||
|
else:
|
||||||
|
mask_indices = None
|
||||||
|
|
||||||
|
if self.mask_channel_prob > 0:
|
||||||
|
mask_channel_indices = compute_mask_indices(
|
||||||
|
(B, C),
|
||||||
|
None,
|
||||||
|
self.mask_channel_prob,
|
||||||
|
self.mask_channel_length,
|
||||||
|
self.mask_channel_selection,
|
||||||
|
self.mask_channel_other,
|
||||||
|
no_overlap=self.no_mask_channel_overlap,
|
||||||
|
min_space=self.mask_channel_min_space,
|
||||||
|
)
|
||||||
|
mask_channel_indices = torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1)
|
||||||
|
x[mask_channel_indices] = 0
|
||||||
|
|
||||||
|
return x, mask_indices
|
||||||
|
|
||||||
|
def forward_padding_mask(
|
||||||
|
self,
|
||||||
|
features: torch.Tensor,
|
||||||
|
padding_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
extra = padding_mask.size(1) % features.size(1)
|
||||||
|
if extra > 0:
|
||||||
|
padding_mask = padding_mask[:, :-extra]
|
||||||
|
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
||||||
|
# padding_mask = padding_mask.all(-1)
|
||||||
|
padding_mask = padding_mask.any(-1)
|
||||||
|
return padding_mask
|
||||||
|
|
||||||
|
def extract_features(
|
||||||
|
self,
|
||||||
|
source: torch.Tensor,
|
||||||
|
padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
mask: bool = False,
|
||||||
|
ret_conv: bool = False,
|
||||||
|
output_layer: Optional[int] = None,
|
||||||
|
ret_layer_results: bool = False,
|
||||||
|
):
|
||||||
|
if self.feature_grad_mult > 0:
|
||||||
|
features = self.feature_extractor(source)
|
||||||
|
if self.feature_grad_mult != 1.0:
|
||||||
|
features = GradMultiply.apply(features, self.feature_grad_mult)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
features = self.feature_extractor(source)
|
||||||
|
|
||||||
|
features = features.transpose(1, 2)
|
||||||
|
features = self.layer_norm(features)
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
padding_mask = self.forward_padding_mask(features, padding_mask)
|
||||||
|
|
||||||
|
if self.post_extract_proj is not None:
|
||||||
|
features = self.post_extract_proj(features)
|
||||||
|
|
||||||
|
features = self.dropout_input(features)
|
||||||
|
|
||||||
|
if mask:
|
||||||
|
x, mask_indices = self.apply_mask(features, padding_mask)
|
||||||
|
else:
|
||||||
|
x = features
|
||||||
|
|
||||||
|
# feature: (B, T, D), float
|
||||||
|
# target: (B, T), long
|
||||||
|
# x: (B, T, D), float
|
||||||
|
# padding_mask: (B, T), bool
|
||||||
|
# mask_indices: (B, T), bool
|
||||||
|
x, layer_results = self.encoder(
|
||||||
|
x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
|
||||||
|
|
||||||
|
feature = res["features"] if ret_conv else res["x"]
|
||||||
|
if ret_layer_results:
|
||||||
|
feature = (feature, res["layer_results"])
|
||||||
|
return feature, res["padding_mask"]
|
||||||
|
|
||||||
|
|
||||||
|
class ConvFeatureExtractionModel(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
conv_layers: List[Tuple[int, int, int]],
|
||||||
|
dropout: float = 0.0,
|
||||||
|
mode: str = "default",
|
||||||
|
conv_bias: bool = False,
|
||||||
|
conv_type: str = "default",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert mode in {"default", "layer_norm"}
|
||||||
|
|
||||||
|
def block(
|
||||||
|
n_in,
|
||||||
|
n_out,
|
||||||
|
k,
|
||||||
|
stride,
|
||||||
|
is_layer_norm=False,
|
||||||
|
is_group_norm=False,
|
||||||
|
conv_bias=False,
|
||||||
|
):
|
||||||
|
def make_conv():
|
||||||
|
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
||||||
|
nn.init.kaiming_normal_(conv.weight)
|
||||||
|
return conv
|
||||||
|
|
||||||
|
assert (is_layer_norm and is_group_norm) == False, "layer norm and group norm are exclusive"
|
||||||
|
|
||||||
|
if is_layer_norm:
|
||||||
|
return nn.Sequential(
|
||||||
|
make_conv(),
|
||||||
|
nn.Dropout(p=dropout),
|
||||||
|
nn.Sequential(
|
||||||
|
TransposeLast(),
|
||||||
|
Fp32LayerNorm(dim, elementwise_affine=True),
|
||||||
|
TransposeLast(),
|
||||||
|
),
|
||||||
|
nn.GELU(),
|
||||||
|
)
|
||||||
|
elif is_group_norm:
|
||||||
|
return nn.Sequential(
|
||||||
|
make_conv(),
|
||||||
|
nn.Dropout(p=dropout),
|
||||||
|
Fp32GroupNorm(dim, dim, affine=True),
|
||||||
|
nn.GELU(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
||||||
|
|
||||||
|
self.conv_type = conv_type
|
||||||
|
if self.conv_type == "default":
|
||||||
|
in_d = 1
|
||||||
|
self.conv_layers = nn.ModuleList()
|
||||||
|
for i, cl in enumerate(conv_layers):
|
||||||
|
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
||||||
|
(dim, k, stride) = cl
|
||||||
|
|
||||||
|
self.conv_layers.append(
|
||||||
|
block(
|
||||||
|
in_d,
|
||||||
|
dim,
|
||||||
|
k,
|
||||||
|
stride,
|
||||||
|
is_layer_norm=mode == "layer_norm",
|
||||||
|
is_group_norm=mode == "default" and i == 0,
|
||||||
|
conv_bias=conv_bias,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
in_d = dim
|
||||||
|
elif self.conv_type == "conv2d":
|
||||||
|
in_d = 1
|
||||||
|
self.conv_layers = nn.ModuleList()
|
||||||
|
for i, cl in enumerate(conv_layers):
|
||||||
|
assert len(cl) == 3
|
||||||
|
(dim, k, stride) = cl
|
||||||
|
|
||||||
|
self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride))
|
||||||
|
self.conv_layers.append(torch.nn.ReLU())
|
||||||
|
in_d = dim
|
||||||
|
elif self.conv_type == "custom":
|
||||||
|
in_d = 1
|
||||||
|
idim = 80
|
||||||
|
self.conv_layers = nn.ModuleList()
|
||||||
|
for i, cl in enumerate(conv_layers):
|
||||||
|
assert len(cl) == 3
|
||||||
|
(dim, k, stride) = cl
|
||||||
|
self.conv_layers.append(torch.nn.Conv2d(in_d, dim, k, stride, padding=1))
|
||||||
|
self.conv_layers.append(torch.nn.LayerNorm([dim, idim]))
|
||||||
|
self.conv_layers.append(torch.nn.ReLU())
|
||||||
|
in_d = dim
|
||||||
|
if (i + 1) % 2 == 0:
|
||||||
|
self.conv_layers.append(torch.nn.MaxPool2d(2, stride=2, ceil_mode=True))
|
||||||
|
idim = int(math.ceil(idim / 2))
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
# BxT -> BxCxT
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
if self.conv_type == "custom":
|
||||||
|
for conv in self.conv_layers:
|
||||||
|
if isinstance(conv, nn.LayerNorm):
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
x = conv(x).transpose(1, 2)
|
||||||
|
else:
|
||||||
|
x = conv(x)
|
||||||
|
x = x.transpose(2, 3).contiguous()
|
||||||
|
x = x.view(x.size(0), -1, x.size(-1))
|
||||||
|
else:
|
||||||
|
for conv in self.conv_layers:
|
||||||
|
x = conv(x)
|
||||||
|
if self.conv_type == "conv2d":
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
x = x.transpose(2, 3).contiguous().view(b, c * f, t)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dropout = args.dropout
|
||||||
|
self.embedding_dim = args.encoder_embed_dim
|
||||||
|
|
||||||
|
self.pos_conv = nn.Conv1d(
|
||||||
|
self.embedding_dim,
|
||||||
|
self.embedding_dim,
|
||||||
|
kernel_size=args.conv_pos,
|
||||||
|
padding=args.conv_pos // 2,
|
||||||
|
groups=args.conv_pos_groups,
|
||||||
|
)
|
||||||
|
dropout = 0
|
||||||
|
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
||||||
|
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
||||||
|
nn.init.constant_(self.pos_conv.bias, 0)
|
||||||
|
|
||||||
|
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
||||||
|
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
||||||
|
|
||||||
|
if hasattr(args, "relative_position_embedding"):
|
||||||
|
self.relative_position_embedding = args.relative_position_embedding
|
||||||
|
self.num_buckets = args.num_buckets
|
||||||
|
self.max_distance = args.max_distance
|
||||||
|
else:
|
||||||
|
self.relative_position_embedding = False
|
||||||
|
self.num_buckets = 0
|
||||||
|
self.max_distance = 0
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
TransformerSentenceEncoderLayer(
|
||||||
|
embedding_dim=self.embedding_dim,
|
||||||
|
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
||||||
|
num_attention_heads=args.encoder_attention_heads,
|
||||||
|
dropout=self.dropout,
|
||||||
|
attention_dropout=args.attention_dropout,
|
||||||
|
activation_dropout=args.activation_dropout,
|
||||||
|
activation_fn=args.activation_fn,
|
||||||
|
layer_norm_first=args.layer_norm_first,
|
||||||
|
has_relative_attention_bias=(self.relative_position_embedding and i == 0),
|
||||||
|
num_buckets=self.num_buckets,
|
||||||
|
max_distance=self.max_distance,
|
||||||
|
gru_rel_pos=args.gru_rel_pos,
|
||||||
|
)
|
||||||
|
for i in range(args.encoder_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.layer_norm_first = args.layer_norm_first
|
||||||
|
self.layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
self.layerdrop = args.encoder_layerdrop
|
||||||
|
|
||||||
|
self.apply(init_bert_params)
|
||||||
|
|
||||||
|
def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
|
||||||
|
x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
|
||||||
|
|
||||||
|
if self.layer_norm_first and layer is None:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
return x, layer_results
|
||||||
|
|
||||||
|
def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
|
||||||
|
if padding_mask is not None:
|
||||||
|
x[padding_mask] = 0
|
||||||
|
|
||||||
|
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||||
|
x_conv = x_conv.transpose(1, 2)
|
||||||
|
x += x_conv
|
||||||
|
|
||||||
|
if not self.layer_norm_first:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
# B x T x C -> T x B x C
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
layer_results = []
|
||||||
|
z = None
|
||||||
|
if tgt_layer is not None:
|
||||||
|
layer_results.append((x, z))
|
||||||
|
r = None
|
||||||
|
pos_bias = None
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
dropout_probability = np.random.random()
|
||||||
|
if not self.training or (dropout_probability > self.layerdrop):
|
||||||
|
x, z, pos_bias = layer(
|
||||||
|
x,
|
||||||
|
self_attn_padding_mask=padding_mask,
|
||||||
|
need_weights=False,
|
||||||
|
self_attn_mask=streaming_mask,
|
||||||
|
pos_bias=pos_bias,
|
||||||
|
)
|
||||||
|
if tgt_layer is not None:
|
||||||
|
layer_results.append((x, z))
|
||||||
|
if i == tgt_layer:
|
||||||
|
r = x
|
||||||
|
break
|
||||||
|
|
||||||
|
if r is not None:
|
||||||
|
x = r
|
||||||
|
|
||||||
|
# T x B x C -> B x T x C
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
return x, layer_results
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerSentenceEncoderLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
||||||
|
models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: float = 768,
|
||||||
|
ffn_embedding_dim: float = 3072,
|
||||||
|
num_attention_heads: float = 8,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
attention_dropout: float = 0.1,
|
||||||
|
activation_dropout: float = 0.1,
|
||||||
|
activation_fn: str = "relu",
|
||||||
|
layer_norm_first: bool = False,
|
||||||
|
has_relative_attention_bias: bool = False,
|
||||||
|
num_buckets: int = 0,
|
||||||
|
max_distance: int = 0,
|
||||||
|
rescale_init: bool = False,
|
||||||
|
gru_rel_pos: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# Initialize parameters
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.activation_dropout = activation_dropout
|
||||||
|
|
||||||
|
# Initialize blocks
|
||||||
|
self.activation_name = activation_fn
|
||||||
|
self.activation_fn = get_activation_fn(activation_fn)
|
||||||
|
self.self_attn = MultiheadAttention(
|
||||||
|
self.embedding_dim,
|
||||||
|
num_attention_heads,
|
||||||
|
dropout=attention_dropout,
|
||||||
|
self_attention=True,
|
||||||
|
has_relative_attention_bias=has_relative_attention_bias,
|
||||||
|
num_buckets=num_buckets,
|
||||||
|
max_distance=max_distance,
|
||||||
|
rescale_init=rescale_init,
|
||||||
|
gru_rel_pos=gru_rel_pos,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(self.activation_dropout)
|
||||||
|
self.dropout3 = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.layer_norm_first = layer_norm_first
|
||||||
|
|
||||||
|
# layer norm associated with the self attention layer
|
||||||
|
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
||||||
|
else:
|
||||||
|
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
||||||
|
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
||||||
|
|
||||||
|
# layer norm associated with the position wise feed-forward NN
|
||||||
|
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
self_attn_mask: torch.Tensor = None,
|
||||||
|
self_attn_padding_mask: torch.Tensor = None,
|
||||||
|
need_weights: bool = False,
|
||||||
|
pos_bias=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
LayerNorm is applied either before or after the self-attention/ffn
|
||||||
|
modules similar to the original Transformer imlementation.
|
||||||
|
"""
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
if self.layer_norm_first:
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
x, attn, pos_bias = self.self_attn(
|
||||||
|
query=x,
|
||||||
|
key=x,
|
||||||
|
value=x,
|
||||||
|
key_padding_mask=self_attn_padding_mask,
|
||||||
|
need_weights=False,
|
||||||
|
attn_mask=self_attn_mask,
|
||||||
|
position_bias=pos_bias,
|
||||||
|
)
|
||||||
|
x = self.dropout1(x)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
x = self.fc1(x)
|
||||||
|
else:
|
||||||
|
x = self.activation_fn(self.fc1(x))
|
||||||
|
x = self.dropout2(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.dropout3(x)
|
||||||
|
x = residual + x
|
||||||
|
else:
|
||||||
|
x, attn, pos_bias = self.self_attn(
|
||||||
|
query=x,
|
||||||
|
key=x,
|
||||||
|
value=x,
|
||||||
|
key_padding_mask=self_attn_padding_mask,
|
||||||
|
need_weights=need_weights,
|
||||||
|
attn_mask=self_attn_mask,
|
||||||
|
position_bias=pos_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.dropout1(x)
|
||||||
|
x = residual + x
|
||||||
|
|
||||||
|
x = self.self_attn_layer_norm(x)
|
||||||
|
|
||||||
|
residual = x
|
||||||
|
if self.activation_name == "glu":
|
||||||
|
x = self.fc1(x)
|
||||||
|
else:
|
||||||
|
x = self.activation_fn(self.fc1(x))
|
||||||
|
x = self.dropout2(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.dropout3(x)
|
||||||
|
x = residual + x
|
||||||
|
x = self.final_layer_norm(x)
|
||||||
|
|
||||||
|
return x, attn, pos_bias
|
|
@ -18,6 +18,8 @@ class BaseVocoder(BaseTrainerModel):
|
||||||
- 1D tensors `batch x 1`
|
- 1D tensors `batch x 1`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
MODEL_TYPE = "vocoder"
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._set_model_args(config)
|
self._set_model_args(config)
|
||||||
|
|
|
@ -36,8 +36,8 @@ Run a tts and a vocoder model from the released model list. Note that not every
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
tts --text "Text for TTS" \
|
tts --text "Text for TTS" \
|
||||||
--model_name "<type>/<language>/<dataset>/<model_name>" \
|
--model_name "tts_models/<language>/<dataset>/<model_name>" \
|
||||||
--vocoder_name "<type>/<language>/<dataset>/<model_name>" \
|
--vocoder_name "vocoder_models/<language>/<dataset>/<model_name>" \
|
||||||
--out_path folder/to/save/output.wav
|
--out_path folder/to/save/output.wav
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -64,8 +64,17 @@ tts --text "Text for TTS" \
|
||||||
Run a multi-speaker TTS model from the released models list.
|
Run a multi-speaker TTS model from the released models list.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
tts --model_name "<type>/<language>/<dataset>/<model_name>" --list_speaker_idxs # list the possible speaker IDs.
|
tts --model_name "tts_models/<language>/<dataset>/<model_name>" --list_speaker_idxs # list the possible speaker IDs.
|
||||||
tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx "<speaker_id>"
|
tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "tts_models/<language>/<dataset>/<model_name>" --speaker_idx "<speaker_id>"
|
||||||
|
```
|
||||||
|
|
||||||
|
Run a released voice conversion model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tts --model_name "voice_conversion/<language>/<dataset>/<model_name>"
|
||||||
|
--source_wav "my/source/speaker/audio.wav"
|
||||||
|
--target_wav "my/target/speaker/audio.wav"
|
||||||
|
--out_path folder/to/save/output.wav
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note:** You can use ```./TTS/bin/synthesize.py``` if you prefer running ```tts``` from the TTS project folder.
|
**Note:** You can use ```./TTS/bin/synthesize.py``` if you prefer running ```tts``` from the TTS project folder.
|
||||||
|
@ -135,4 +144,23 @@ tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_
|
||||||
tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
|
tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
|
||||||
tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="output.wav")
|
tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="output.wav")
|
||||||
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav")
|
tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="output.wav")
|
||||||
|
```
|
||||||
|
|
||||||
|
Example voice conversion converting speaker of the `source_wav` to the speaker of the `target_wav`
|
||||||
|
|
||||||
|
```python
|
||||||
|
tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progress_bar=False, gpu=True)
|
||||||
|
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
|
||||||
|
```
|
||||||
|
|
||||||
|
Example voice cloning by a single speaker TTS model combining with the voice conversion model. This way, you can
|
||||||
|
clone voices by using any model in 🐸TTS.
|
||||||
|
|
||||||
|
```python
|
||||||
|
tts = TTS("tts_models/de/thorsten/tacotron2-DDC")
|
||||||
|
tts.tts_with_vc_to_file(
|
||||||
|
"Wie sage ich auf Italienisch, dass ich dich liebe?",
|
||||||
|
speaker_wav="target/speaker.wav",
|
||||||
|
file_path="ouptut.wav"
|
||||||
|
)
|
||||||
```
|
```
|
|
@ -14,6 +14,7 @@ tqdm
|
||||||
anyascii
|
anyascii
|
||||||
pyyaml
|
pyyaml
|
||||||
fsspec>=2021.04.0
|
fsspec>=2021.04.0
|
||||||
|
aiohttp
|
||||||
packaging
|
packaging
|
||||||
# deps for examples
|
# deps for examples
|
||||||
flask
|
flask
|
||||||
|
|
|
@ -28,7 +28,7 @@ class TTSTest(unittest.TestCase):
|
||||||
|
|
||||||
def test_multi_speaker_multi_lingual_model(self):
|
def test_multi_speaker_multi_lingual_model(self):
|
||||||
tts = TTS()
|
tts = TTS()
|
||||||
tts.load_model_by_name(tts.models[0]) # YourTTS
|
tts.load_tts_model_by_name(tts.models[0]) # YourTTS
|
||||||
tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path=OUTPUT_PATH)
|
tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path=OUTPUT_PATH)
|
||||||
|
|
||||||
self.assertTrue(tts.is_multi_speaker)
|
self.assertTrue(tts.is_multi_speaker)
|
||||||
|
@ -38,5 +38,5 @@ class TTSTest(unittest.TestCase):
|
||||||
|
|
||||||
def test_voice_cloning(self): # pylint: disable=no-self-use
|
def test_voice_cloning(self): # pylint: disable=no-self-use
|
||||||
tts = TTS()
|
tts = TTS()
|
||||||
tts.load_model_by_name("tts_models/multilingual/multi-dataset/your_tts")
|
tts.load_tts_model_by_name("tts_models/multilingual/multi-dataset/your_tts")
|
||||||
tts.tts_to_file("Hello world!", speaker_wav=cloning_test_wav_path, language="en", file_path=OUTPUT_PATH)
|
tts.tts_to_file("Hello world!", speaker_wav=cloning_test_wav_path, language="en", file_path=OUTPUT_PATH)
|
||||||
|
|
|
@ -0,0 +1,135 @@
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests import get_tests_input_path
|
||||||
|
from TTS.vc.configs.freevc_config import FreeVCConfig
|
||||||
|
from TTS.vc.models.freevc import FreeVC
|
||||||
|
|
||||||
|
# pylint: disable=unused-variable
|
||||||
|
# pylint: disable=no-self-use
|
||||||
|
|
||||||
|
torch.manual_seed(1)
|
||||||
|
use_cuda = torch.cuda.is_available()
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
c = FreeVCConfig()
|
||||||
|
|
||||||
|
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||||
|
BATCH_SIZE = 3
|
||||||
|
|
||||||
|
|
||||||
|
def count_parameters(model):
|
||||||
|
r"""Count number of trainable parameters in a network"""
|
||||||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFreeVC(unittest.TestCase):
|
||||||
|
def _create_inputs(self, config, batch_size=2):
|
||||||
|
input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device)
|
||||||
|
input_lengths = torch.randint(100, 30 * config.audio["hop_length"], (batch_size,)).long().to(device)
|
||||||
|
input_lengths[-1] = 30 * config.audio["hop_length"]
|
||||||
|
spec = torch.rand(batch_size, 30, config.audio["filter_length"] // 2 + 1).to(device)
|
||||||
|
mel = torch.rand(batch_size, 30, config.audio["n_mel_channels"]).to(device)
|
||||||
|
spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
|
||||||
|
spec_lengths[-1] = spec.size(2)
|
||||||
|
waveform = torch.rand(batch_size, spec.size(2) * config.audio["hop_length"]).to(device)
|
||||||
|
return input_dummy, input_lengths, mel, spec, spec_lengths, waveform
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_inputs_inference():
|
||||||
|
source_wav = torch.rand(16000)
|
||||||
|
target_wav = torch.rand(16000)
|
||||||
|
return source_wav, target_wav
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_parameter_changes(model, model_ref):
|
||||||
|
count = 0
|
||||||
|
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||||
|
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||||
|
count, param.shape, param, param_ref
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
def test_methods(self):
|
||||||
|
config = FreeVCConfig()
|
||||||
|
model = FreeVC(config).to(device)
|
||||||
|
model.load_pretrained_speaker_encoder()
|
||||||
|
model.init_multispeaker(config)
|
||||||
|
wavlm_feats = model.extract_wavlm_features(torch.rand(1, 16000))
|
||||||
|
assert wavlm_feats.shape == (1, 1024, 49), wavlm_feats.shape
|
||||||
|
|
||||||
|
def test_load_audio(self):
|
||||||
|
config = FreeVCConfig()
|
||||||
|
model = FreeVC(config).to(device)
|
||||||
|
wav = model.load_audio(WAV_FILE)
|
||||||
|
wav2 = model.load_audio(wav)
|
||||||
|
assert all(torch.isclose(wav, wav2))
|
||||||
|
|
||||||
|
def _test_forward(self, batch_size):
|
||||||
|
# create model
|
||||||
|
config = FreeVCConfig()
|
||||||
|
model = FreeVC(config).to(device)
|
||||||
|
model.train()
|
||||||
|
print(" > Num parameters for FreeVC model:%s" % (count_parameters(model)))
|
||||||
|
|
||||||
|
_, _, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size)
|
||||||
|
|
||||||
|
wavlm_vec = model.extract_wavlm_features(waveform)
|
||||||
|
wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long)
|
||||||
|
|
||||||
|
y = model.forward(wavlm_vec, spec, None, mel, spec_lengths, wavlm_vec_lengths)
|
||||||
|
# TODO: assert with training implementation
|
||||||
|
|
||||||
|
def test_forward(self):
|
||||||
|
self._test_forward(1)
|
||||||
|
self._test_forward(3)
|
||||||
|
|
||||||
|
def _test_inference(self, batch_size):
|
||||||
|
config = FreeVCConfig()
|
||||||
|
model = FreeVC(config).to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
_, _, mel, _, _, waveform = self._create_inputs(config, batch_size)
|
||||||
|
|
||||||
|
wavlm_vec = model.extract_wavlm_features(waveform)
|
||||||
|
wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long)
|
||||||
|
|
||||||
|
output_wav = model.inference(wavlm_vec, None, mel, wavlm_vec_lengths)
|
||||||
|
assert (
|
||||||
|
output_wav.shape[-1] // config.audio.hop_length == wavlm_vec.shape[-1]
|
||||||
|
), f"{output_wav.shape[-1] // config.audio.hop_length} != {wavlm_vec.shape}"
|
||||||
|
|
||||||
|
def test_inference(self):
|
||||||
|
self._test_inference(1)
|
||||||
|
self._test_inference(3)
|
||||||
|
|
||||||
|
def test_voice_conversion(self):
|
||||||
|
config = FreeVCConfig()
|
||||||
|
model = FreeVC(config).to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
source_wav, target_wav = self._create_inputs_inference()
|
||||||
|
output_wav = model.voice_conversion(source_wav, target_wav)
|
||||||
|
assert (
|
||||||
|
output_wav.shape[0] + config.audio.hop_length == source_wav.shape[0]
|
||||||
|
), f"{output_wav.shape} != {source_wav.shape}"
|
||||||
|
|
||||||
|
def test_train_step(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
def test_train_eval_log(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
def test_test_run(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
def test_load_checkpoint(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
def test_get_criterion(self):
|
||||||
|
...
|
||||||
|
|
||||||
|
def test_init_from_config(self):
|
||||||
|
...
|
|
@ -51,6 +51,13 @@ def run_models(offset=0, step=1):
|
||||||
# remove downloaded models
|
# remove downloaded models
|
||||||
shutil.rmtree(local_download_dir)
|
shutil.rmtree(local_download_dir)
|
||||||
shutil.rmtree(get_user_data_dir("tts"))
|
shutil.rmtree(get_user_data_dir("tts"))
|
||||||
|
elif "voice_conversion_models" in model_name:
|
||||||
|
speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
|
||||||
|
reference_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0032.wav")
|
||||||
|
run_cli(
|
||||||
|
f"tts --model_name {model_name} "
|
||||||
|
f'--out_path "{output_path}" --source_wav "{speaker_wav}" --target_wav "{reference_wav}" --progress_bar False'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# only download the model
|
# only download the model
|
||||||
manager.download_model(model_name)
|
manager.download_model(model_name)
|
||||||
|
|
Loading…
Reference in New Issue