hpoghos commited on
Commit
1a105d4
·
1 Parent(s): 052250f

remove open_clip from FrozenOpenCLIPImageEmbedder

Browse files
t2v_enhanced/model/diffusers_conditional/models/controlnet/image_embedder.py CHANGED
@@ -1,10 +1,11 @@
1
  import math
2
  from typing import Any, Mapping
3
  import torch
 
4
  import torch.nn as nn
5
  import kornia
6
- from open_clip import create_model_and_transforms
7
- from transformers import AutoImageProcessor, AutoModel
8
  from transformers.models.bit.image_processing_bit import BitImageProcessor
9
  from einops import rearrange, repeat
10
  # FFN
@@ -72,13 +73,16 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
72
  output_tokens=False,
73
  ):
74
  super().__init__()
75
- model, _, _ = create_model_and_transforms(
76
- arch,
77
- device=torch.device("cpu"),
78
- pretrained=version,
79
- )
80
- del model.transformer
81
- self.model = model
 
 
 
82
  self.max_crops = num_image_crops
83
  self.pad_to_max_len = self.max_crops > 0
84
  self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
@@ -98,7 +102,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
98
  self.ucg_rate = ucg_rate
99
  self.unsqueeze_dim = unsqueeze_dim
100
  self.stored_batch = None
101
- self.model.visual.output_tokens = output_tokens
102
  self.output_tokens = output_tokens
103
 
104
  def preprocess(self, x):
@@ -116,9 +120,10 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
116
  return x
117
 
118
  def freeze(self):
119
- self.model = self.model.eval()
120
  for param in self.parameters():
121
  param.requires_grad = False
 
122
 
123
  def forward(self, image, no_dropout=False):
124
  z = self.encode_with_vision_transformer(image)
@@ -174,38 +179,42 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
174
  return z
175
 
176
  def encode_with_vision_transformer(self, img):
177
- # if self.max_crops > 0:
178
- # img = self.preprocess_by_cropping(img)
179
- if img.dim() == 5:
180
- assert self.max_crops == img.shape[1]
181
- img = rearrange(img, "b n c h w -> (b n) c h w")
182
- img = self.preprocess(img)
183
- if not self.output_tokens:
184
- assert not self.model.visual.output_tokens
185
- x = self.model.visual(img)
186
- tokens = None
187
- else:
188
- assert self.model.visual.output_tokens
189
- x, tokens = self.model.visual(img)
190
  if self.max_crops > 0:
191
- x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
192
- # drop out between 0 and all along the sequence axis
193
- x = (
194
- torch.bernoulli(
195
- (1.0 - self.ucg_rate)
196
- * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
197
- )
198
- * x
199
- )
200
- if tokens is not None:
201
- tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
202
- print(
203
- f"You are running very experimental token-concat in {self.__class__.__name__}. "
204
- f"Check what you are doing, and then remove this message."
205
- )
206
- if self.output_tokens:
207
- return x, tokens
208
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  def encode(self, text):
211
  return self(text)
 
1
  import math
2
  from typing import Any, Mapping
3
  import torch
4
+ from torchvision.transforms.functional import to_pil_image
5
  import torch.nn as nn
6
  import kornia
7
+ # import open_clip
8
+ from transformers import CLIPVisionModelWithProjection, AutoProcessor
9
  from transformers.models.bit.image_processing_bit import BitImageProcessor
10
  from einops import rearrange, repeat
11
  # FFN
 
73
  output_tokens=False,
74
  ):
75
  super().__init__()
76
+ # model, _, _ = open_clip.create_model_and_transforms(
77
+ # arch,
78
+ # device=torch.device("cpu"),
79
+ # pretrained=version,
80
+ # )
81
+ # del model.transformer
82
+ # self.model = model
83
+ self.model_t = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
84
+ self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
85
+
86
  self.max_crops = num_image_crops
87
  self.pad_to_max_len = self.max_crops > 0
88
  self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
 
102
  self.ucg_rate = ucg_rate
103
  self.unsqueeze_dim = unsqueeze_dim
104
  self.stored_batch = None
105
+ # self.model.visual.output_tokens = output_tokens
106
  self.output_tokens = output_tokens
107
 
108
  def preprocess(self, x):
 
120
  return x
121
 
122
  def freeze(self):
123
+ # self.model = self.model.eval()
124
  for param in self.parameters():
125
  param.requires_grad = False
126
+ self.model_t = self.model_t.eval()
127
 
128
  def forward(self, image, no_dropout=False):
129
  z = self.encode_with_vision_transformer(image)
 
179
  return z
180
 
181
  def encode_with_vision_transformer(self, img):
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  if self.max_crops > 0:
183
+ img = self.preprocess_by_cropping(img)
184
+ pil_img = to_pil_image(img[0]*0.5 + 0.5)
185
+ inputs = self.processor(images=pil_img, return_tensors="pt").to("cuda")
186
+ outputs = self.model_t(**inputs)
187
+ return outputs.image_embeds
188
+ # if img.dim() == 5:
189
+ # assert self.max_crops == img.shape[1]
190
+ # img = rearrange(img, "b n c h w -> (b n) c h w")
191
+ # img = self.preprocess(img)
192
+ # if not self.output_tokens:
193
+ # assert not self.model.visual.output_tokens
194
+ # x = self.model.visual(img)
195
+ # tokens = None
196
+ # else:
197
+ # assert self.model.visual.output_tokens
198
+ # x, tokens = self.model.visual(img)
199
+ # if self.max_crops > 0:
200
+ # x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
201
+ # # drop out between 0 and all along the sequence axis
202
+ # x = (
203
+ # torch.bernoulli(
204
+ # (1.0 - self.ucg_rate)
205
+ # * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
206
+ # )
207
+ # * x
208
+ # )
209
+ # if tokens is not None:
210
+ # tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
211
+ # print(
212
+ # f"You are running very experimental token-concat in {self.__class__.__name__}. "
213
+ # f"Check what you are doing, and then remove this message."
214
+ # )
215
+ # if self.output_tokens:
216
+ # return x, tokens
217
+ # return x
218
 
219
  def encode(self, text):
220
  return self(text)