Spaces:
Build error
Build error
# import argparse | |
import sys | |
from pathlib import Path | |
from pytorch_lightning.cli import LightningCLI | |
from PIL import Image | |
# For streaming | |
import yaml | |
from copy import deepcopy | |
from typing import List, Optional | |
from jsonargparse.typing import restricted_string_type | |
# -------------------------------------- | |
# ----------- For Streaming ------------ | |
# -------------------------------------- | |
class CustomCLI(LightningCLI): | |
def add_arguments_to_parser(self, parser): | |
parser.add_argument("--result_fol", type=Path, | |
help="Set the path to the result folder", default="results") | |
parser.add_argument("--exp_name", type=str, help="Experiment name") | |
parser.add_argument("--run_name", type=str, | |
help="Current run name") | |
parser.add_argument("--prompts", type=Optional[List[str]]) | |
parser.add_argument("--scale_lr", type=bool, | |
help="Scale lr", default=False) | |
CodeType = restricted_string_type( | |
'CodeType', '(medium)|(high)|(highest)') | |
parser.add_argument("--matmul_precision", type=CodeType) | |
parser.add_argument("--ckpt", type=Path,) | |
parser.add_argument("--n_predictions", type=int) | |
return parser | |
def remove_value(dictionary, x): | |
for key, value in list(dictionary.items()): | |
if key == x: | |
del dictionary[key] | |
elif isinstance(value, dict): | |
remove_value(value, x) | |
return dictionary | |
def legacy_transformation(cfg: yaml): | |
cfg = deepcopy(cfg) | |
cfg["trainer"]["devices"] = "1" | |
cfg["trainer"]['num_nodes'] = 1 | |
if not "class_path" in cfg["model"]["inference_params"]: | |
cfg["model"]["inference_params"] = { | |
"class_path": "t2v_enhanced.model.pl_module_params.InferenceParams", "init_args": cfg["model"]["inference_params"]} | |
return cfg | |
# --------------------------------------------- | |
# ----------- For enhancement ----------- | |
# --------------------------------------------- | |
def add_margin(pil_img, top, right, bottom, left, color): | |
width, height = pil_img.size | |
new_width = width + right + left | |
new_height = height + top + bottom | |
result = Image.new(pil_img.mode, (new_width, new_height), color) | |
result.paste(pil_img, (left, top)) | |
return result | |
def resize_to_fit(image, size): | |
W, H = size | |
w, h = image.size | |
if H / h > W / w: | |
H_ = int(h * W / w) | |
W_ = W | |
else: | |
W_ = int(w * H / h) | |
H_ = H | |
return image.resize((W_, H_)) | |
def pad_to_fit(image, size): | |
W, H = size | |
w, h = image.size | |
pad_h = (H - h) // 2 | |
pad_w = (W - w) // 2 | |
return add_margin(image, pad_h, pad_w, pad_h, pad_w, (0, 0, 0)) | |
def resize_and_keep(pil_img): | |
myheight = 576 | |
hpercent = (myheight/float(pil_img.size[1])) | |
wsize = int((float(pil_img.size[0])*float(hpercent))) | |
pil_img = pil_img.resize((wsize, myheight)) | |
return pil_img | |
def center_crop(pil_img): | |
width, height = pil_img.size | |
new_width = 576 | |
new_height = 576 | |
left = (width - new_width)/2 | |
top = (height - new_height)/2 | |
right = (width + new_width)/2 | |
bottom = (height + new_height)/2 | |
# Crop the center of the image | |
pil_img = pil_img.crop((left, top, right, bottom)) | |
return pil_img |