mirror of https://github.com/coqui-ai/TTS.git
Fix #767
This commit is contained in:
parent
c560114324
commit
5255e089e6
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import pickle as pickle_tts
|
||||
import shutil
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
|
|
|
@ -110,7 +110,7 @@ class ModelManager(object):
|
|||
os.makedirs(output_path, exist_ok=True)
|
||||
print(f" > Downloading model to {output_path}")
|
||||
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
||||
output_speakers_path = os.path.join(output_path, "speakers.json")
|
||||
|
||||
# download files to the output path
|
||||
if self._check_dict_key(model_item, "github_rls_url"):
|
||||
# download from github release
|
||||
|
@ -122,22 +122,52 @@ class ModelManager(object):
|
|||
if self._check_dict_key(model_item, "stats_file"):
|
||||
self._download_gdrive_file(model_item["stats_file"], output_stats_path)
|
||||
|
||||
# update the scale_path.npy file path in the model config.json
|
||||
if self._check_dict_key(model_item, "stats_file") or os.path.exists(output_stats_path):
|
||||
# set scale stats path in config.json
|
||||
config_path = output_config_path
|
||||
config = load_config(config_path)
|
||||
config.audio.stats_path = output_stats_path
|
||||
config.save_json(config_path)
|
||||
# update the speakers.json file path in the model config.json to the current path
|
||||
if os.path.exists(output_speakers_path):
|
||||
# set scale stats path in config.json
|
||||
config_path = output_config_path
|
||||
config = load_config(config_path)
|
||||
config.d_vector_file = output_speakers_path
|
||||
config.save_json(config_path)
|
||||
# update paths in the config.json
|
||||
self._update_paths(output_path, output_config_path)
|
||||
return output_model_path, output_config_path, model_item
|
||||
|
||||
def _update_paths(self, output_path: str, config_path: str) -> None:
|
||||
"""Update paths for certain files in config.json after download.
|
||||
|
||||
Args:
|
||||
output_path (str): local path the model is downloaded to.
|
||||
config_path (str): local config.json path.
|
||||
"""
|
||||
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
||||
output_d_vector_file_path = os.path.join(output_path, "speakers.json")
|
||||
output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
|
||||
|
||||
# update the scale_path.npy file path in the model config.json
|
||||
self._update_path("audio.stats_path", output_stats_path, config_path)
|
||||
|
||||
# update the speakers.json file path in the model config.json to the current path
|
||||
self._update_path("d_vector_file", output_d_vector_file_path, config_path)
|
||||
self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path)
|
||||
|
||||
# update the speaker_ids.json file path in the model config.json to the current path
|
||||
self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
|
||||
self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
|
||||
|
||||
@staticmethod
|
||||
def _update_path(field_name, new_path, config_path):
|
||||
"""Update the path in the model config.json for the current environment after download"""
|
||||
if os.path.exists(new_path):
|
||||
config = load_config(config_path)
|
||||
field_names = field_name.split(".")
|
||||
if len(field_names) > 1:
|
||||
# field name points to a sub-level field
|
||||
sub_conf = config
|
||||
for fd in field_names[:-1]:
|
||||
if fd in sub_conf:
|
||||
sub_conf = sub_conf[fd]
|
||||
else:
|
||||
return
|
||||
sub_conf[field_names[-1]] = new_path
|
||||
else:
|
||||
# field name points to a top-level field
|
||||
config[field_name] = new_path
|
||||
config.save_json(config_path)
|
||||
|
||||
def _download_gdrive_file(self, gdrive_idx, output):
|
||||
"""Download files from GDrive using their file ids"""
|
||||
gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False)
|
||||
|
|
|
@ -12,9 +12,9 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.trainer_utils import get_optimizer, get_scheduler
|
||||
from TTS.vocoder.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.datasets import WaveGradDataset
|
||||
from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue