Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from safetensors.torch import load_file as safetensors_load_file | |
from safetensors.torch import save_file as safetensors_save_file | |
import sys, argparse | |
def load_ckpt(ckpt_filepath): | |
print(f"Loading model from {ckpt_filepath}") | |
if ckpt_filepath.endswith(".safetensors"): | |
state_dict = safetensors_load_file(ckpt_filepath, device="cpu") | |
ckpt = None | |
else: | |
ckpt = torch.load(ckpt_filepath, map_location="cpu") | |
state_dict = ckpt["state_dict"] | |
return ckpt, state_dict | |
def save_ckpt(ckpt, ckpt_state_dict, ckpt_filepath): | |
if ckpt_filepath.endswith(".safetensors"): | |
safetensors_save_file(ckpt_state_dict, ckpt_filepath) | |
else: | |
if ckpt is not None: | |
torch.save(ckpt, ckpt_filepath) | |
else: | |
torch.save(ckpt_state_dict, ckpt_filepath) | |
print(f"Saved to {ckpt_filepath}") | |
def load_two_models(base_ckpt_filepath, vae_ckpt_filepath): | |
base_ckpt, base_state_dict = load_ckpt(base_ckpt_filepath) | |
_, vae_state_dict = load_ckpt(vae_ckpt_filepath) | |
# Other fields in sd_ckpt are also needed when saving the checkpoint. | |
# So return the whole sd_ckpt. | |
return base_ckpt, base_state_dict, vae_state_dict | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--base_ckpt", type=str, required=True, help="Path to the base checkpoint") | |
parser.add_argument("--vae_ckpt", type=str, required=True, help="Path to the checkpoint providing vae") | |
parser.add_argument("--out_ckpt", type=str, required=True, help="Path to the output checkpoint") | |
args = parser.parse_args() | |
base_ckpt, base_state_dict, vae_state_dict = load_two_models(args.base_ckpt, args.vae_ckpt) | |
# base_state_dict = sd_ckpt["state_dict"] | |
repl_count = 0 | |
for k in base_state_dict: | |
if k.startswith("first_stage_model."): | |
k2 = k.replace("first_stage_model.", "") | |
if k2 not in vae_state_dict: | |
print(f"!!!! '{k2}' not in VAE checkpoint") | |
continue | |
if base_state_dict[k].shape != vae_state_dict[k2].shape: | |
print(f"!!!! '{k}' shape mismatch: {base_state_dict[k].shape} vs {vae_state_dict[k2].shape} !!!!") | |
continue | |
print(k) | |
base_state_dict[k] = vae_state_dict[k2] | |
repl_count += 1 | |
if repl_count > 0: | |
print(f"{repl_count} parameters replaced") | |
save_ckpt(base_ckpt, base_state_dict, args.out_ckpt) | |
else: | |
print("ERROR: No parameter replaced") | |