fix tests

This commit is contained in:
manmay-nakhashi 2023-05-02 00:51:29 +05:30
parent fede89ac0d
commit 49feaf5fa1
3 changed files with 11 additions and 12 deletions

View File

@ -306,6 +306,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
encoder_config_path = None encoder_config_path = None
vc_path = None vc_path = None
vc_config_path = None vc_config_path = None
model_dir = None
# CASE1 #list : list pre-trained TTS models # CASE1 #list : list pre-trained TTS models
if args.list_models: if args.list_models:
@ -336,6 +337,9 @@ If you don't specify any models, then it uses LJSpeech based English model.
if args.model_name is not None and not args.model_path: if args.model_name is not None and not args.model_path:
model_path, config_path, model_item = manager.download_model(args.model_name) model_path, config_path, model_item = manager.download_model(args.model_name)
# tortoise model
if model_path.split("--")[-1] == "tortoise-v2":
model_dir = model_path
# tts model # tts model
if model_item["model_type"] == "tts_models": if model_item["model_type"] == "tts_models":
tts_path = model_path tts_path = model_path
@ -379,6 +383,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
encoder_config_path, encoder_config_path,
vc_path, vc_path,
vc_config_path, vc_config_path,
model_dir,
args.use_cuda, args.use_cuda,
) )

View File

@ -1,7 +1,7 @@
import os import os
import random import random
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass
from time import time from time import time
import torch import torch
@ -392,11 +392,6 @@ class Tortoise(BaseTTS):
yield m yield m
m = model.cpu() m = model.cpu()
def speaker_manager(
self,
):
self.speaker_names = os.listdir()
def get_conditioning_latents( def get_conditioning_latents(
self, self,
voice_samples, voice_samples,
@ -804,17 +799,17 @@ class Tortoise(BaseTTS):
} }
return return_dict return return_dict
def forward(): def forward(self):
raise NotImplementedError("Tortoise Training is not implemented") raise NotImplementedError("Tortoise Training is not implemented")
def eval_step(): def eval_step(self):
raise NotImplementedError("Tortoise Training is not implemented") raise NotImplementedError("Tortoise Training is not implemented")
def init_from_config(): def init_from_config(self):
raise NotImplementedError("Tortoise Training is not implemented") raise NotImplementedError("Tortoise Training is not implemented")
def load_checkpoint(): def load_checkpoint(self):
raise NotImplementedError("Tortoise Training is not implemented") raise NotImplementedError("Tortoise Training is not implemented")
def train_step(): def train_step(self):
raise NotImplementedError("Tortoise Training is not implemented") raise NotImplementedError("Tortoise Training is not implemented")

View File

@ -299,7 +299,6 @@ class ModelManager(object):
model_file = None model_file = None
config_file = None config_file = None
for file_name in os.listdir(output_path): for file_name in os.listdir(output_path):
print(file_name)
if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]: if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
model_file = os.path.join(output_path, file_name) model_file = os.path.join(output_path, file_name)
elif file_name == "config.json": elif file_name == "config.json":