Spaces:
Sleeping
Sleeping
File size: 5,370 Bytes
02cc20b a29cf91 ad88a0b 02cc20b a29cf91 02cc20b a29cf91 02cc20b b0b5a77 02cc20b ad88a0b 02cc20b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from diffusers import AutoencoderKL, DDIMScheduler
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from animatediff.models.unet import UNet3DConditionModel
from omegaconf import OmegaConf
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import load_weights
from safetensors import safe_open
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
from faceadapter.face_adapter import FaceAdapterPlusForVideoLora
model_style_type2base_model_path = {
"realistic": "models/rv51/realisticVisionV51_v51VAE_dste8.safetensors",
"anime": "models/aingdiffusion/aingdiffusion_v170_ar.safetensors",
"photorealistic": "models/sar/sar.safetensors" # LDM format. Needs to be converted.
}
def load_model(model_style_type="realistic", device="cuda"):
inference_config = "inference-v2.yaml"
sd_version = "animatediff/sd"
id_ckpt = "models/animator.ckpt"
image_encoder_path = "models/image_encoder"
base_model_path = model_style_type2base_model_path[model_style_type]
motion_module_path="models/v3_sd15_mm.ckpt"
motion_lora_path = "models/v3_sd15_adapter.ckpt"
inference_config = OmegaConf.load(inference_config)
tokenizer = CLIPTokenizer.from_pretrained(sd_version, subfolder="tokenizer",torch_dtype=torch.float16,
)
text_encoder = CLIPTextModel.from_pretrained(sd_version, subfolder="text_encoder",torch_dtype=torch.float16,
).to(device=device)
vae = AutoencoderKL.from_pretrained(sd_version, subfolder="vae",torch_dtype=torch.float16,
).to(device=device)
unet = UNet3DConditionModel.from_pretrained_2d(sd_version, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)
).to(device=device)
# unet.to(dtype=torch.float16) does not work on hf spaces.
unet = unet.half()
pipeline = AnimationPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
controlnet=None,
#beta_start=0.00085, beta_end=0.012, beta_schedule="linear",steps_offset=1
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
# scheduler=DPMSolverMultistepScheduler(**OmegaConf.to_container(inference_config.DPMSolver_scheduler_kwargs)
# scheduler=EulerAncestralDiscreteScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
# scheduler=EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="linear",steps_offset=1
),
torch_dtype=torch.float16,
).to(device=device)
pipeline = load_weights(
pipeline,
# motion module
motion_module_path = motion_module_path,
motion_module_lora_configs = [],
# domain adapter
adapter_lora_path = motion_lora_path,
adapter_lora_scale = 1,
# image layers
dreambooth_model_path = None,
lora_model_path = "",
lora_alpha = 0.8
).to(device=device)
if base_model_path != "":
print(f"load dreambooth model from {base_model_path}")
dreambooth_state_dict = {}
with safe_open(base_model_path, framework="pt", device="cpu") as f:
for key in f.keys():
dreambooth_state_dict[key] = f.get_tensor(key)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, pipeline.vae.config)
# print(vae)
#vae ->to_q,to_k,to_v
# print(converted_vae_checkpoint)
convert_vae_keys = list(converted_vae_checkpoint.keys())
for key in convert_vae_keys:
if "encoder.mid_block.attentions" in key or "decoder.mid_block.attentions" in key:
new_key = None
if "key" in key:
new_key = key.replace("key","to_k")
elif "query" in key:
new_key = key.replace("query","to_q")
elif "value" in key:
new_key = key.replace("value","to_v")
elif "proj_attn" in key:
new_key = key.replace("proj_attn","to_out.0")
if new_key:
converted_vae_checkpoint[new_key] = converted_vae_checkpoint.pop(key)
pipeline.vae.load_state_dict(converted_vae_checkpoint)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, pipeline.unet.config)
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict, dtype=torch.float16).to(device=device)
del dreambooth_state_dict
pipeline = pipeline.to(torch.float16)
id_animator = FaceAdapterPlusForVideoLora(pipeline, image_encoder_path, id_ckpt, num_tokens=16,
device=torch.device(device), torch_type=torch.float16)
return id_animator
|