mirror of https://github.com/coqui-ai/TTS.git
parent
843d1b3d98
commit
715b0a65a0
|
@ -30,6 +30,7 @@ jobs:
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
architecture: x64
|
||||||
- name: check OS
|
- name: check OS
|
||||||
run: cat /etc/os-release
|
run: cat /etc/os-release
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
|
|
|
@ -41,7 +41,7 @@ def main():
|
||||||
if args.data_path:
|
if args.data_path:
|
||||||
dataset_items = glob.glob(os.path.join(args.data_path, "**", "*.wav"), recursive=True)
|
dataset_items = glob.glob(os.path.join(args.data_path, "**", "*.wav"), recursive=True)
|
||||||
else:
|
else:
|
||||||
dataset_items = load_meta_data(CONFIG.dataset_config)[0] # take only train data
|
dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data
|
||||||
print(f" > There are {len(dataset_items)} files.")
|
print(f" > There are {len(dataset_items)} files.")
|
||||||
|
|
||||||
mel_sum = 0
|
mel_sum = 0
|
||||||
|
|
|
@ -14,7 +14,7 @@ from TTS.tts.datasets.TTSDataset import MyDataset
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.speakers import parse_speakers
|
from TTS.tts.utils.speakers import parse_speakers
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
from TTS.utils.io import load_config
|
from TTS.config import load_config
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters
|
from TTS.utils.generic_utils import count_parameters
|
||||||
|
|
||||||
|
@ -210,7 +210,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
# Audio processor
|
# Audio processor
|
||||||
ap = AudioProcessor(**c.audio)
|
ap = AudioProcessor(**c.audio)
|
||||||
if "characters" in c.keys():
|
if "characters" in c.keys() and c['characters']:
|
||||||
symbols, phonemes = make_symbols(**c.characters)
|
symbols, phonemes = make_symbols(**c.characters)
|
||||||
|
|
||||||
# set model characters
|
# set model characters
|
||||||
|
@ -276,5 +276,4 @@ if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
c = load_config(args.config_path)
|
c = load_config(args.config_path)
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
|
|
@ -125,17 +125,15 @@ class ModelManager(object):
|
||||||
# set scale stats path in config.json
|
# set scale stats path in config.json
|
||||||
config_path = output_config_path
|
config_path = output_config_path
|
||||||
config = load_config(config_path)
|
config = load_config(config_path)
|
||||||
config["audio"]["stats_path"] = output_stats_path
|
config.audio.stats_path = output_stats_path
|
||||||
with open(config_path, "w") as jf:
|
config.save_json(config_path)
|
||||||
json.dump(config, jf)
|
|
||||||
# update the speakers.json file path in the model config.json to the current path
|
# update the speakers.json file path in the model config.json to the current path
|
||||||
if os.path.exists(output_speakers_path):
|
if os.path.exists(output_speakers_path):
|
||||||
# set scale stats path in config.json
|
# set scale stats path in config.json
|
||||||
config_path = output_config_path
|
config_path = output_config_path
|
||||||
config = load_config(config_path)
|
config = load_config(config_path)
|
||||||
config["external_speaker_embedding_file"] = output_speakers_path
|
config.external_speaker_embedding_file = output_speakers_path
|
||||||
with open(config_path, "w") as jf:
|
config.save_json(config_path)
|
||||||
json.dump(config, jf)
|
|
||||||
return output_model_path, output_config_path, model_item
|
return output_model_path, output_config_path, model_item
|
||||||
|
|
||||||
def _download_gdrive_file(self, gdrive_idx, output):
|
def _download_gdrive_file(self, gdrive_idx, output):
|
||||||
|
|
|
@ -9,7 +9,7 @@ from tests import get_tests_output_path, run_cli
|
||||||
|
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
|
|
||||||
from TTS.utils.io import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
|
|
|
@ -33,7 +33,6 @@ command_train = (
|
||||||
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
"--coqpit.datasets.0.meta_file_train metadata.csv "
|
||||||
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
"--coqpit.datasets.0.meta_file_val metadata.csv "
|
||||||
"--coqpit.datasets.0.path tests/data/ljspeech "
|
"--coqpit.datasets.0.path tests/data/ljspeech "
|
||||||
"--coqpit.datasets.0.meta_file_attn_mask tests/data/ljspeech/metadata_attn_mask.txt"
|
|
||||||
)
|
)
|
||||||
run_cli(command_train)
|
run_cli(command_train)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue