Spaces:
Running
on
Zero
Running
on
Zero
wondervictor
commited on
Update model_new.py
Browse files- model_new.py +32 -40
model_new.py
CHANGED
@@ -12,33 +12,13 @@ from condition.canny import CannyDetector
|
|
12 |
import time
|
13 |
from autoregressive.models.generate import generate
|
14 |
from condition.midas.depth import MidasDetector
|
15 |
-
|
16 |
|
17 |
models = {
|
18 |
-
"
|
19 |
-
"depth": "checkpoints/
|
20 |
}
|
21 |
-
|
22 |
-
|
23 |
-
def resize_image_to_16_multiple(image, condition_type='canny'):
|
24 |
-
if isinstance(image, np.ndarray):
|
25 |
-
image = Image.fromarray(image)
|
26 |
-
# image = Image.open(image_path)
|
27 |
-
width, height = image.size
|
28 |
-
|
29 |
-
if condition_type == 'depth': # The depth model requires a side length that is a multiple of 32
|
30 |
-
new_width = (width + 31) // 32 * 32
|
31 |
-
new_height = (height + 31) // 32 * 32
|
32 |
-
else:
|
33 |
-
new_width = (width + 15) // 16 * 16
|
34 |
-
new_height = (height + 15) // 16 * 16
|
35 |
-
|
36 |
-
resized_image = image.resize((new_width, new_height))
|
37 |
-
return resized_image
|
38 |
-
|
39 |
-
|
40 |
class Model:
|
41 |
-
|
42 |
def __init__(self):
|
43 |
self.device = torch.device(
|
44 |
"cuda")
|
@@ -46,8 +26,9 @@ class Model:
|
|
46 |
self.task_name = ""
|
47 |
self.vq_model = self.load_vq()
|
48 |
self.t5_model = self.load_t5()
|
49 |
-
self.
|
50 |
-
|
|
|
51 |
|
52 |
def to(self, device):
|
53 |
self.gpt_model_canny.to('cuda')
|
@@ -67,19 +48,17 @@ class Model:
|
|
67 |
gpt_ckpt = models[condition_type]
|
68 |
# precision = torch.bfloat16
|
69 |
precision = torch.float32
|
70 |
-
latent_size =
|
71 |
gpt_model = GPT_models["GPT-XL"](
|
72 |
block_size=latent_size**2,
|
73 |
cls_token_num=120,
|
74 |
model_type='t2i',
|
75 |
condition_type=condition_type,
|
|
|
76 |
).to(device='cpu', dtype=precision)
|
77 |
-
|
78 |
model_weight = load_file(gpt_ckpt)
|
79 |
-
print("prev:", model_weight['adapter.model.embeddings.patch_embeddings.projection.weight'])
|
80 |
gpt_model.load_state_dict(model_weight, strict=True)
|
81 |
gpt_model.eval()
|
82 |
-
print("loaded:", gpt_model.adapter.model.embeddings.patch_embeddings.projection.weight)
|
83 |
print("gpt model is loaded")
|
84 |
return gpt_model
|
85 |
|
@@ -109,22 +88,35 @@ class Model:
|
|
109 |
seed: int,
|
110 |
low_threshold: int,
|
111 |
high_threshold: int,
|
|
|
|
|
112 |
) -> list[PIL.Image.Image]:
|
113 |
-
print(image)
|
114 |
-
image = resize_image_to_16_multiple(image, 'canny')
|
115 |
-
W, H = image.size
|
116 |
-
print(W, H)
|
117 |
self.t5_model.model.to('cuda').to(torch.bfloat16)
|
118 |
self.gpt_model_canny.to('cuda').to(torch.bfloat16)
|
119 |
self.vq_model.to('cuda')
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
condition_img = condition_img.to(self.device)
|
127 |
-
condition_img = 2
|
128 |
prompts = [prompt] * 2
|
129 |
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
|
130 |
|
|
|
12 |
import time
|
13 |
from autoregressive.models.generate import generate
|
14 |
from condition.midas.depth import MidasDetector
|
15 |
+
from preprocessor import Preprocessor
|
16 |
|
17 |
models = {
|
18 |
+
"edge": "checkpoints/edge_base.safetensors",
|
19 |
+
"depth": "checkpoints/depth_base.safetensors",
|
20 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
class Model:
|
|
|
22 |
def __init__(self):
|
23 |
self.device = torch.device(
|
24 |
"cuda")
|
|
|
26 |
self.task_name = ""
|
27 |
self.vq_model = self.load_vq()
|
28 |
self.t5_model = self.load_t5()
|
29 |
+
self.gpt_model_edge = self.load_gpt(condition_type='edge')
|
30 |
+
self.gpt_model_depth = self.load_gpt(condition_type='depth')
|
31 |
+
self.preprocessor = Preprocessor()
|
32 |
|
33 |
def to(self, device):
|
34 |
self.gpt_model_canny.to('cuda')
|
|
|
48 |
gpt_ckpt = models[condition_type]
|
49 |
# precision = torch.bfloat16
|
50 |
precision = torch.float32
|
51 |
+
latent_size = 512 // 16
|
52 |
gpt_model = GPT_models["GPT-XL"](
|
53 |
block_size=latent_size**2,
|
54 |
cls_token_num=120,
|
55 |
model_type='t2i',
|
56 |
condition_type=condition_type,
|
57 |
+
adapter_size='base',
|
58 |
).to(device='cpu', dtype=precision)
|
|
|
59 |
model_weight = load_file(gpt_ckpt)
|
|
|
60 |
gpt_model.load_state_dict(model_weight, strict=True)
|
61 |
gpt_model.eval()
|
|
|
62 |
print("gpt model is loaded")
|
63 |
return gpt_model
|
64 |
|
|
|
88 |
seed: int,
|
89 |
low_threshold: int,
|
90 |
high_threshold: int,
|
91 |
+
control_strength: float,
|
92 |
+
preprocessor_name: str,
|
93 |
) -> list[PIL.Image.Image]:
|
|
|
|
|
|
|
|
|
94 |
self.t5_model.model.to('cuda').to(torch.bfloat16)
|
95 |
self.gpt_model_canny.to('cuda').to(torch.bfloat16)
|
96 |
self.vq_model.to('cuda')
|
97 |
+
if isinstance(image, np.ndarray):
|
98 |
+
image = Image.fromarray(image)
|
99 |
+
origin_W, origin_H = image.size
|
100 |
+
if preprocessor_name == 'Canny':
|
101 |
+
self.preprocessor.load("Canny")
|
102 |
+
condition_img = self.preprocessor(
|
103 |
+
image=image, low_threshold=low_threshold, high_threshold=high_threshold, detect_resolution=512)
|
104 |
+
elif preprocessor_name == 'Hed':
|
105 |
+
self.preprocessor.load("HED")
|
106 |
+
condition_img = self.preprocessor(
|
107 |
+
image=image,image_resolution=512, detect_resolution=512)
|
108 |
+
elif preprocessor_name == 'Lineart':
|
109 |
+
self.preprocessor.load("Lineart")
|
110 |
+
condition_img = self.preprocessor(
|
111 |
+
image=image,image_resolution=512, detect_resolution=512)
|
112 |
+
elif preprocessor_name == 'No preprocess':
|
113 |
+
condition_img = image
|
114 |
+
condition_img = condition_img.resize((512,512))
|
115 |
+
W, H = condition_img.size
|
116 |
+
|
117 |
+
condition_img = torch.from_numpy(np.array(condition_img)).unsqueeze(0).permute(0,3,1,2).repeat(2,1,1,1)
|
118 |
condition_img = condition_img.to(self.device)
|
119 |
+
condition_img = 2*(condition_img/255 - 0.5)
|
120 |
prompts = [prompt] * 2
|
121 |
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
|
122 |
|