Spaces:
Build error
Build error
from __future__ import annotations | |
import os | |
import torch | |
from safetensors.torch import save_file | |
ckpt = "path_to/pytorch_model.bin" | |
vista_bin = torch.load(ckpt, map_location="cpu") # only contains model weights | |
for k in list(vista_bin.keys()): # merge LoRA weights (if exist) for inference | |
if "adapter_down" in k: | |
if "q_adapter_down" in k: | |
up_k = k.replace("q_adapter_down", "q_adapter_up") | |
pretrain_k = k.replace("q_adapter_down", "to_q") | |
elif "k_adapter_down" in k: | |
up_k = k.replace("k_adapter_down", "k_adapter_up") | |
pretrain_k = k.replace("k_adapter_down", "to_k") | |
elif "v_adapter_down" in k: | |
up_k = k.replace("v_adapter_down", "v_adapter_up") | |
pretrain_k = k.replace("v_adapter_down", "to_v") | |
else: | |
up_k = k.replace("out_adapter_down", "out_adapter_up") | |
if "model_ema" in k: | |
pretrain_k = k.replace("out_adapter_down", "to_out0") | |
else: | |
pretrain_k = k.replace("out_adapter_down", "to_out.0") | |
lora_weights = vista_bin[up_k] @ vista_bin[k] | |
del vista_bin[k] | |
del vista_bin[up_k] | |
vista_bin[pretrain_k] = vista_bin[pretrain_k] + lora_weights | |
for k in list(vista_bin.keys()): # remove the prefix | |
if "_forward_module" in k and "decay" not in k and "num_updates" not in k: | |
vista_bin[k.replace("_forward_module.", "")] = vista_bin[k] | |
del vista_bin[k] | |
for k in list(vista_bin.keys()): # combine EMA weights | |
if "model_ema" in k: | |
orig_k = None | |
for kk in list(vista_bin.keys()): | |
if "model_ema" not in kk and k[10:] == kk[6:].replace(".", ""): | |
orig_k = kk | |
assert orig_k is not None | |
vista_bin[orig_k] = vista_bin[k] | |
del vista_bin[k] | |
print("Replace", orig_k, "with", k) | |
vista_st = dict() | |
for k in list(vista_bin.keys()): | |
vista_st[k] = vista_bin[k] | |
os.makedirs("ckpts", exist_ok=True) | |
save_file(vista_st, "ckpts/vista.safetensors") | |