From 5255e089e6b7502c9ffb9f0a4ad462796c3f918f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 30 Aug 2021 13:08:05 +0000 Subject: [PATCH] Fix #767 --- TTS/utils/io.py | 2 +- TTS/utils/manage.py | 60 +++++++++++++++++++++++++--------- TTS/vocoder/models/wavegrad.py | 2 +- 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 4d75e7b0..dd4ffd60 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -3,7 +3,7 @@ import json import os import pickle as pickle_tts import shutil -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, Union import fsspec import torch diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 1d61d392..4a45fb2d 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -110,7 +110,7 @@ class ModelManager(object): os.makedirs(output_path, exist_ok=True) print(f" > Downloading model to {output_path}") output_stats_path = os.path.join(output_path, "scale_stats.npy") - output_speakers_path = os.path.join(output_path, "speakers.json") + # download files to the output path if self._check_dict_key(model_item, "github_rls_url"): # download from github release @@ -122,22 +122,52 @@ class ModelManager(object): if self._check_dict_key(model_item, "stats_file"): self._download_gdrive_file(model_item["stats_file"], output_stats_path) - # update the scale_path.npy file path in the model config.json - if self._check_dict_key(model_item, "stats_file") or os.path.exists(output_stats_path): - # set scale stats path in config.json - config_path = output_config_path - config = load_config(config_path) - config.audio.stats_path = output_stats_path - config.save_json(config_path) - # update the speakers.json file path in the model config.json to the current path - if os.path.exists(output_speakers_path): - # set scale stats path in config.json - config_path = output_config_path - config = load_config(config_path) - config.d_vector_file = output_speakers_path - config.save_json(config_path) + # update paths in the config.json + self._update_paths(output_path, output_config_path) return output_model_path, output_config_path, model_item + def _update_paths(self, output_path: str, config_path: str) -> None: + """Update paths for certain files in config.json after download. + + Args: + output_path (str): local path the model is downloaded to. + config_path (str): local config.json path. + """ + output_stats_path = os.path.join(output_path, "scale_stats.npy") + output_d_vector_file_path = os.path.join(output_path, "speakers.json") + output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json") + + # update the scale_path.npy file path in the model config.json + self._update_path("audio.stats_path", output_stats_path, config_path) + + # update the speakers.json file path in the model config.json to the current path + self._update_path("d_vector_file", output_d_vector_file_path, config_path) + self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path) + + # update the speaker_ids.json file path in the model config.json to the current path + self._update_path("speakers_file", output_speaker_ids_file_path, config_path) + self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path) + + @staticmethod + def _update_path(field_name, new_path, config_path): + """Update the path in the model config.json for the current environment after download""" + if os.path.exists(new_path): + config = load_config(config_path) + field_names = field_name.split(".") + if len(field_names) > 1: + # field name points to a sub-level field + sub_conf = config + for fd in field_names[:-1]: + if fd in sub_conf: + sub_conf = sub_conf[fd] + else: + return + sub_conf[field_names[-1]] = new_path + else: + # field name points to a top-level field + config[field_name] = new_path + config.save_json(config_path) + def _download_gdrive_file(self, gdrive_idx, output): """Download files from GDrive using their file ids""" gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 2a76baa5..8d95a063 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -12,9 +12,9 @@ from torch.utils.data.distributed import DistributedSampler from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_fsspec from TTS.utils.trainer_utils import get_optimizer, get_scheduler -from TTS.vocoder.base_vocoder import BaseVocoder from TTS.vocoder.datasets import WaveGradDataset from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock +from TTS.vocoder.models.base_vocoder import BaseVocoder from TTS.vocoder.utils.generic_utils import plot_results