wondervictor commited on
Commit
180819f
·
verified ·
1 Parent(s): 2d1e0bb

Update model_new.py

Browse files
Files changed (1) hide show
  1. model_new.py +32 -40
model_new.py CHANGED
@@ -12,33 +12,13 @@ from condition.canny import CannyDetector
12
  import time
13
  from autoregressive.models.generate import generate
14
  from condition.midas.depth import MidasDetector
15
-
16
 
17
  models = {
18
- "canny": "checkpoints/canny_MR.safetensors",
19
- "depth": "checkpoints/depth_MR.safetensors",
20
  }
21
-
22
-
23
- def resize_image_to_16_multiple(image, condition_type='canny'):
24
- if isinstance(image, np.ndarray):
25
- image = Image.fromarray(image)
26
- # image = Image.open(image_path)
27
- width, height = image.size
28
-
29
- if condition_type == 'depth': # The depth model requires a side length that is a multiple of 32
30
- new_width = (width + 31) // 32 * 32
31
- new_height = (height + 31) // 32 * 32
32
- else:
33
- new_width = (width + 15) // 16 * 16
34
- new_height = (height + 15) // 16 * 16
35
-
36
- resized_image = image.resize((new_width, new_height))
37
- return resized_image
38
-
39
-
40
  class Model:
41
-
42
  def __init__(self):
43
  self.device = torch.device(
44
  "cuda")
@@ -46,8 +26,9 @@ class Model:
46
  self.task_name = ""
47
  self.vq_model = self.load_vq()
48
  self.t5_model = self.load_t5()
49
- self.gpt_model_canny = self.load_gpt(condition_type='canny')
50
- # self.gpt_model_depth = self.load_gpt(condition_type='depth')
 
51
 
52
  def to(self, device):
53
  self.gpt_model_canny.to('cuda')
@@ -67,19 +48,17 @@ class Model:
67
  gpt_ckpt = models[condition_type]
68
  # precision = torch.bfloat16
69
  precision = torch.float32
70
- latent_size = 768 // 16
71
  gpt_model = GPT_models["GPT-XL"](
72
  block_size=latent_size**2,
73
  cls_token_num=120,
74
  model_type='t2i',
75
  condition_type=condition_type,
 
76
  ).to(device='cpu', dtype=precision)
77
-
78
  model_weight = load_file(gpt_ckpt)
79
- print("prev:", model_weight['adapter.model.embeddings.patch_embeddings.projection.weight'])
80
  gpt_model.load_state_dict(model_weight, strict=True)
81
  gpt_model.eval()
82
- print("loaded:", gpt_model.adapter.model.embeddings.patch_embeddings.projection.weight)
83
  print("gpt model is loaded")
84
  return gpt_model
85
 
@@ -109,22 +88,35 @@ class Model:
109
  seed: int,
110
  low_threshold: int,
111
  high_threshold: int,
 
 
112
  ) -> list[PIL.Image.Image]:
113
- print(image)
114
- image = resize_image_to_16_multiple(image, 'canny')
115
- W, H = image.size
116
- print(W, H)
117
  self.t5_model.model.to('cuda').to(torch.bfloat16)
118
  self.gpt_model_canny.to('cuda').to(torch.bfloat16)
119
  self.vq_model.to('cuda')
120
-
121
- condition_img = self.get_control_canny(np.array(image), low_threshold,
122
- high_threshold)
123
- condition_img = torch.from_numpy(condition_img[None, None,
124
- ...]).repeat(
125
- 2, 3, 1, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  condition_img = condition_img.to(self.device)
127
- condition_img = 2 * (condition_img / 255 - 0.5)
128
  prompts = [prompt] * 2
129
  caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
130
 
 
12
  import time
13
  from autoregressive.models.generate import generate
14
  from condition.midas.depth import MidasDetector
15
+ from preprocessor import Preprocessor
16
 
17
  models = {
18
+ "edge": "checkpoints/edge_base.safetensors",
19
+ "depth": "checkpoints/depth_base.safetensors",
20
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class Model:
 
22
  def __init__(self):
23
  self.device = torch.device(
24
  "cuda")
 
26
  self.task_name = ""
27
  self.vq_model = self.load_vq()
28
  self.t5_model = self.load_t5()
29
+ self.gpt_model_edge = self.load_gpt(condition_type='edge')
30
+ self.gpt_model_depth = self.load_gpt(condition_type='depth')
31
+ self.preprocessor = Preprocessor()
32
 
33
  def to(self, device):
34
  self.gpt_model_canny.to('cuda')
 
48
  gpt_ckpt = models[condition_type]
49
  # precision = torch.bfloat16
50
  precision = torch.float32
51
+ latent_size = 512 // 16
52
  gpt_model = GPT_models["GPT-XL"](
53
  block_size=latent_size**2,
54
  cls_token_num=120,
55
  model_type='t2i',
56
  condition_type=condition_type,
57
+ adapter_size='base',
58
  ).to(device='cpu', dtype=precision)
 
59
  model_weight = load_file(gpt_ckpt)
 
60
  gpt_model.load_state_dict(model_weight, strict=True)
61
  gpt_model.eval()
 
62
  print("gpt model is loaded")
63
  return gpt_model
64
 
 
88
  seed: int,
89
  low_threshold: int,
90
  high_threshold: int,
91
+ control_strength: float,
92
+ preprocessor_name: str,
93
  ) -> list[PIL.Image.Image]:
 
 
 
 
94
  self.t5_model.model.to('cuda').to(torch.bfloat16)
95
  self.gpt_model_canny.to('cuda').to(torch.bfloat16)
96
  self.vq_model.to('cuda')
97
+ if isinstance(image, np.ndarray):
98
+ image = Image.fromarray(image)
99
+ origin_W, origin_H = image.size
100
+ if preprocessor_name == 'Canny':
101
+ self.preprocessor.load("Canny")
102
+ condition_img = self.preprocessor(
103
+ image=image, low_threshold=low_threshold, high_threshold=high_threshold, detect_resolution=512)
104
+ elif preprocessor_name == 'Hed':
105
+ self.preprocessor.load("HED")
106
+ condition_img = self.preprocessor(
107
+ image=image,image_resolution=512, detect_resolution=512)
108
+ elif preprocessor_name == 'Lineart':
109
+ self.preprocessor.load("Lineart")
110
+ condition_img = self.preprocessor(
111
+ image=image,image_resolution=512, detect_resolution=512)
112
+ elif preprocessor_name == 'No preprocess':
113
+ condition_img = image
114
+ condition_img = condition_img.resize((512,512))
115
+ W, H = condition_img.size
116
+
117
+ condition_img = torch.from_numpy(np.array(condition_img)).unsqueeze(0).permute(0,3,1,2).repeat(2,1,1,1)
118
  condition_img = condition_img.to(self.device)
119
+ condition_img = 2*(condition_img/255 - 0.5)
120
  prompts = [prompt] * 2
121
  caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
122