Spaces:
Running
Running
import os | |
import re | |
from fairseq import checkpoint_utils | |
def get_index_path_from_model(sid): | |
sid0strip = re.sub(r'\.pth|\.onnx$', '', sid) | |
sid0name = os.path.split(sid0strip)[-1] # Extract only the name, not the directory | |
# Check if the sid0strip has the specific ending format _eXXX_sXXX | |
if re.match(r'.+_e\d+_s\d+$', sid0name): | |
base_model_name = sid0name.rsplit('_', 2)[0] | |
else: | |
base_model_name = sid0name | |
return next( | |
( | |
f | |
for f in [ | |
os.path.join(root, name) | |
for root, _, files in os.walk(os.getenv("index_root"), topdown=False) | |
for name in files | |
if name.endswith(".index") and "trained" not in name | |
] | |
if base_model_name in f | |
), | |
"", | |
) | |
def load_hubert(config): | |
models, _, _ = checkpoint_utils.load_model_ensemble_and_task( | |
["assets/hubert/hubert_base.pt"], | |
suffix="", | |
) | |
hubert_model = models[0] | |
hubert_model = hubert_model.to(config.device) | |
if config.is_half: | |
hubert_model = hubert_model.half() | |
else: | |
hubert_model = hubert_model.float() | |
return hubert_model.eval() | |