This commit is contained in:
Eren Gölge 2023-07-08 11:40:44 +02:00 committed by GitHub
parent b5cd644132
commit 672ec3b35e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 4 deletions

View File

@ -121,7 +121,7 @@ class HubertTokenizer(nn.Module):
data_from_model.output_size, data_from_model.output_size,
data_from_model.version, data_from_model.version,
) )
model.load_state_dict(torch.load(path)) model.load_state_dict(torch.load(path, map_location=map_location))
if map_location: if map_location:
model = model.to(map_location) model = model.to(map_location)
return model return model

View File

@ -136,9 +136,7 @@ def generate_voice(
hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device) hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device)
# Load the CustomTokenizer model # Load the CustomTokenizer model
tokenizer = HubertTokenizer.load_from_checkpoint(model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"]).to( tokenizer = HubertTokenizer.load_from_checkpoint(model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"], map_location=model.device)
model.device
) # Automatically uses
# semantic_tokens = model.text_to_semantic( # semantic_tokens = model.text_to_semantic(
# text, max_gen_duration_s=seconds, top_k=50, top_p=0.95, temp=0.7 # text, max_gen_duration_s=seconds, top_k=50, top_p=0.95, temp=0.7
# ) # not 100% # ) # not 100%