mirror of https://github.com/coqui-ai/TTS.git
parent
b5cd644132
commit
672ec3b35e
|
@ -121,7 +121,7 @@ class HubertTokenizer(nn.Module):
|
||||||
data_from_model.output_size,
|
data_from_model.output_size,
|
||||||
data_from_model.version,
|
data_from_model.version,
|
||||||
)
|
)
|
||||||
model.load_state_dict(torch.load(path))
|
model.load_state_dict(torch.load(path, map_location=map_location))
|
||||||
if map_location:
|
if map_location:
|
||||||
model = model.to(map_location)
|
model = model.to(map_location)
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -136,9 +136,7 @@ def generate_voice(
|
||||||
hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device)
|
hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device)
|
||||||
|
|
||||||
# Load the CustomTokenizer model
|
# Load the CustomTokenizer model
|
||||||
tokenizer = HubertTokenizer.load_from_checkpoint(model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"]).to(
|
tokenizer = HubertTokenizer.load_from_checkpoint(model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"], map_location=model.device)
|
||||||
model.device
|
|
||||||
) # Automatically uses
|
|
||||||
# semantic_tokens = model.text_to_semantic(
|
# semantic_tokens = model.text_to_semantic(
|
||||||
# text, max_gen_duration_s=seconds, top_k=50, top_p=0.95, temp=0.7
|
# text, max_gen_duration_s=seconds, top_k=50, top_p=0.95, temp=0.7
|
||||||
# ) # not 100%
|
# ) # not 100%
|
||||||
|
|
Loading…
Reference in New Issue