This commit is contained in:
Eren Gölge 2023-07-08 11:40:44 +02:00 committed by GitHub
parent b5cd644132
commit 672ec3b35e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 4 deletions

View File

@ -121,7 +121,7 @@ class HubertTokenizer(nn.Module):
data_from_model.output_size,
data_from_model.version,
)
model.load_state_dict(torch.load(path))
model.load_state_dict(torch.load(path, map_location=map_location))
if map_location:
model = model.to(map_location)
return model

View File

@ -136,9 +136,7 @@ def generate_voice(
hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device)
# Load the CustomTokenizer model
tokenizer = HubertTokenizer.load_from_checkpoint(model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"]).to(
model.device
) # Automatically uses
tokenizer = HubertTokenizer.load_from_checkpoint(model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"], map_location=model.device)
# semantic_tokens = model.text_to_semantic(
# text, max_gen_duration_s=seconds, top_k=50, top_p=0.95, temp=0.7
# ) # not 100%