Spaces:
Runtime error
Runtime error
import gc | |
import os | |
from typing import Any, Callable, List, Literal, Union, Dict, Tuple | |
import logging | |
from safetensors.torch import load_file | |
from safetensors import safe_open | |
import torch | |
from torch import nn | |
from diffusers.models.controlnet import ControlNetModel | |
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel | |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
from .convert_from_ckpt import ( | |
convert_ldm_unet_checkpoint, | |
convert_ldm_vae_checkpoint, | |
convert_ldm_clip_checkpoint, | |
) | |
from .convert_lora_safetensor_to_diffusers import convert_motion_lora_ckpt_to_diffusers | |
logger = logging.getLogger(__name__) | |
def update_pipeline_model_parameters( | |
pipeline: DiffusionPipeline, | |
model_path: str = None, | |
lora_dict: Dict[str, Dict] = None, | |
text_model_path: str = None, | |
device="cuda", | |
need_unload: bool = False, | |
): | |
if model_path is not None: | |
pipeline = update_pipeline_basemodel( | |
pipeline, model_path, text_sd_model_path=text_model_path, device=device | |
) | |
if lora_dict is not None: | |
pipeline, unload_dict = update_pipeline_lora_models( | |
pipeline, | |
lora_dict, | |
device=device, | |
need_unload=need_unload, | |
) | |
if need_unload: | |
return pipeline, unload_dict | |
return pipeline | |
def update_pipeline_basemodel( | |
pipeline: DiffusionPipeline, | |
model_path: str, | |
text_sd_model_path: str, | |
device: str = "cuda", | |
): | |
"""使用model_path更新pipeline中的基础参数 | |
Args: | |
pipeline (DiffusionPipeline): _description_ | |
model_path (str): _description_ | |
text_sd_model_path (str): _description_ | |
device (str, optional): _description_. Defaults to "cuda". | |
Returns: | |
_type_: _description_ | |
""" | |
# load base | |
if model_path.endswith(".ckpt"): | |
state_dict = torch.load(model_path, map_location=device) | |
pipeline.unet.load_state_dict(state_dict) | |
print("update sd_model", model_path) | |
elif model_path.endswith(".safetensors"): | |
base_state_dict = {} | |
with safe_open(model_path, framework="pt", device=device) as f: | |
for key in f.keys(): | |
base_state_dict[key] = f.get_tensor(key) | |
is_lora = all("lora" in k for k in base_state_dict.keys()) | |
assert is_lora == False, "Base model cannot be LoRA: {}".format(model_path) | |
# vae | |
converted_vae_checkpoint = convert_ldm_vae_checkpoint( | |
base_state_dict, pipeline.vae.config | |
) | |
pipeline.vae.load_state_dict(converted_vae_checkpoint) | |
# unet | |
converted_unet_checkpoint = convert_ldm_unet_checkpoint( | |
base_state_dict, pipeline.unet.config | |
) | |
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) | |
# text_model | |
pipeline.text_encoder = convert_ldm_clip_checkpoint( | |
base_state_dict, text_sd_model_path | |
) | |
print("update sd_model", model_path) | |
pipeline.to(device) | |
return pipeline | |
# ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/cfg.yaml | |
LORA_BLOCK_WEIGHT_MAP = { | |
"FACE": [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], | |
"DEFACE": [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1], | |
"ALL": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], | |
"MIDD": [1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], | |
"OUTALL": [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], | |
} | |
# ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/draw_pipe.py | |
def update_pipeline_lora_model( | |
pipeline: DiffusionPipeline, | |
lora: Union[str, Dict], | |
alpha: float = 0.75, | |
device: str = "cuda", | |
lora_prefix_unet: str = "lora_unet", | |
lora_prefix_text_encoder: str = "lora_te", | |
lora_unet_layers=[ | |
"lora_unet_down_blocks_0_attentions_0", | |
"lora_unet_down_blocks_0_attentions_1", | |
"lora_unet_down_blocks_1_attentions_0", | |
"lora_unet_down_blocks_1_attentions_1", | |
"lora_unet_down_blocks_2_attentions_0", | |
"lora_unet_down_blocks_2_attentions_1", | |
"lora_unet_mid_block_attentions_0", | |
"lora_unet_up_blocks_1_attentions_0", | |
"lora_unet_up_blocks_1_attentions_1", | |
"lora_unet_up_blocks_1_attentions_2", | |
"lora_unet_up_blocks_2_attentions_0", | |
"lora_unet_up_blocks_2_attentions_1", | |
"lora_unet_up_blocks_2_attentions_2", | |
"lora_unet_up_blocks_3_attentions_0", | |
"lora_unet_up_blocks_3_attentions_1", | |
"lora_unet_up_blocks_3_attentions_2", | |
], | |
lora_block_weight_str: Literal["FACE", "ALL"] = "ALL", | |
need_unload: bool = False, | |
): | |
"""使用 lora 更新pipeline中的unet相关参数 | |
Args: | |
pipeline (DiffusionPipeline): _description_ | |
lora (Union[str, Dict]): _description_ | |
alpha (float, optional): _description_. Defaults to 0.75. | |
device (str, optional): _description_. Defaults to "cuda". | |
lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet". | |
lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te". | |
lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ]. | |
lora_block_weight_str (Literal["FACE", "ALL"], optional): _description_. Defaults to "ALL". | |
need_unload (bool, optional): _description_. Defaults to False. | |
Returns: | |
_type_: _description_ | |
""" | |
# ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/tool.py#L20 | |
if lora_block_weight_str is not None: | |
lora_block_weight = LORA_BLOCK_WEIGHT_MAP[lora_block_weight_str.upper()] | |
if lora_block_weight: | |
assert len(lora_block_weight) == 17 | |
# load lora weight | |
if isinstance(lora, str): | |
state_dict = load_file(lora, device=device) | |
else: | |
for k in lora: | |
lora[k] = lora[k].to(device) | |
state_dict = lora # state_dict = {} | |
visited = set() | |
unload_dict = [] | |
# directly update weight in diffusers model | |
for key in state_dict: | |
# it is suggested to print out the key, it usually will be something like below | |
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" | |
# as we have set the alpha beforehand, so just skip | |
if ".alpha" in key or key in visited: | |
continue | |
if "text" in key: | |
layer_infos = ( | |
key.split(".")[0].split(lora_prefix_text_encoder + "_")[-1].split("_") | |
) | |
curr_layer = pipeline.text_encoder | |
else: | |
layer_infos = key.split(".")[0].split(lora_prefix_unet + "_")[-1].split("_") | |
curr_layer = pipeline.unet | |
# find the target layer | |
temp_name = layer_infos.pop(0) | |
while len(layer_infos) > -1: | |
try: | |
curr_layer = curr_layer.__getattr__(temp_name) | |
if len(layer_infos) > 0: | |
temp_name = layer_infos.pop(0) | |
elif len(layer_infos) == 0: | |
break | |
except Exception: | |
if len(temp_name) > 0: | |
temp_name += "_" + layer_infos.pop(0) | |
else: | |
temp_name = layer_infos.pop(0) | |
pair_keys = [] | |
if "lora_down" in key: | |
pair_keys.append(key.replace("lora_down", "lora_up")) | |
pair_keys.append(key) | |
alpha_key = key.replace("lora_down.weight", "alpha") | |
else: | |
pair_keys.append(key) | |
pair_keys.append(key.replace("lora_up", "lora_down")) | |
alpha_key = key.replace("lora_up.weight", "alpha") | |
# update weight | |
if len(state_dict[pair_keys[0]].shape) == 4: | |
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) | |
weight_down = ( | |
state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) | |
) | |
if alpha_key in state_dict: | |
weight_scale = state_dict[alpha_key].item() / weight_up.shape[1] | |
else: | |
weight_scale = 1.0 | |
# adding_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) | |
if len(weight_up.shape) == len(weight_down.shape): | |
adding_weight = ( | |
alpha | |
* weight_scale | |
* torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) | |
) | |
else: | |
adding_weight = ( | |
alpha | |
* weight_scale | |
* torch.einsum("a b, b c h w -> a c h w", weight_up, weight_down) | |
) | |
else: | |
weight_up = state_dict[pair_keys[0]].to(torch.float32) | |
weight_down = state_dict[pair_keys[1]].to(torch.float32) | |
if alpha_key in state_dict: | |
weight_scale = state_dict[alpha_key].item() / weight_up.shape[1] | |
else: | |
weight_scale = 1.0 | |
adding_weight = alpha * weight_scale * torch.mm(weight_up, weight_down) | |
adding_weight = adding_weight.to(torch.float16) | |
if lora_block_weight: | |
if "text" in key: | |
adding_weight *= lora_block_weight[0] | |
else: | |
for idx, layer in enumerate(lora_unet_layers): | |
if layer in key: | |
adding_weight *= lora_block_weight[idx + 1] | |
break | |
curr_layer_unload_data = {"layer": curr_layer, "added_weight": adding_weight} | |
curr_layer.weight.data += adding_weight | |
unload_dict.append(curr_layer_unload_data) | |
# update visited list | |
for item in pair_keys: | |
visited.add(item) | |
if need_unload: | |
return pipeline, unload_dict | |
else: | |
return pipeline | |
# ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/draw_pipe.py | |
def update_pipeline_lora_model_old( | |
pipeline: DiffusionPipeline, | |
lora: Union[str, Dict], | |
alpha: float = 0.75, | |
device: str = "cuda", | |
lora_prefix_unet: str = "lora_unet", | |
lora_prefix_text_encoder: str = "lora_te", | |
lora_unet_layers=[ | |
"lora_unet_down_blocks_0_attentions_0", | |
"lora_unet_down_blocks_0_attentions_1", | |
"lora_unet_down_blocks_1_attentions_0", | |
"lora_unet_down_blocks_1_attentions_1", | |
"lora_unet_down_blocks_2_attentions_0", | |
"lora_unet_down_blocks_2_attentions_1", | |
"lora_unet_mid_block_attentions_0", | |
"lora_unet_up_blocks_1_attentions_0", | |
"lora_unet_up_blocks_1_attentions_1", | |
"lora_unet_up_blocks_1_attentions_2", | |
"lora_unet_up_blocks_2_attentions_0", | |
"lora_unet_up_blocks_2_attentions_1", | |
"lora_unet_up_blocks_2_attentions_2", | |
"lora_unet_up_blocks_3_attentions_0", | |
"lora_unet_up_blocks_3_attentions_1", | |
"lora_unet_up_blocks_3_attentions_2", | |
], | |
lora_block_weight_str: Literal["FACE", "ALL"] = "ALL", | |
need_unload: bool = False, | |
): | |
"""使用 lora 更新pipeline中的unet相关参数 | |
Args: | |
pipeline (DiffusionPipeline): _description_ | |
lora (Union[str, Dict]): _description_ | |
alpha (float, optional): _description_. Defaults to 0.75. | |
device (str, optional): _description_. Defaults to "cuda". | |
lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet". | |
lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te". | |
lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ]. | |
lora_block_weight_str (Literal["FACE", "ALL"], optional): _description_. Defaults to "ALL". | |
need_unload (bool, optional): _description_. Defaults to False. | |
Returns: | |
_type_: _description_ | |
""" | |
# ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/tool.py#L20 | |
if lora_block_weight_str is not None: | |
lora_block_weight = LORA_BLOCK_WEIGHT_MAP[lora_block_weight_str.upper()] | |
if lora_block_weight: | |
assert len(lora_block_weight) == 17 | |
# load lora weight | |
if isinstance(lora, str): | |
state_dict = load_file(lora, device=device) | |
else: | |
for k in lora: | |
lora[k] = lora[k].to(device) | |
state_dict = lora # state_dict = {} | |
visited = set() | |
unload_dict = [] | |
# directly update weight in diffusers model | |
for key in state_dict: | |
# it is suggested to print out the key, it usually will be something like below | |
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" | |
# as we have set the alpha beforehand, so just skip | |
if ".alpha" in key or key in visited: | |
continue | |
if "text" in key: | |
layer_infos = ( | |
key.split(".")[0].split(lora_prefix_text_encoder + "_")[-1].split("_") | |
) | |
curr_layer = pipeline.text_encoder | |
else: | |
layer_infos = key.split(".")[0].split(lora_prefix_unet + "_")[-1].split("_") | |
curr_layer = pipeline.unet | |
# find the target layer | |
temp_name = layer_infos.pop(0) | |
while len(layer_infos) > -1: | |
try: | |
curr_layer = curr_layer.__getattr__(temp_name) | |
if len(layer_infos) > 0: | |
temp_name = layer_infos.pop(0) | |
elif len(layer_infos) == 0: | |
break | |
except Exception: | |
if len(temp_name) > 0: | |
temp_name += "_" + layer_infos.pop(0) | |
else: | |
temp_name = layer_infos.pop(0) | |
pair_keys = [] | |
if "lora_down" in key: | |
pair_keys.append(key.replace("lora_down", "lora_up")) | |
pair_keys.append(key) | |
else: | |
pair_keys.append(key) | |
pair_keys.append(key.replace("lora_up", "lora_down")) | |
# update weight | |
if len(state_dict[pair_keys[0]].shape) == 4: | |
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) | |
weight_down = ( | |
state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) | |
) | |
adding_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze( | |
2 | |
).unsqueeze(3) | |
else: | |
weight_up = state_dict[pair_keys[0]].to(torch.float32) | |
weight_down = state_dict[pair_keys[1]].to(torch.float32) | |
adding_weight = alpha * torch.mm(weight_up, weight_down) | |
if lora_block_weight: | |
if "text" in key: | |
adding_weight *= lora_block_weight[0] | |
else: | |
for idx, layer in enumerate(lora_unet_layers): | |
if layer in key: | |
adding_weight *= lora_block_weight[idx + 1] | |
break | |
curr_layer_unload_data = {"layer": curr_layer, "added_weight": adding_weight} | |
curr_layer.weight.data += adding_weight | |
unload_dict.append(curr_layer_unload_data) | |
# update visited list | |
for item in pair_keys: | |
visited.add(item) | |
if need_unload: | |
return pipeline, unload_dict | |
else: | |
return pipeline | |
def update_pipeline_lora_models( | |
pipeline: DiffusionPipeline, | |
lora_dict: Dict[str, Dict], | |
device: str = "cuda", | |
need_unload: bool = True, | |
lora_prefix_unet: str = "lora_unet", | |
lora_prefix_text_encoder: str = "lora_te", | |
lora_unet_layers=[ | |
"lora_unet_down_blocks_0_attentions_0", | |
"lora_unet_down_blocks_0_attentions_1", | |
"lora_unet_down_blocks_1_attentions_0", | |
"lora_unet_down_blocks_1_attentions_1", | |
"lora_unet_down_blocks_2_attentions_0", | |
"lora_unet_down_blocks_2_attentions_1", | |
"lora_unet_mid_block_attentions_0", | |
"lora_unet_up_blocks_1_attentions_0", | |
"lora_unet_up_blocks_1_attentions_1", | |
"lora_unet_up_blocks_1_attentions_2", | |
"lora_unet_up_blocks_2_attentions_0", | |
"lora_unet_up_blocks_2_attentions_1", | |
"lora_unet_up_blocks_2_attentions_2", | |
"lora_unet_up_blocks_3_attentions_0", | |
"lora_unet_up_blocks_3_attentions_1", | |
"lora_unet_up_blocks_3_attentions_2", | |
], | |
): | |
"""使用 lora 更新pipeline中的unet相关参数 | |
Args: | |
pipeline (DiffusionPipeline): _description_ | |
lora_dict (Dict[str, Dict]): _description_ | |
device (str, optional): _description_. Defaults to "cuda". | |
lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet". | |
lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te". | |
lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ]. | |
Returns: | |
_type_: _description_ | |
""" | |
unload_dicts = [] | |
for lora, value in lora_dict.items(): | |
lora_name = os.path.basename(lora).replace(".safetensors", "") | |
strength_offset = value.get("strength_offset", 0.0) | |
alpha = value.get("strength", 1.0) | |
alpha += strength_offset | |
lora_weight_str = value.get("lora_block_weight", "ALL") | |
lora = load_file(lora) | |
pipeline, unload_dict = update_pipeline_lora_model( | |
pipeline, | |
lora=lora, | |
device=device, | |
alpha=alpha, | |
lora_prefix_unet=lora_prefix_unet, | |
lora_prefix_text_encoder=lora_prefix_text_encoder, | |
lora_unet_layers=lora_unet_layers, | |
lora_block_weight_str=lora_weight_str, | |
need_unload=True, | |
) | |
print( | |
"Update LoRA {} with alpha {} and weight {}".format( | |
lora_name, alpha, lora_weight_str | |
) | |
) | |
unload_dicts += unload_dict | |
return pipeline, unload_dicts | |
def unload_lora(unload_dict: List[Dict[str, nn.Module]]): | |
for layer_data in unload_dict: | |
layer = layer_data["layer"] | |
added_weight = layer_data["added_weight"] | |
layer.weight.data -= added_weight | |
gc.collect() | |
torch.cuda.empty_cache() | |
def load_motion_lora_weights( | |
animation_pipeline, | |
motion_module_lora_configs=[], | |
): | |
for motion_module_lora_config in motion_module_lora_configs: | |
path, alpha = ( | |
motion_module_lora_config["path"], | |
motion_module_lora_config["alpha"], | |
) | |
print(f"load motion LoRA from {path}") | |
motion_lora_state_dict = torch.load(path, map_location="cpu") | |
motion_lora_state_dict = ( | |
motion_lora_state_dict["state_dict"] | |
if "state_dict" in motion_lora_state_dict | |
else motion_lora_state_dict | |
) | |
animation_pipeline = convert_motion_lora_ckpt_to_diffusers( | |
animation_pipeline, motion_lora_state_dict, alpha | |
) | |
return animation_pipeline | |