ujin-song commited on
Commit
8e12b4e
·
verified ·
1 Parent(s): 1c6e4b7

upload mixofshow and orthogonal_mats folder

Browse files
Files changed (50) hide show
  1. mixofshow/.DS_Store +0 -0
  2. mixofshow/data/__init__.py +0 -0
  3. mixofshow/data/__pycache__/__init__.cpython-38.pyc +0 -0
  4. mixofshow/data/__pycache__/__init__.cpython-39.pyc +0 -0
  5. mixofshow/data/__pycache__/lora_dataset.cpython-38.pyc +0 -0
  6. mixofshow/data/__pycache__/lora_dataset.cpython-39.pyc +0 -0
  7. mixofshow/data/__pycache__/pil_transform.cpython-38.pyc +0 -0
  8. mixofshow/data/__pycache__/pil_transform.cpython-39.pyc +0 -0
  9. mixofshow/data/__pycache__/prompt_dataset.cpython-38.pyc +0 -0
  10. mixofshow/data/__pycache__/prompt_dataset.cpython-39.pyc +0 -0
  11. mixofshow/data/lora_dataset.py +102 -0
  12. mixofshow/data/pil_transform.py +366 -0
  13. mixofshow/data/prompt_dataset.py +67 -0
  14. mixofshow/models/__pycache__/edlora.cpython-310.pyc +0 -0
  15. mixofshow/models/__pycache__/edlora.cpython-38.pyc +0 -0
  16. mixofshow/models/__pycache__/edlora.cpython-39.pyc +0 -0
  17. mixofshow/models/edlora.py +259 -0
  18. mixofshow/pipelines/__pycache__/pipeline_edlora.cpython-310.pyc +0 -0
  19. mixofshow/pipelines/__pycache__/pipeline_edlora.cpython-38.pyc +0 -0
  20. mixofshow/pipelines/__pycache__/pipeline_edlora.cpython-39.pyc +0 -0
  21. mixofshow/pipelines/__pycache__/pipeline_regionally_t2iadapter.cpython-310.pyc +0 -0
  22. mixofshow/pipelines/__pycache__/pipeline_regionally_t2iadapter.cpython-38.pyc +0 -0
  23. mixofshow/pipelines/__pycache__/pipeline_regionally_t2iadapter.cpython-39.pyc +0 -0
  24. mixofshow/pipelines/__pycache__/trainer_edlora.cpython-38.pyc +0 -0
  25. mixofshow/pipelines/__pycache__/trainer_edlora.cpython-39.pyc +0 -0
  26. mixofshow/pipelines/pipeline_edlora.py +322 -0
  27. mixofshow/pipelines/pipeline_regionally_t2iadapter.py +608 -0
  28. mixofshow/pipelines/trainer_edlora.py +380 -0
  29. mixofshow/utils/__init__.py +0 -0
  30. mixofshow/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  31. mixofshow/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  32. mixofshow/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  33. mixofshow/utils/__pycache__/convert_edlora_to_diffusers.cpython-38.pyc +0 -0
  34. mixofshow/utils/__pycache__/convert_edlora_to_diffusers.cpython-39.pyc +0 -0
  35. mixofshow/utils/__pycache__/ptp_util.cpython-38.pyc +0 -0
  36. mixofshow/utils/__pycache__/ptp_util.cpython-39.pyc +0 -0
  37. mixofshow/utils/__pycache__/registry.cpython-38.pyc +0 -0
  38. mixofshow/utils/__pycache__/registry.cpython-39.pyc +0 -0
  39. mixofshow/utils/__pycache__/util.cpython-310.pyc +0 -0
  40. mixofshow/utils/__pycache__/util.cpython-38.pyc +0 -0
  41. mixofshow/utils/__pycache__/util.cpython-39.pyc +0 -0
  42. mixofshow/utils/arial.ttf +0 -0
  43. mixofshow/utils/convert_edlora_to_diffusers.py +99 -0
  44. mixofshow/utils/ptp_util.py +200 -0
  45. mixofshow/utils/registry.py +79 -0
  46. mixofshow/utils/util.py +313 -0
  47. orthogonal_mats/1280.npy +3 -0
  48. orthogonal_mats/320.npy +3 -0
  49. orthogonal_mats/640.npy +3 -0
  50. orthogonal_mats/768.npy +3 -0
mixofshow/.DS_Store ADDED
Binary file (6.15 kB). View file
 
mixofshow/data/__init__.py ADDED
File without changes
mixofshow/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (148 Bytes). View file
 
mixofshow/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (148 Bytes). View file
 
mixofshow/data/__pycache__/lora_dataset.cpython-38.pyc ADDED
Binary file (3.02 kB). View file
 
mixofshow/data/__pycache__/lora_dataset.cpython-39.pyc ADDED
Binary file (3.07 kB). View file
 
mixofshow/data/__pycache__/pil_transform.cpython-38.pyc ADDED
Binary file (10.9 kB). View file
 
mixofshow/data/__pycache__/pil_transform.cpython-39.pyc ADDED
Binary file (10.8 kB). View file
 
mixofshow/data/__pycache__/prompt_dataset.cpython-38.pyc ADDED
Binary file (2.35 kB). View file
 
mixofshow/data/__pycache__/prompt_dataset.cpython-39.pyc ADDED
Binary file (2.36 kB). View file
 
mixofshow/data/lora_dataset.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import re
5
+ from pathlib import Path
6
+
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset
9
+
10
+ from mixofshow.data.pil_transform import PairCompose, build_transform
11
+
12
+
13
+ class LoraDataset(Dataset):
14
+ """
15
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
16
+ It pre-processes the images and the tokenizes prompts.
17
+ """
18
+ def __init__(self, opt):
19
+ self.opt = opt
20
+ self.instance_images_path = []
21
+
22
+ with open(opt['concept_list'], 'r') as f:
23
+ concept_list = json.load(f)
24
+
25
+ replace_mapping = opt.get('replace_mapping', {})
26
+ use_caption = opt.get('use_caption', False)
27
+ use_mask = opt.get('use_mask', False)
28
+
29
+ for concept in concept_list:
30
+ instance_prompt = concept['instance_prompt']
31
+ caption_dir = concept.get('caption_dir')
32
+ mask_dir = concept.get('mask_dir')
33
+
34
+ instance_prompt = self.process_text(instance_prompt, replace_mapping)
35
+
36
+ inst_img_path = []
37
+ for x in Path(concept['instance_data_dir']).iterdir():
38
+ if x.is_file() and x.name != '.DS_Store':
39
+ basename = os.path.splitext(os.path.basename(x))[0]
40
+ caption_path = os.path.join(caption_dir, f'{basename}.txt') if caption_dir is not None else None
41
+
42
+ if use_caption and caption_path is not None and os.path.exists(caption_path):
43
+ with open(caption_path, 'r') as fr:
44
+ line = fr.readlines()[0]
45
+ instance_prompt_image = self.process_text(line, replace_mapping)
46
+ else:
47
+ instance_prompt_image = instance_prompt
48
+
49
+ if use_mask and mask_dir is not None:
50
+ mask_path = os.path.join(mask_dir, f'{basename}.png')
51
+ else:
52
+ mask_path = None
53
+
54
+ inst_img_path.append((x, instance_prompt_image, mask_path))
55
+
56
+ self.instance_images_path.extend(inst_img_path)
57
+
58
+ random.shuffle(self.instance_images_path)
59
+ self.num_instance_images = len(self.instance_images_path)
60
+
61
+ self.instance_transform = PairCompose([
62
+ build_transform(transform_opt)
63
+ for transform_opt in opt['instance_transform']
64
+ ])
65
+
66
+ def process_text(self, instance_prompt, replace_mapping):
67
+ for k, v in replace_mapping.items():
68
+ instance_prompt = instance_prompt.replace(k, v)
69
+ instance_prompt = instance_prompt.strip()
70
+ instance_prompt = re.sub(' +', ' ', instance_prompt)
71
+ return instance_prompt
72
+
73
+ def __len__(self):
74
+ return self.num_instance_images * self.opt['dataset_enlarge_ratio']
75
+
76
+ def __getitem__(self, index):
77
+ example = {}
78
+ instance_image, instance_prompt, instance_mask = self.instance_images_path[index % self.num_instance_images]
79
+ instance_image = Image.open(instance_image).convert('RGB')
80
+
81
+ extra_args = {'prompts': instance_prompt}
82
+ if instance_mask is not None:
83
+ instance_mask = Image.open(instance_mask).convert('L')
84
+ extra_args.update({'mask': instance_mask})
85
+
86
+ instance_image, extra_args = self.instance_transform(instance_image, **extra_args)
87
+ example['images'] = instance_image
88
+
89
+ if 'mask' in extra_args:
90
+ example['masks'] = extra_args['mask']
91
+ example['masks'] = example['masks'].unsqueeze(0)
92
+ else:
93
+ pass
94
+
95
+ if 'img_mask' in extra_args:
96
+ example['img_masks'] = extra_args['img_mask']
97
+ example['img_masks'] = example['img_masks'].unsqueeze(0)
98
+ else:
99
+ raise NotImplementedError
100
+
101
+ example['prompts'] = extra_args['prompts']
102
+ return example
mixofshow/data/pil_transform.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import random
3
+ from copy import deepcopy
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchvision.transforms.functional as F
10
+ from PIL import Image
11
+ from torchvision.transforms import CenterCrop, Normalize, RandomCrop, RandomHorizontalFlip, Resize
12
+ from torchvision.transforms.functional import InterpolationMode
13
+
14
+ from mixofshow.utils.registry import TRANSFORM_REGISTRY
15
+
16
+
17
+ def build_transform(opt):
18
+ """Build performance evaluator from options.
19
+ Args:
20
+ opt (dict): Configuration.
21
+ """
22
+ opt = deepcopy(opt)
23
+ transform_type = opt.pop('type')
24
+ transform = TRANSFORM_REGISTRY.get(transform_type)(**opt)
25
+ return transform
26
+
27
+
28
+ TRANSFORM_REGISTRY.register(Normalize)
29
+ TRANSFORM_REGISTRY.register(Resize)
30
+ TRANSFORM_REGISTRY.register(RandomHorizontalFlip)
31
+ TRANSFORM_REGISTRY.register(CenterCrop)
32
+ TRANSFORM_REGISTRY.register(RandomCrop)
33
+
34
+
35
+ @TRANSFORM_REGISTRY.register()
36
+ class BILINEARResize(Resize):
37
+ def __init__(self, size):
38
+ super(BILINEARResize,
39
+ self).__init__(size, interpolation=InterpolationMode.BILINEAR)
40
+
41
+
42
+ @TRANSFORM_REGISTRY.register()
43
+ class PairRandomCrop(nn.Module):
44
+ def __init__(self, size):
45
+ super().__init__()
46
+ if isinstance(size, int):
47
+ self.height, self.width = size, size
48
+ else:
49
+ self.height, self.width = size
50
+
51
+ def forward(self, img, **kwargs):
52
+ img_width, img_height = img.size
53
+ mask_width, mask_height = kwargs['mask'].size
54
+
55
+ assert img_height >= self.height and img_height == mask_height
56
+ assert img_width >= self.width and img_width == mask_width
57
+
58
+ x = random.randint(0, img_width - self.width)
59
+ y = random.randint(0, img_height - self.height)
60
+ img = F.crop(img, y, x, self.height, self.width)
61
+ kwargs['mask'] = F.crop(kwargs['mask'], y, x, self.height, self.width)
62
+ return img, kwargs
63
+
64
+
65
+ @TRANSFORM_REGISTRY.register()
66
+ class ToTensor(nn.Module):
67
+ def __init__(self) -> None:
68
+ super().__init__()
69
+
70
+ def forward(self, pic):
71
+ return F.to_tensor(pic)
72
+
73
+ def __repr__(self) -> str:
74
+ return f'{self.__class__.__name__}()'
75
+
76
+
77
+ @TRANSFORM_REGISTRY.register()
78
+ class PairRandomHorizontalFlip(torch.nn.Module):
79
+ def __init__(self, p=0.5):
80
+ super().__init__()
81
+ self.p = p
82
+
83
+ def forward(self, img, **kwargs):
84
+ if torch.rand(1) < self.p:
85
+ kwargs['mask'] = F.hflip(kwargs['mask'])
86
+ return F.hflip(img), kwargs
87
+ return img, kwargs
88
+
89
+
90
+ @TRANSFORM_REGISTRY.register()
91
+ class PairResize(nn.Module):
92
+ def __init__(self, size):
93
+ super().__init__()
94
+ self.resize = Resize(size=size)
95
+
96
+ def forward(self, img, **kwargs):
97
+ kwargs['mask'] = self.resize(kwargs['mask'])
98
+ img = self.resize(img)
99
+ return img, kwargs
100
+
101
+
102
+ class PairCompose(nn.Module):
103
+ def __init__(self, transforms):
104
+ super().__init__()
105
+ self.transforms = transforms
106
+
107
+ def __call__(self, img, **kwargs):
108
+ for t in self.transforms:
109
+ if len(inspect.signature(t.forward).parameters
110
+ ) == 1: # count how many args, not count self
111
+ img = t(img)
112
+ else:
113
+ img, kwargs = t(img, **kwargs)
114
+ return img, kwargs
115
+
116
+ def __repr__(self) -> str:
117
+ format_string = self.__class__.__name__ + '('
118
+ for t in self.transforms:
119
+ format_string += '\n'
120
+ format_string += f' {t}'
121
+ format_string += '\n)'
122
+ return format_string
123
+
124
+
125
+ @TRANSFORM_REGISTRY.register()
126
+ class HumanResizeCropFinalV3(nn.Module):
127
+ def __init__(self, size, crop_p=0.5):
128
+ super().__init__()
129
+ self.size = size
130
+ self.crop_p = crop_p
131
+ self.random_crop = RandomCrop(size=size)
132
+ self.paired_random_crop = PairRandomCrop(size=size)
133
+
134
+ def forward(self, img, **kwargs):
135
+ # step 1: short edge resize to 512
136
+ img = F.resize(img, size=self.size)
137
+ if 'mask' in kwargs:
138
+ kwargs['mask'] = F.resize(kwargs['mask'], size=self.size)
139
+
140
+ # step 2: random crop
141
+ width, height = img.size
142
+ if random.random() < self.crop_p:
143
+ if height > width:
144
+ crop_pos = random.randint(0, height - width)
145
+ img = F.crop(img, 0, 0, width + crop_pos, width)
146
+ if 'mask' in kwargs:
147
+ kwargs['mask'] = F.crop(kwargs['mask'], 0, 0, width + crop_pos, width)
148
+ else:
149
+ if 'mask' in kwargs:
150
+ img, kwargs = self.paired_random_crop(img, **kwargs)
151
+ else:
152
+ img = self.random_crop(img)
153
+ else:
154
+ img = img
155
+
156
+ # step 3: long edge resize
157
+ img = F.resize(img, size=self.size - 1, max_size=self.size)
158
+ if 'mask' in kwargs:
159
+ kwargs['mask'] = F.resize(kwargs['mask'], size=self.size - 1, max_size=self.size)
160
+
161
+ new_width, new_height = img.size
162
+
163
+ img = np.array(img)
164
+ if 'mask' in kwargs:
165
+ kwargs['mask'] = np.array(kwargs['mask']) / 255
166
+ new_width = min(new_width, kwargs['mask'].shape[1])
167
+ new_height = min(new_height, kwargs['mask'].shape[0])
168
+
169
+ start_y = random.randint(0, 512 - new_height)
170
+ start_x = random.randint(0, 512 - new_width)
171
+
172
+ res_img = np.zeros((self.size, self.size, 3), dtype=np.uint8)
173
+ res_mask = np.zeros((self.size, self.size))
174
+ res_img_mask = np.zeros((self.size, self.size))
175
+
176
+ res_img[start_y:start_y + new_height, start_x:start_x + new_width, :] = img[:new_height, :new_width]
177
+ if 'mask' in kwargs:
178
+ res_mask[start_y:start_y + new_height, start_x:start_x + new_width] = kwargs['mask'][:new_height, :new_width]
179
+ kwargs['mask'] = res_mask
180
+
181
+ res_img_mask[start_y:start_y + new_height, start_x:start_x + new_width] = 1
182
+ kwargs['img_mask'] = res_img_mask
183
+
184
+ img = Image.fromarray(res_img)
185
+
186
+ if 'mask' in kwargs:
187
+ kwargs['mask'] = cv2.resize(kwargs['mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
188
+ kwargs['mask'] = torch.from_numpy(kwargs['mask'])
189
+ kwargs['img_mask'] = cv2.resize(kwargs['img_mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
190
+ kwargs['img_mask'] = torch.from_numpy(kwargs['img_mask'])
191
+ return img, kwargs
192
+
193
+
194
+ @TRANSFORM_REGISTRY.register()
195
+ class ResizeFillMaskNew(nn.Module):
196
+ def __init__(self, size, crop_p, scale_ratio):
197
+ super().__init__()
198
+ self.size = size
199
+ self.crop_p = crop_p
200
+ self.scale_ratio = scale_ratio
201
+ self.random_crop = RandomCrop(size=size)
202
+ self.paired_random_crop = PairRandomCrop(size=size)
203
+
204
+ def forward(self, img, **kwargs):
205
+ # width, height = img.size
206
+
207
+ # step 1: short edge resize to 512
208
+ img = F.resize(img, size=self.size)
209
+ if 'mask' in kwargs:
210
+ kwargs['mask'] = F.resize(kwargs['mask'], size=self.size)
211
+
212
+ # step 2: random crop
213
+ if random.random() < self.crop_p:
214
+ if 'mask' in kwargs:
215
+ img, kwargs = self.paired_random_crop(img, **kwargs) # 51
216
+ else:
217
+ img = self.random_crop(img) # 512
218
+ else:
219
+ # long edge resize
220
+ img = F.resize(img, size=self.size - 1, max_size=self.size)
221
+ if 'mask' in kwargs:
222
+ kwargs['mask'] = F.resize(kwargs['mask'], size=self.size - 1, max_size=self.size)
223
+
224
+ # step 3: random aspect ratio
225
+ width, height = img.size
226
+ ratio = random.uniform(*self.scale_ratio)
227
+
228
+ img = F.resize(img, size=(int(height * ratio), int(width * ratio)))
229
+ if 'mask' in kwargs:
230
+ kwargs['mask'] = F.resize(kwargs['mask'], size=(int(height * ratio), int(width * ratio)), interpolation=0)
231
+
232
+ # step 4: random place
233
+ new_width, new_height = img.size
234
+
235
+ img = np.array(img)
236
+ if 'mask' in kwargs:
237
+ kwargs['mask'] = np.array(kwargs['mask']) / 255
238
+
239
+ start_y = random.randint(0, 512 - new_height)
240
+ start_x = random.randint(0, 512 - new_width)
241
+
242
+ res_img = np.zeros((self.size, self.size, 3), dtype=np.uint8)
243
+ res_mask = np.zeros((self.size, self.size))
244
+ res_img_mask = np.zeros((self.size, self.size))
245
+
246
+ res_img[start_y:start_y + new_height, start_x:start_x + new_width, :] = img
247
+ if 'mask' in kwargs:
248
+ res_mask[start_y:start_y + new_height, start_x:start_x + new_width] = kwargs['mask']
249
+ kwargs['mask'] = res_mask
250
+
251
+ res_img_mask[start_y:start_y + new_height, start_x:start_x + new_width] = 1
252
+ kwargs['img_mask'] = res_img_mask
253
+
254
+ img = Image.fromarray(res_img)
255
+
256
+ if 'mask' in kwargs:
257
+ kwargs['mask'] = cv2.resize(kwargs['mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
258
+ kwargs['mask'] = torch.from_numpy(kwargs['mask'])
259
+ kwargs['img_mask'] = cv2.resize(kwargs['img_mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
260
+ kwargs['img_mask'] = torch.from_numpy(kwargs['img_mask'])
261
+
262
+ return img, kwargs
263
+
264
+
265
+ @TRANSFORM_REGISTRY.register()
266
+ class ShuffleCaption(nn.Module):
267
+ def __init__(self, keep_token_num):
268
+ super().__init__()
269
+ self.keep_token_num = keep_token_num
270
+
271
+ def forward(self, img, **kwargs):
272
+ prompts = kwargs['prompts'].strip()
273
+
274
+ fixed_tokens = []
275
+ flex_tokens = [t.strip() for t in prompts.strip().split(',')]
276
+ if self.keep_token_num > 0:
277
+ fixed_tokens = flex_tokens[:self.keep_token_num]
278
+ flex_tokens = flex_tokens[self.keep_token_num:]
279
+
280
+ random.shuffle(flex_tokens)
281
+ prompts = ', '.join(fixed_tokens + flex_tokens)
282
+ kwargs['prompts'] = prompts
283
+ return img, kwargs
284
+
285
+
286
+ @TRANSFORM_REGISTRY.register()
287
+ class EnhanceText(nn.Module):
288
+ def __init__(self, enhance_type='object'):
289
+ super().__init__()
290
+ STYLE_TEMPLATE = [
291
+ 'a painting in the style of {}',
292
+ 'a rendering in the style of {}',
293
+ 'a cropped painting in the style of {}',
294
+ 'the painting in the style of {}',
295
+ 'a clean painting in the style of {}',
296
+ 'a dirty painting in the style of {}',
297
+ 'a dark painting in the style of {}',
298
+ 'a picture in the style of {}',
299
+ 'a cool painting in the style of {}',
300
+ 'a close-up painting in the style of {}',
301
+ 'a bright painting in the style of {}',
302
+ 'a cropped painting in the style of {}',
303
+ 'a good painting in the style of {}',
304
+ 'a close-up painting in the style of {}',
305
+ 'a rendition in the style of {}',
306
+ 'a nice painting in the style of {}',
307
+ 'a small painting in the style of {}',
308
+ 'a weird painting in the style of {}',
309
+ 'a large painting in the style of {}',
310
+ ]
311
+
312
+ OBJECT_TEMPLATE = [
313
+ 'a photo of a {}',
314
+ 'a rendering of a {}',
315
+ 'a cropped photo of the {}',
316
+ 'the photo of a {}',
317
+ 'a photo of a clean {}',
318
+ 'a photo of a dirty {}',
319
+ 'a dark photo of the {}',
320
+ 'a photo of my {}',
321
+ 'a photo of the cool {}',
322
+ 'a close-up photo of a {}',
323
+ 'a bright photo of the {}',
324
+ 'a cropped photo of a {}',
325
+ 'a photo of the {}',
326
+ 'a good photo of the {}',
327
+ 'a photo of one {}',
328
+ 'a close-up photo of the {}',
329
+ 'a rendition of the {}',
330
+ 'a photo of the clean {}',
331
+ 'a rendition of a {}',
332
+ 'a photo of a nice {}',
333
+ 'a good photo of a {}',
334
+ 'a photo of the nice {}',
335
+ 'a photo of the small {}',
336
+ 'a photo of the weird {}',
337
+ 'a photo of the large {}',
338
+ 'a photo of a cool {}',
339
+ 'a photo of a small {}',
340
+ ]
341
+
342
+ HUMAN_TEMPLATE = [
343
+ 'a photo of a {}', 'a photo of one {}', 'a photo of the {}',
344
+ 'the photo of a {}', 'a rendering of a {}',
345
+ 'a rendition of the {}', 'a rendition of a {}',
346
+ 'a cropped photo of the {}', 'a cropped photo of a {}',
347
+ 'a bad photo of the {}', 'a bad photo of a {}',
348
+ 'a photo of a weird {}', 'a weird photo of a {}',
349
+ 'a bright photo of the {}', 'a good photo of the {}',
350
+ 'a photo of a nice {}', 'a good photo of a {}',
351
+ 'a photo of a cool {}', 'a bright photo of the {}'
352
+ ]
353
+
354
+ if enhance_type == 'object':
355
+ self.templates = OBJECT_TEMPLATE
356
+ elif enhance_type == 'style':
357
+ self.templates = STYLE_TEMPLATE
358
+ elif enhance_type == 'human':
359
+ self.templates = HUMAN_TEMPLATE
360
+ else:
361
+ raise NotImplementedError
362
+
363
+ def forward(self, img, **kwargs):
364
+ concept_token = kwargs['prompts'].strip()
365
+ kwargs['prompts'] = random.choice(self.templates).format(concept_token)
366
+ return img, kwargs
mixofshow/data/prompt_dataset.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import re
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+
8
+
9
+ class PromptDataset(Dataset):
10
+ 'A simple dataset to prepare the prompts to generate class images on multiple GPUs.'
11
+
12
+ def __init__(self, opt):
13
+ self.opt = opt
14
+
15
+ self.prompts = opt['prompts']
16
+
17
+ if isinstance(self.prompts, list):
18
+ self.prompts = self.prompts
19
+ elif os.path.exists(self.prompts):
20
+ # is file
21
+ with open(self.prompts, 'r') as fr:
22
+ lines = fr.readlines()
23
+ lines = [item.strip() for item in lines]
24
+ self.prompts = lines
25
+ else:
26
+ raise ValueError(
27
+ 'prompts should be a prompt file path or prompt list, please check!'
28
+ )
29
+
30
+ self.prompts = self.replace_placeholder(self.prompts)
31
+
32
+ self.num_samples_per_prompt = opt['num_samples_per_prompt']
33
+ self.prompts_to_generate = [
34
+ (p, i) for i in range(1, self.num_samples_per_prompt + 1)
35
+ for p in self.prompts
36
+ ]
37
+ self.latent_size = opt['latent_size'] # (4,64,64)
38
+ self.share_latent_across_prompt = opt.get('share_latent_across_prompt', True) # (true, false)
39
+
40
+ def replace_placeholder(self, prompts):
41
+ # replace placehold token
42
+ replace_mapping = self.opt.get('replace_mapping', {})
43
+ new_lines = []
44
+ for line in self.prompts:
45
+ if len(line.strip()) == 0:
46
+ continue
47
+ for k, v in replace_mapping.items():
48
+ line = line.replace(k, v)
49
+ line = line.strip()
50
+ line = re.sub(' +', ' ', line)
51
+ new_lines.append(line)
52
+ return new_lines
53
+
54
+ def __len__(self):
55
+ return len(self.prompts_to_generate)
56
+
57
+ def __getitem__(self, index):
58
+ prompt, indice = self.prompts_to_generate[index]
59
+ example = {}
60
+ example['prompts'] = prompt
61
+ example['indices'] = indice
62
+ if self.share_latent_across_prompt:
63
+ seed = indice
64
+ else:
65
+ seed = random.randint(0, 1000)
66
+ example['latents'] = torch.randn(self.latent_size, generator=torch.manual_seed(seed))
67
+ return example
mixofshow/models/__pycache__/edlora.cpython-310.pyc ADDED
Binary file (6.96 kB). View file
 
mixofshow/models/__pycache__/edlora.cpython-38.pyc ADDED
Binary file (6.96 kB). View file
 
mixofshow/models/__pycache__/edlora.cpython-39.pyc ADDED
Binary file (6.95 kB). View file
 
mixofshow/models/edlora.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.models.attention_processor import AttnProcessor
7
+ from diffusers.utils.import_utils import is_xformers_available
8
+
9
+ if is_xformers_available():
10
+ import xformers
11
+
12
+
13
+ def remove_edlora_unet_attention_forward(unet):
14
+ def change_forward(unet): # omit proceesor in new diffusers
15
+ for name, layer in unet.named_children():
16
+ if layer.__class__.__name__ == 'Attention' and name == 'attn2':
17
+ layer.set_processor(AttnProcessor())
18
+ else:
19
+ change_forward(layer)
20
+ change_forward(unet)
21
+
22
+
23
+ class EDLoRA_Control_AttnProcessor:
24
+ r"""
25
+ Default processor for performing attention-related computations.
26
+ """
27
+ def __init__(self, cross_attention_idx, place_in_unet, controller, attention_op=None):
28
+ self.cross_attention_idx = cross_attention_idx
29
+ self.place_in_unet = place_in_unet
30
+ self.controller = controller
31
+ self.attention_op = attention_op
32
+
33
+ def __call__(
34
+ self,
35
+ attn,
36
+ hidden_states,
37
+ encoder_hidden_states=None,
38
+ attention_mask=None,
39
+ temb=None,
40
+ ):
41
+ residual = hidden_states
42
+
43
+ if attn.spatial_norm is not None:
44
+ hidden_states = attn.spatial_norm(hidden_states, temb)
45
+
46
+ input_ndim = hidden_states.ndim
47
+
48
+ if input_ndim == 4:
49
+ batch_size, channel, height, width = hidden_states.shape
50
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
51
+
52
+ if encoder_hidden_states is None:
53
+ is_cross = False
54
+ encoder_hidden_states = hidden_states
55
+ else:
56
+ is_cross = True
57
+ if len(encoder_hidden_states.shape) == 4: # multi-layer embedding
58
+ encoder_hidden_states = encoder_hidden_states[:, self.cross_attention_idx, ...]
59
+ else: # single layer embedding
60
+ encoder_hidden_states = encoder_hidden_states
61
+
62
+ assert not attn.norm_cross
63
+
64
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
65
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
66
+
67
+ if attn.group_norm is not None:
68
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
69
+
70
+ query = attn.to_q(hidden_states)
71
+ key = attn.to_k(encoder_hidden_states)
72
+ value = attn.to_v(encoder_hidden_states)
73
+
74
+ query = attn.head_to_batch_dim(query).contiguous()
75
+ key = attn.head_to_batch_dim(key).contiguous()
76
+ value = attn.head_to_batch_dim(value).contiguous()
77
+
78
+ if is_xformers_available() and not is_cross:
79
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
80
+ hidden_states = hidden_states.to(query.dtype)
81
+ else:
82
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
83
+ attention_probs = self.controller(attention_probs, is_cross, self.place_in_unet)
84
+ hidden_states = torch.bmm(attention_probs, value)
85
+
86
+ hidden_states = attn.batch_to_head_dim(hidden_states)
87
+
88
+ # linear proj
89
+ hidden_states = attn.to_out[0](hidden_states)
90
+ # dropout
91
+ hidden_states = attn.to_out[1](hidden_states)
92
+
93
+ if input_ndim == 4:
94
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
95
+
96
+ if attn.residual_connection:
97
+ hidden_states = hidden_states + residual
98
+
99
+ hidden_states = hidden_states / attn.rescale_output_factor
100
+
101
+ return hidden_states
102
+
103
+
104
+ class EDLoRA_AttnProcessor:
105
+ def __init__(self, cross_attention_idx, attention_op=None):
106
+ self.attention_op = attention_op
107
+ self.cross_attention_idx = cross_attention_idx
108
+
109
+ def __call__(
110
+ self,
111
+ attn,
112
+ hidden_states,
113
+ encoder_hidden_states=None,
114
+ attention_mask=None,
115
+ temb=None,
116
+ ):
117
+ residual = hidden_states
118
+
119
+ if attn.spatial_norm is not None:
120
+ hidden_states = attn.spatial_norm(hidden_states, temb)
121
+
122
+ input_ndim = hidden_states.ndim
123
+
124
+ if input_ndim == 4:
125
+ batch_size, channel, height, width = hidden_states.shape
126
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
127
+
128
+ if encoder_hidden_states is None:
129
+ encoder_hidden_states = hidden_states
130
+ else:
131
+ if len(encoder_hidden_states.shape) == 4: # multi-layer embedding
132
+ encoder_hidden_states = encoder_hidden_states[:, self.cross_attention_idx, ...]
133
+ else: # single layer embedding
134
+ encoder_hidden_states = encoder_hidden_states
135
+
136
+ assert not attn.norm_cross
137
+
138
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
139
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
140
+
141
+ if attn.group_norm is not None:
142
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
143
+
144
+ query = attn.to_q(hidden_states)
145
+ key = attn.to_k(encoder_hidden_states)
146
+ value = attn.to_v(encoder_hidden_states)
147
+
148
+ query = attn.head_to_batch_dim(query).contiguous()
149
+ key = attn.head_to_batch_dim(key).contiguous()
150
+ value = attn.head_to_batch_dim(value).contiguous()
151
+
152
+ if is_xformers_available():
153
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
154
+ hidden_states = hidden_states.to(query.dtype)
155
+ else:
156
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
157
+ hidden_states = torch.bmm(attention_probs, value)
158
+
159
+ hidden_states = attn.batch_to_head_dim(hidden_states)
160
+
161
+ # linear proj
162
+ hidden_states = attn.to_out[0](hidden_states)
163
+ # dropout
164
+ hidden_states = attn.to_out[1](hidden_states)
165
+
166
+ if input_ndim == 4:
167
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
168
+
169
+ if attn.residual_connection:
170
+ hidden_states = hidden_states + residual
171
+
172
+ hidden_states = hidden_states / attn.rescale_output_factor
173
+
174
+ return hidden_states
175
+
176
+
177
+ def revise_edlora_unet_attention_forward(unet):
178
+ def change_forward(unet, count):
179
+ for name, layer in unet.named_children():
180
+ if layer.__class__.__name__ == 'Attention' and 'attn2' in name:
181
+ layer.set_processor(EDLoRA_AttnProcessor(count))
182
+ count += 1
183
+ else:
184
+ count = change_forward(layer, count)
185
+ return count
186
+
187
+ # use this to ensure the order
188
+ cross_attention_idx = change_forward(unet.down_blocks, 0)
189
+ cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx)
190
+ cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx)
191
+ print(f'Number of attention layer registered {cross_attention_idx}')
192
+
193
+
194
+ def revise_edlora_unet_attention_controller_forward(unet, controller):
195
+ class DummyController:
196
+ def __call__(self, *args):
197
+ return args[0]
198
+
199
+ def __init__(self):
200
+ self.num_att_layers = 0
201
+
202
+ if controller is None:
203
+ controller = DummyController()
204
+
205
+ def change_forward(unet, count, place_in_unet):
206
+ for name, layer in unet.named_children():
207
+ if layer.__class__.__name__ == 'Attention' and 'attn2' in name: # only register controller for cross-attention
208
+ layer.set_processor(EDLoRA_Control_AttnProcessor(count, place_in_unet, controller))
209
+ count += 1
210
+ else:
211
+ count = change_forward(layer, count, place_in_unet)
212
+ return count
213
+
214
+ # use this to ensure the order
215
+ cross_attention_idx = change_forward(unet.down_blocks, 0, 'down')
216
+ cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx, 'mid')
217
+ cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx, 'up')
218
+ print(f'Number of attention layer registered {cross_attention_idx}')
219
+ controller.num_att_layers = cross_attention_idx
220
+
221
+
222
+ class LoRALinearLayer(nn.Module):
223
+ def __init__(self, name, original_module, rank=4, alpha=1):
224
+ super().__init__()
225
+
226
+ self.name = name
227
+
228
+ ### Hard coded LoRA rank
229
+ rank = 32
230
+
231
+ if original_module.__class__.__name__ == 'Conv2d':
232
+ in_channels, out_channels = original_module.in_channels, original_module.out_channels
233
+ self.lora_down = torch.nn.Conv2d(in_channels, rank, (1, 1), bias=False)
234
+ self.lora_up = torch.nn.Conv2d(rank, out_channels, (1, 1), bias=False)
235
+ else:
236
+ in_features, out_features = original_module.in_features, original_module.out_features
237
+ self.lora_down = nn.Linear(in_features, rank, bias=False)
238
+ self.lora_up = nn.Linear(rank, out_features, bias=False)
239
+
240
+ self.register_buffer('alpha', torch.tensor(alpha))
241
+
242
+ ### Load and initialize orthogonal B
243
+ m = np.load(f"orthogonal_mats/{in_features}.npy")
244
+ idxs = np.random.choice(in_features, size = rank, replace = False)
245
+ m = m[idxs]/2
246
+ with torch.no_grad():
247
+ self.lora_down.weight = torch.nn.Parameter(torch.tensor(m, dtype = self.lora_down.weight.dtype))
248
+
249
+ torch.nn.init.zeros_(self.lora_up.weight)
250
+
251
+ for param in self.lora_down.parameters():
252
+ param.requires_grad = False
253
+
254
+ self.original_forward = original_module.forward
255
+ original_module.forward = self.forward
256
+
257
+ def forward(self, hidden_states):
258
+ hidden_states = self.original_forward(hidden_states) + self.alpha * self.lora_up(self.lora_down(hidden_states))
259
+ return hidden_states
mixofshow/pipelines/__pycache__/pipeline_edlora.cpython-310.pyc ADDED
Binary file (8.81 kB). View file
 
mixofshow/pipelines/__pycache__/pipeline_edlora.cpython-38.pyc ADDED
Binary file (8.69 kB). View file
 
mixofshow/pipelines/__pycache__/pipeline_edlora.cpython-39.pyc ADDED
Binary file (8.7 kB). View file
 
mixofshow/pipelines/__pycache__/pipeline_regionally_t2iadapter.cpython-310.pyc ADDED
Binary file (19.1 kB). View file
 
mixofshow/pipelines/__pycache__/pipeline_regionally_t2iadapter.cpython-38.pyc ADDED
Binary file (19 kB). View file
 
mixofshow/pipelines/__pycache__/pipeline_regionally_t2iadapter.cpython-39.pyc ADDED
Binary file (19 kB). View file
 
mixofshow/pipelines/__pycache__/trainer_edlora.cpython-38.pyc ADDED
Binary file (10.9 kB). View file
 
mixofshow/pipelines/__pycache__/trainer_edlora.cpython-39.pyc ADDED
Binary file (10.9 kB). View file
 
mixofshow/pipelines/pipeline_edlora.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union
2
+
3
+ import torch
4
+ from diffusers import StableDiffusionPipeline
5
+ from diffusers.configuration_utils import FrozenDict
6
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
7
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
8
+ from diffusers.schedulers import KarrasDiffusionSchedulers
9
+ from diffusers.utils import deprecate
10
+ from einops import rearrange
11
+ from packaging import version
12
+ from transformers import CLIPTextModel, CLIPTokenizer
13
+
14
+ from mixofshow.models.edlora import (revise_edlora_unet_attention_controller_forward,
15
+ revise_edlora_unet_attention_forward)
16
+
17
+
18
+ def bind_concept_prompt(prompts, new_concept_cfg):
19
+ if isinstance(prompts, str):
20
+ prompts = [prompts]
21
+ new_prompts = []
22
+ for prompt in prompts:
23
+ prompt = [prompt] * 16
24
+ for concept_name, new_token_cfg in new_concept_cfg.items():
25
+ prompt = [
26
+ p.replace(concept_name, new_name) for p, new_name in zip(prompt, new_token_cfg['concept_token_names'])
27
+ ]
28
+ new_prompts.extend(prompt)
29
+ return new_prompts
30
+
31
+
32
+ class EDLoRAPipeline(StableDiffusionPipeline):
33
+
34
+ def __init__(
35
+ self,
36
+ vae: AutoencoderKL,
37
+ text_encoder: CLIPTextModel,
38
+ tokenizer: CLIPTokenizer,
39
+ unet: UNet2DConditionModel,
40
+ scheduler: KarrasDiffusionSchedulers,
41
+ safety_checker=None,
42
+ feature_extractor=None,
43
+ requires_safety_checker: bool = False,
44
+ ):
45
+ if hasattr(scheduler.config, 'steps_offset') and scheduler.config.steps_offset != 1:
46
+ deprecation_message = (
47
+ f'The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`'
48
+ f' should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure '
49
+ 'to update the config accordingly as leaving `steps_offset` might led to incorrect results'
50
+ ' in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,'
51
+ ' it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`'
52
+ ' file'
53
+ )
54
+ deprecate('steps_offset!=1', '1.0.0', deprecation_message, standard_warn=False)
55
+ new_config = dict(scheduler.config)
56
+ new_config['steps_offset'] = 1
57
+ scheduler._internal_dict = FrozenDict(new_config)
58
+
59
+ if hasattr(scheduler.config, 'clip_sample') and scheduler.config.clip_sample is True:
60
+ deprecation_message = (
61
+ f'The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.'
62
+ ' `clip_sample` should be set to False in the configuration file. Please make sure to update the'
63
+ ' config accordingly as not setting `clip_sample` in the config might lead to incorrect results in'
64
+ ' future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very'
65
+ ' nice if you could open a Pull request for the `scheduler/scheduler_config.json` file'
66
+ )
67
+ deprecate('clip_sample not set', '1.0.0', deprecation_message, standard_warn=False)
68
+ new_config = dict(scheduler.config)
69
+ new_config['clip_sample'] = False
70
+ scheduler._internal_dict = FrozenDict(new_config)
71
+
72
+ is_unet_version_less_0_9_0 = hasattr(unet.config, '_diffusers_version') and version.parse(
73
+ version.parse(unet.config._diffusers_version).base_version
74
+ ) < version.parse('0.9.0.dev0')
75
+ is_unet_sample_size_less_64 = hasattr(unet.config, 'sample_size') and unet.config.sample_size < 64
76
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
77
+ deprecation_message = (
78
+ 'The configuration file of the unet has set the default `sample_size` to smaller than'
79
+ ' 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the'
80
+ ' following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-'
81
+ ' CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5'
82
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
83
+ ' configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`'
84
+ ' in the config might lead to incorrect results in future versions. If you have downloaded this'
85
+ ' checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for'
86
+ ' the `unet/config.json` file'
87
+ )
88
+ deprecate('sample_size<64', '1.0.0', deprecation_message, standard_warn=False)
89
+ new_config = dict(unet.config)
90
+ new_config['sample_size'] = 64
91
+ unet._internal_dict = FrozenDict(new_config)
92
+
93
+ revise_edlora_unet_attention_forward(unet)
94
+ self.register_modules(
95
+ vae=vae,
96
+ text_encoder=text_encoder,
97
+ tokenizer=tokenizer,
98
+ unet=unet,
99
+ scheduler=scheduler
100
+ )
101
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
102
+ self.new_concept_cfg = None
103
+
104
+ def set_new_concept_cfg(self, new_concept_cfg=None):
105
+ self.new_concept_cfg = new_concept_cfg
106
+
107
+ def set_controller(self, controller):
108
+ self.controller = controller
109
+ revise_edlora_unet_attention_controller_forward(self.unet, controller)
110
+
111
+ def _encode_prompt(self,
112
+ prompt,
113
+ new_concept_cfg,
114
+ device,
115
+ num_images_per_prompt,
116
+ do_classifier_free_guidance,
117
+ negative_prompt=None,
118
+ prompt_embeds: Optional[torch.FloatTensor] = None,
119
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None
120
+ ):
121
+
122
+ assert num_images_per_prompt == 1, 'only support num_images_per_prompt=1 now'
123
+
124
+ if prompt is not None and isinstance(prompt, str):
125
+ batch_size = 1
126
+ elif prompt is not None and isinstance(prompt, list):
127
+ batch_size = len(prompt)
128
+ else:
129
+ batch_size = prompt_embeds.shape[0]
130
+
131
+ if prompt_embeds is None:
132
+
133
+ prompt_extend = bind_concept_prompt(prompt, new_concept_cfg)
134
+
135
+ text_inputs = self.tokenizer(
136
+ prompt_extend,
137
+ padding='max_length',
138
+ max_length=self.tokenizer.model_max_length,
139
+ truncation=True,
140
+ return_tensors='pt',
141
+ )
142
+ text_input_ids = text_inputs.input_ids
143
+
144
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
145
+ prompt_embeds = rearrange(prompt_embeds, '(b n) m c -> b n m c', b=batch_size)
146
+
147
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
148
+
149
+ bs_embed, layer_num, seq_len, _ = prompt_embeds.shape
150
+
151
+ # get unconditional embeddings for classifier free guidance
152
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
153
+ uncond_tokens: List[str]
154
+ if negative_prompt is None:
155
+ uncond_tokens = [''] * batch_size
156
+ elif type(prompt) is not type(negative_prompt):
157
+ raise TypeError(f'`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !='
158
+ f' {type(prompt)}.')
159
+ elif isinstance(negative_prompt, str):
160
+ uncond_tokens = [negative_prompt]
161
+ elif batch_size != len(negative_prompt):
162
+ raise ValueError(
163
+ f'`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:'
164
+ f' {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches'
165
+ ' the batch size of `prompt`.')
166
+ else:
167
+ uncond_tokens = negative_prompt
168
+
169
+ uncond_input = self.tokenizer(
170
+ uncond_tokens,
171
+ padding='max_length',
172
+ max_length=seq_len,
173
+ truncation=True,
174
+ return_tensors='pt',
175
+ )
176
+
177
+ negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device))[0]
178
+
179
+ if do_classifier_free_guidance:
180
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
181
+ seq_len = negative_prompt_embeds.shape[1]
182
+
183
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
184
+ negative_prompt_embeds = (negative_prompt_embeds).view(batch_size, 1, seq_len, -1).repeat(1, layer_num, 1, 1)
185
+
186
+ # For classifier free guidance, we need to do two forward passes.
187
+ # Here we concatenate the unconditional and text embeddings into a single batch
188
+ # to avoid doing two forward passes
189
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
190
+ return prompt_embeds
191
+
192
+ @torch.no_grad()
193
+ def __call__(
194
+ self,
195
+ prompt: Union[str, List[str]] = None,
196
+ height: Optional[int] = None,
197
+ width: Optional[int] = None,
198
+ num_inference_steps: int = 50,
199
+ guidance_scale: float = 7.5,
200
+ negative_prompt: Optional[Union[str, List[str]]] = None,
201
+ num_images_per_prompt: Optional[int] = 1,
202
+ eta: float = 0.0,
203
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
204
+ latents: Optional[torch.FloatTensor] = None,
205
+ prompt_embeds: Optional[torch.FloatTensor] = None,
206
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
207
+ output_type: Optional[str] = 'pil',
208
+ return_dict: bool = True,
209
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
210
+ callback_steps: int = 1,
211
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
212
+ ):
213
+
214
+ # 0. Default height and width to unet
215
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
216
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
217
+
218
+ # 1. Check inputs. Raise error if not correct
219
+ self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
220
+
221
+ # 2. Define call parameters
222
+ if prompt is not None and isinstance(prompt, str):
223
+ batch_size = 1
224
+ elif prompt is not None and isinstance(prompt, list):
225
+ batch_size = len(prompt)
226
+ else:
227
+ batch_size = prompt_embeds.shape[0]
228
+
229
+ device = self._execution_device
230
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
231
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
232
+ # corresponds to doing no classifier free guidance.
233
+ do_classifier_free_guidance = guidance_scale > 1.0
234
+
235
+ # 3. Encode input prompt, this support pplus and edlora (layer-wise embedding)
236
+ assert self.new_concept_cfg is not None
237
+ prompt_embeds = self._encode_prompt(
238
+ prompt,
239
+ self.new_concept_cfg,
240
+ device,
241
+ num_images_per_prompt,
242
+ do_classifier_free_guidance,
243
+ negative_prompt,
244
+ prompt_embeds=prompt_embeds,
245
+ negative_prompt_embeds=negative_prompt_embeds,
246
+ )
247
+
248
+ # 4. Prepare timesteps
249
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
250
+ timesteps = self.scheduler.timesteps
251
+
252
+ # 5. Prepare latent variables
253
+ num_channels_latents = self.unet.in_channels
254
+ latents = self.prepare_latents(
255
+ batch_size * num_images_per_prompt,
256
+ num_channels_latents,
257
+ height,
258
+ width,
259
+ prompt_embeds.dtype,
260
+ device,
261
+ generator,
262
+ latents,
263
+ )
264
+
265
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
266
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
267
+
268
+ # 7. Denoising loop
269
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
270
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
271
+ for i, t in enumerate(timesteps):
272
+ # expand the latents if we are doing classifier free guidance
273
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
274
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
275
+
276
+ # predict the noise residual
277
+ noise_pred = self.unet(
278
+ latent_model_input,
279
+ t,
280
+ encoder_hidden_states=prompt_embeds,
281
+ cross_attention_kwargs=cross_attention_kwargs,
282
+ ).sample
283
+
284
+ # perform guidance
285
+ if do_classifier_free_guidance:
286
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
287
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
288
+
289
+ # compute the previous noisy sample x_t -> x_t-1
290
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
291
+
292
+ if hasattr(self, 'controller'):
293
+ dtype = latents.dtype
294
+ latents = self.controller.step_callback(latents)
295
+ latents = latents.to(dtype)
296
+
297
+ # call the callback, if provided
298
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
299
+ progress_bar.update()
300
+ if callback is not None and i % callback_steps == 0:
301
+ callback(i, t, latents)
302
+
303
+ if output_type == 'latent':
304
+ image = latents
305
+ elif output_type == 'pil':
306
+ # 8. Post-processing
307
+ image = self.decode_latents(latents)
308
+
309
+ # 10. Convert to PIL
310
+ image = self.numpy_to_pil(image)
311
+ else:
312
+ # 8. Post-processing
313
+ image = self.decode_latents(latents)
314
+
315
+ # Offload last model to CPU
316
+ if hasattr(self, 'final_offload_hook') and self.final_offload_hook is not None:
317
+ self.final_offload_hook.offload()
318
+
319
+ if not return_dict:
320
+ return (image)
321
+
322
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
mixofshow/pipelines/pipeline_regionally_t2iadapter.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import PIL
5
+ import torch
6
+ from diffusers.image_processor import VaeImageProcessor
7
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
8
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
9
+ from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter import (StableDiffusionAdapterPipeline,
10
+ StableDiffusionAdapterPipelineOutput,
11
+ _preprocess_adapter_image)
12
+ from diffusers.schedulers import KarrasDiffusionSchedulers
13
+ from diffusers.utils import logging
14
+ from diffusers.utils.import_utils import is_xformers_available
15
+ from einops import rearrange
16
+ from torch import einsum
17
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
18
+
19
+ if is_xformers_available():
20
+ import xformers
21
+
22
+ from mixofshow.pipelines.pipeline_edlora import bind_concept_prompt
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+
27
+ class RegionT2I_AttnProcessor:
28
+ def __init__(self, cross_attention_idx, attention_op=None):
29
+ self.attention_op = attention_op
30
+ self.cross_attention_idx = cross_attention_idx
31
+
32
+ def region_rewrite(self, attn, hidden_states, query, region_list, height, width):
33
+
34
+ def get_region_mask(region_list, feat_height, feat_width):
35
+ exclusive_mask = torch.zeros((feat_height, feat_width))
36
+ for region in region_list:
37
+ start_h, start_w, end_h, end_w = region[-1]
38
+ start_h, start_w, end_h, end_w = math.ceil(start_h * feat_height), math.ceil(
39
+ start_w * feat_width), math.floor(end_h * feat_height), math.floor(end_w * feat_width)
40
+ exclusive_mask[start_h:end_h, start_w:end_w] += 1
41
+ return exclusive_mask
42
+
43
+ dtype = query.dtype
44
+ seq_lens = query.shape[1]
45
+ downscale = math.sqrt(height * width / seq_lens)
46
+
47
+ # 0: context >=1: may be overlap
48
+ feat_height, feat_width = int(height // downscale), int(width // downscale)
49
+ region_mask = get_region_mask(region_list, feat_height, feat_width)
50
+
51
+ query = rearrange(query, 'b (h w) c -> b h w c', h=feat_height, w=feat_width)
52
+ hidden_states = rearrange(hidden_states, 'b (h w) c -> b h w c', h=feat_height, w=feat_width)
53
+
54
+ new_hidden_state = torch.zeros_like(hidden_states)
55
+ new_hidden_state[:, region_mask == 0, :] = hidden_states[:, region_mask == 0, :]
56
+
57
+ replace_ratio = 1.0
58
+ new_hidden_state[:, region_mask != 0, :] = (1 - replace_ratio) * hidden_states[:, region_mask != 0, :]
59
+
60
+ for region in region_list:
61
+ region_key, region_value, region_box = region
62
+
63
+ if attn.upcast_attention:
64
+ query = query.float()
65
+ region_key = region_key.float()
66
+
67
+ start_h, start_w, end_h, end_w = region_box
68
+ start_h, start_w, end_h, end_w = math.ceil(start_h * feat_height), math.ceil(
69
+ start_w * feat_width), math.floor(end_h * feat_height), math.floor(end_w * feat_width)
70
+
71
+ attention_region = einsum('b h w c, b n c -> b h w n', query[:, start_h:end_h, start_w:end_w, :], region_key) * attn.scale
72
+ if attn.upcast_softmax:
73
+ attention_region = attention_region.float()
74
+
75
+ attention_region = attention_region.softmax(dim=-1)
76
+ attention_region = attention_region.to(dtype)
77
+
78
+ hidden_state_region = einsum('b h w n, b n c -> b h w c', attention_region, region_value)
79
+ new_hidden_state[:, start_h:end_h, start_w:end_w, :] += \
80
+ replace_ratio * (hidden_state_region / (
81
+ region_mask.reshape(
82
+ 1, *region_mask.shape, 1)[:, start_h:end_h, start_w:end_w, :]
83
+ ).to(query.device))
84
+
85
+ new_hidden_state = rearrange(new_hidden_state, 'b h w c -> b (h w) c')
86
+ return new_hidden_state
87
+
88
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, **cross_attention_kwargs):
89
+ batch_size, sequence_length, _ = hidden_states.shape
90
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
91
+ query = attn.to_q(hidden_states)
92
+
93
+ if encoder_hidden_states is None:
94
+ is_cross = False
95
+ encoder_hidden_states = hidden_states
96
+ else:
97
+ is_cross = True
98
+
99
+ if len(encoder_hidden_states.shape) == 4: # multi-layer embedding
100
+ encoder_hidden_states = encoder_hidden_states[:, self.cross_attention_idx, ...]
101
+ else:
102
+ encoder_hidden_states = encoder_hidden_states
103
+
104
+ key = attn.to_k(encoder_hidden_states)
105
+ value = attn.to_v(encoder_hidden_states)
106
+
107
+ query = attn.head_to_batch_dim(query)
108
+ key = attn.head_to_batch_dim(key)
109
+ value = attn.head_to_batch_dim(value)
110
+
111
+ if is_xformers_available() and not is_cross:
112
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
113
+ hidden_states = hidden_states.to(query.dtype)
114
+ else:
115
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
116
+ hidden_states = torch.bmm(attention_probs, value)
117
+
118
+ if is_cross:
119
+ region_list = []
120
+ for region in cross_attention_kwargs['region_list']:
121
+ if len(region[0].shape) == 4:
122
+ region_key = attn.to_k(region[0][:, self.cross_attention_idx, ...])
123
+ region_value = attn.to_v(region[0][:, self.cross_attention_idx, ...])
124
+ else:
125
+ region_key = attn.to_k(region[0])
126
+ region_value = attn.to_v(region[0])
127
+ region_key = attn.head_to_batch_dim(region_key)
128
+ region_value = attn.head_to_batch_dim(region_value)
129
+ region_list.append((region_key, region_value, region[1]))
130
+
131
+ hidden_states = self.region_rewrite(
132
+ attn=attn,
133
+ hidden_states=hidden_states,
134
+ query=query,
135
+ region_list=region_list,
136
+ height=cross_attention_kwargs['height'],
137
+ width=cross_attention_kwargs['width'])
138
+
139
+ hidden_states = attn.batch_to_head_dim(hidden_states)
140
+
141
+ # linear proj
142
+ hidden_states = attn.to_out[0](hidden_states)
143
+ # dropout
144
+ hidden_states = attn.to_out[1](hidden_states)
145
+ return hidden_states
146
+
147
+
148
+ def revise_regionally_t2iadapter_attention_forward(unet):
149
+ def change_forward(unet, count):
150
+ for name, layer in unet.named_children():
151
+ if layer.__class__.__name__ == 'Attention':
152
+ layer.set_processor(RegionT2I_AttnProcessor(count))
153
+ if 'attn2' in name:
154
+ count += 1
155
+ else:
156
+ count = change_forward(layer, count)
157
+ return count
158
+
159
+ # use this to ensure the order
160
+ cross_attention_idx = change_forward(unet.down_blocks, 0)
161
+ cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx)
162
+ cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx)
163
+ print(f'Number of attention layer registered {cross_attention_idx}')
164
+
165
+
166
+ class RegionallyT2IAdapterPipeline(StableDiffusionAdapterPipeline):
167
+ _optional_components = ['safety_checker', 'feature_extractor']
168
+
169
+ def __init__(
170
+ self,
171
+ vae: AutoencoderKL,
172
+ text_encoder: CLIPTextModel,
173
+ tokenizer: CLIPTokenizer,
174
+ unet: UNet2DConditionModel,
175
+ scheduler: KarrasDiffusionSchedulers,
176
+ safety_checker: StableDiffusionSafetyChecker,
177
+ feature_extractor: CLIPFeatureExtractor,
178
+ requires_safety_checker: bool = False,
179
+ ):
180
+
181
+ if safety_checker is None and requires_safety_checker:
182
+ logger.warning(
183
+ f'You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure'
184
+ ' that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered'
185
+ ' results in services or applications open to the public. Both the diffusers team and Hugging Face'
186
+ ' strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling'
187
+ ' it only for use-cases that involve analyzing network behavior or auditing its results. For more'
188
+ ' information, please have a look at https://github.com/huggingface/diffusers/pull/254 .'
189
+ )
190
+
191
+ if safety_checker is not None and feature_extractor is None:
192
+ raise ValueError(
193
+ 'Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety'
194
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
195
+ )
196
+
197
+ self.register_modules(
198
+ vae=vae,
199
+ text_encoder=text_encoder,
200
+ tokenizer=tokenizer,
201
+ unet=unet,
202
+ scheduler=scheduler,
203
+ safety_checker=safety_checker,
204
+ feature_extractor=feature_extractor,
205
+ )
206
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
207
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
208
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
209
+ self.new_concept_cfg = None
210
+ revise_regionally_t2iadapter_attention_forward(self.unet)
211
+
212
+ def set_new_concept_cfg(self, new_concept_cfg=None):
213
+ self.new_concept_cfg = new_concept_cfg
214
+
215
+ def _encode_region_prompt(self,
216
+ prompt,
217
+ new_concept_cfg,
218
+ device,
219
+ num_images_per_prompt,
220
+ do_classifier_free_guidance,
221
+ negative_prompt=None,
222
+ prompt_embeds: Optional[torch.FloatTensor] = None,
223
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
224
+ height=512,
225
+ width=512
226
+ ):
227
+ if prompt is not None and isinstance(prompt, str):
228
+ batch_size = 1
229
+ elif prompt is not None and isinstance(prompt, list):
230
+ batch_size = len(prompt)
231
+ else:
232
+ batch_size = prompt_embeds.shape[0]
233
+
234
+ assert batch_size == 1, 'only sample one prompt once in this version'
235
+
236
+ if prompt_embeds is None:
237
+ context_prompt, region_list = prompt[0][0], prompt[0][1]
238
+ context_prompt = bind_concept_prompt([context_prompt], new_concept_cfg)
239
+ context_prompt_input_ids = self.tokenizer(
240
+ context_prompt,
241
+ padding='max_length',
242
+ max_length=self.tokenizer.model_max_length,
243
+ truncation=True,
244
+ return_tensors='pt',
245
+ ).input_ids
246
+
247
+ prompt_embeds = self.text_encoder(context_prompt_input_ids.to(device), attention_mask=None)[0]
248
+ prompt_embeds = rearrange(prompt_embeds, '(b n) m c -> b n m c', b=batch_size)
249
+ prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
250
+
251
+ bs_embed, layer_num, seq_len, _ = prompt_embeds.shape
252
+
253
+ if negative_prompt is None:
254
+ negative_prompt = [''] * batch_size
255
+
256
+ negative_prompt_input_ids = self.tokenizer(
257
+ negative_prompt,
258
+ padding='max_length',
259
+ max_length=self.tokenizer.model_max_length,
260
+ truncation=True,
261
+ return_tensors='pt').input_ids
262
+
263
+ negative_prompt_embeds = self.text_encoder(
264
+ negative_prompt_input_ids.to(device),
265
+ attention_mask=None,
266
+ )[0]
267
+
268
+ negative_prompt_embeds = (negative_prompt_embeds).view(batch_size, 1, seq_len, -1).repeat(1, layer_num, 1, 1)
269
+ negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
270
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
271
+
272
+ for idx, region in enumerate(region_list):
273
+ region_prompt, region_neg_prompt, pos = region
274
+ region_prompt = bind_concept_prompt([region_prompt], new_concept_cfg)
275
+ region_prompt_input_ids = self.tokenizer(
276
+ region_prompt,
277
+ padding='max_length',
278
+ max_length=self.tokenizer.model_max_length,
279
+ truncation=True,
280
+ return_tensors='pt').input_ids
281
+ region_embeds = self.text_encoder(region_prompt_input_ids.to(device), attention_mask=None)[0]
282
+ region_embeds = rearrange(region_embeds, '(b n) m c -> b n m c', b=batch_size)
283
+ region_embeds.to(dtype=self.text_encoder.dtype, device=device)
284
+ bs_embed, layer_num, seq_len, _ = region_embeds.shape
285
+
286
+ if region_neg_prompt is None:
287
+ region_neg_prompt = [''] * batch_size
288
+ region_negprompt_input_ids = self.tokenizer(
289
+ region_neg_prompt,
290
+ padding='max_length',
291
+ max_length=self.tokenizer.model_max_length,
292
+ truncation=True,
293
+ return_tensors='pt').input_ids
294
+ region_neg_embeds = self.text_encoder(region_negprompt_input_ids.to(device), attention_mask=None)[0]
295
+ region_neg_embeds = (region_neg_embeds).view(batch_size, 1, seq_len, -1).repeat(1, layer_num, 1, 1)
296
+ region_neg_embeds.to(dtype=self.text_encoder.dtype, device=device)
297
+ region_list[idx] = (torch.cat([region_neg_embeds, region_embeds]), pos)
298
+
299
+ return prompt_embeds, region_list
300
+
301
+ @torch.no_grad()
302
+ def __call__(
303
+ self,
304
+ prompt: Union[str, List[str]] = None,
305
+ keypose_adapter_input: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
306
+ keypose_adaptor_weight=1.0,
307
+ region_keypose_adaptor_weight='',
308
+ sketch_adapter_input: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
309
+ sketch_adaptor_weight=1.0,
310
+ region_sketch_adaptor_weight='',
311
+ height: Optional[int] = None,
312
+ width: Optional[int] = None,
313
+ num_inference_steps: int = 50,
314
+ guidance_scale: float = 7.5,
315
+ negative_prompt: Optional[Union[str, List[str]]] = None,
316
+ num_images_per_prompt: Optional[int] = 1,
317
+ eta: float = 0.0,
318
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
319
+ latents: Optional[torch.FloatTensor] = None,
320
+ prompt_embeds: Optional[torch.FloatTensor] = None,
321
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
322
+ output_type: Optional[str] = 'pil',
323
+ return_dict: bool = True,
324
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
325
+ callback_steps: int = 1,
326
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
327
+ ):
328
+ r"""
329
+ Function invoked when calling the pipeline for generation.
330
+
331
+ Args:
332
+ prompt (`str` or `List[str]`, *optional*):
333
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
334
+ instead.
335
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`):
336
+ The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the
337
+ type is specified as `Torch.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be
338
+ accepted as an image. The control image is automatically resized to fit the output image.
339
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
340
+ The height in pixels of the generated image.
341
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
342
+ The width in pixels of the generated image.
343
+ num_inference_steps (`int`, *optional*, defaults to 50):
344
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
345
+ expense of slower inference.
346
+ guidance_scale (`float`, *optional*, defaults to 7.5):
347
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
348
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
349
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
350
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
351
+ usually at the expense of lower image quality.
352
+ negative_prompt (`str` or `List[str]`, *optional*):
353
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
354
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
355
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
356
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
357
+ The number of images to generate per prompt.
358
+ eta (`float`, *optional*, defaults to 0.0):
359
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
360
+ [`schedulers.DDIMScheduler`], will be ignored for others.
361
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
362
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
363
+ to make generation deterministic.
364
+ latents (`torch.FloatTensor`, *optional*):
365
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
366
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
367
+ tensor will ge generated by sampling using the supplied random `generator`.
368
+ prompt_embeds (`torch.FloatTensor`, *optional*):
369
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
370
+ provided, text embeddings will be generated from `prompt` input argument.
371
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
372
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
373
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
374
+ argument.
375
+ output_type (`str`, *optional*, defaults to `"pil"`):
376
+ The output format of the generate image. Choose between
377
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
378
+ return_dict (`bool`, *optional*, defaults to `True`):
379
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] instead
380
+ of a plain tuple.
381
+ callback (`Callable`, *optional*):
382
+ A function that will be called every `callback_steps` steps during inference. The function will be
383
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
384
+ callback_steps (`int`, *optional*, defaults to 1):
385
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
386
+ called at every step.
387
+ cross_attention_kwargs (`dict`, *optional*):
388
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
389
+ `self.processor` in
390
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
391
+ adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
392
+ The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the
393
+ residual in the original unet. If multiple adapters are specified in init, you can set the
394
+ corresponding scale as a list.
395
+
396
+ Examples:
397
+
398
+ Returns:
399
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`:
400
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a
401
+ `tuple. When returning a tuple, the first element is a list with the generated images, and the second
402
+ element is a list of `bool`s denoting whether the corresponding generated image likely represents
403
+ "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
404
+ """
405
+ # 0. Default height and width to unet
406
+ device = self._execution_device
407
+
408
+ # 1. Check inputs. Raise error if not correct
409
+ self.check_inputs(
410
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
411
+ )
412
+
413
+ if keypose_adapter_input is not None:
414
+ keypose_input = _preprocess_adapter_image(keypose_adapter_input, height, width).to(self.device)
415
+ keypose_input = keypose_input.to(self.keypose_adapter.dtype)
416
+ else:
417
+ keypose_input = None
418
+
419
+ if sketch_adapter_input is not None:
420
+ sketch_input = _preprocess_adapter_image(sketch_adapter_input, height, width).to(self.device)
421
+ sketch_input = sketch_input.to(self.sketch_adapter.dtype)
422
+ else:
423
+ sketch_input = None
424
+
425
+ # 2. Define call parameters
426
+ if prompt is not None and isinstance(prompt, str):
427
+ batch_size = 1
428
+ elif prompt is not None and isinstance(prompt, list):
429
+ batch_size = len(prompt)
430
+ else:
431
+ batch_size = prompt_embeds.shape[0]
432
+
433
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
434
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
435
+ # corresponds to doing no classifier free guidance.
436
+ do_classifier_free_guidance = guidance_scale > 1.0
437
+
438
+ # 3. Encode input prompt
439
+ assert self.new_concept_cfg is not None
440
+ prompt_embeds, region_list = self._encode_region_prompt(
441
+ prompt,
442
+ self.new_concept_cfg,
443
+ device,
444
+ num_images_per_prompt,
445
+ do_classifier_free_guidance,
446
+ negative_prompt,
447
+ prompt_embeds=prompt_embeds,
448
+ negative_prompt_embeds=negative_prompt_embeds,
449
+ height=height,
450
+ width=width
451
+ )
452
+
453
+ # 4. Prepare timesteps
454
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
455
+ timesteps = self.scheduler.timesteps
456
+
457
+ # 5. Prepare latent variables
458
+ num_channels_latents = self.unet.config.in_channels
459
+ latents = self.prepare_latents(
460
+ batch_size * num_images_per_prompt,
461
+ num_channels_latents,
462
+ height,
463
+ width,
464
+ prompt_embeds.dtype,
465
+ device,
466
+ generator,
467
+ latents,
468
+ )
469
+
470
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
471
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
472
+
473
+ # 7. Denoising loop
474
+ if keypose_input is not None:
475
+ keypose_adapter_state = self.keypose_adapter(keypose_input)
476
+ else:
477
+ keypose_adapter_state = None
478
+
479
+ if sketch_input is not None:
480
+ sketch_adapter_state = self.sketch_adapter(sketch_input)
481
+ else:
482
+ sketch_adapter_state = None
483
+
484
+ num_states = len(keypose_adapter_state) if keypose_adapter_state is not None else len(sketch_adapter_state)
485
+
486
+ adapter_state = []
487
+
488
+ for idx in range(num_states):
489
+ if keypose_adapter_state is not None:
490
+ feat_keypose = keypose_adapter_state[idx]
491
+
492
+ spatial_adaptor_weight = keypose_adaptor_weight * torch.ones(*feat_keypose.shape[2:]).to(
493
+ feat_keypose.dtype).to(feat_keypose.device)
494
+
495
+ if region_keypose_adaptor_weight != '':
496
+ region_list = region_keypose_adaptor_weight.split('|')
497
+
498
+ for region_weight in region_list:
499
+ region, weight = region_weight.split('-')
500
+ region = eval(region)
501
+ weight = eval(weight)
502
+ feat_height, feat_width = feat_keypose.shape[2:]
503
+ start_h, start_w, end_h, end_w = region
504
+ start_h, end_h = start_h / height, end_h / height
505
+ start_w, end_w = start_w / width, end_w / width
506
+
507
+ start_h, start_w, end_h, end_w = math.ceil(start_h * feat_height), math.ceil(
508
+ start_w * feat_width), math.floor(end_h * feat_height), math.floor(end_w * feat_width)
509
+
510
+ spatial_adaptor_weight[start_h:end_h, start_w:end_w] = weight
511
+ feat_keypose = spatial_adaptor_weight * feat_keypose
512
+
513
+ else:
514
+ feat_keypose = 0
515
+
516
+ if sketch_adapter_state is not None:
517
+ feat_sketch = sketch_adapter_state[idx]
518
+ # print(feat_keypose.shape) # torch.Size([1, 320, 64, 128])
519
+ spatial_adaptor_weight = sketch_adaptor_weight * torch.ones(*feat_sketch.shape[2:]).to(
520
+ feat_sketch.dtype).to(feat_sketch.device)
521
+
522
+ if region_sketch_adaptor_weight != '':
523
+ region_list = region_sketch_adaptor_weight.split('|')
524
+
525
+ for region_weight in region_list:
526
+ region, weight = region_weight.split('-')
527
+ region = eval(region)
528
+ weight = eval(weight)
529
+ feat_height, feat_width = feat_sketch.shape[2:]
530
+ start_h, start_w, end_h, end_w = region
531
+ start_h, end_h = start_h / height, end_h / height
532
+ start_w, end_w = start_w / width, end_w / width
533
+
534
+ start_h, start_w, end_h, end_w = math.ceil(start_h * feat_height), math.ceil(
535
+ start_w * feat_width), math.floor(end_h * feat_height), math.floor(end_w * feat_width)
536
+
537
+ spatial_adaptor_weight[start_h:end_h, start_w:end_w] = weight
538
+ feat_sketch = spatial_adaptor_weight * feat_sketch
539
+ else:
540
+ feat_sketch = 0
541
+
542
+ adapter_state.append(feat_keypose + feat_sketch)
543
+
544
+ if do_classifier_free_guidance:
545
+ for k, v in enumerate(adapter_state):
546
+ adapter_state[k] = torch.cat([v] * 2, dim=0)
547
+
548
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
549
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
550
+ for i, t in enumerate(timesteps):
551
+ # expand the latents if we are doing classifier free guidance
552
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
553
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
554
+
555
+ # predict the noise residual
556
+ noise_pred = self.unet(
557
+ latent_model_input,
558
+ t,
559
+ encoder_hidden_states=prompt_embeds,
560
+ cross_attention_kwargs={
561
+ 'region_list': region_list,
562
+ 'height': height,
563
+ 'width': width,
564
+ },
565
+ down_block_additional_residuals=[state.clone() for state in adapter_state],
566
+ ).sample
567
+
568
+ # perform guidance
569
+ if do_classifier_free_guidance:
570
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
571
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
572
+
573
+ # compute the previous noisy sample x_t -> x_t-1
574
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
575
+
576
+ # call the callback, if provided
577
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
578
+ progress_bar.update()
579
+ if callback is not None and i % callback_steps == 0:
580
+ callback(i, t, latents)
581
+
582
+ if output_type == 'latent':
583
+ image = latents
584
+ has_nsfw_concept = None
585
+ elif output_type == 'pil':
586
+ # 8. Post-processing
587
+ image = self.decode_latents(latents)
588
+
589
+ # 9. Run safety checker
590
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
591
+
592
+ # 10. Convert to PIL
593
+ image = self.numpy_to_pil(image)
594
+ else:
595
+ # 8. Post-processing
596
+ image = self.decode_latents(latents)
597
+
598
+ # 9. Run safety checker
599
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
600
+
601
+ # Offload last model to CPU
602
+ if hasattr(self, 'final_offload_hook') and self.final_offload_hook is not None:
603
+ self.final_offload_hook.offload()
604
+
605
+ if not return_dict:
606
+ return (image, has_nsfw_concept)
607
+
608
+ return StableDiffusionAdapterPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
mixofshow/pipelines/trainer_edlora.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import math
3
+ import re
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from accelerate.logging import get_logger
9
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
10
+ from diffusers.utils.import_utils import is_xformers_available
11
+ from einops import rearrange
12
+ from transformers import CLIPTextModel, CLIPTokenizer
13
+
14
+ from mixofshow.models.edlora import (LoRALinearLayer, revise_edlora_unet_attention_controller_forward,
15
+ revise_edlora_unet_attention_forward)
16
+ from mixofshow.pipelines.pipeline_edlora import bind_concept_prompt
17
+ from mixofshow.utils.ptp_util import AttentionStore
18
+
19
+
20
+ class EDLoRATrainer(nn.Module):
21
+ def __init__(
22
+ self,
23
+ pretrained_path,
24
+ new_concept_token,
25
+ initializer_token,
26
+ enable_edlora, # true for ED-LoRA, false for LoRA
27
+ finetune_cfg=None,
28
+ noise_offset=None,
29
+ attn_reg_weight=None,
30
+ reg_full_identity=True, # True for thanos, False for real person (don't need to encode clothes)
31
+ use_mask_loss=True,
32
+ enable_xformers=False,
33
+ gradient_checkpoint=False
34
+ ):
35
+ super().__init__()
36
+
37
+ # 1. Load the model.
38
+ self.vae = AutoencoderKL.from_pretrained(pretrained_path, subfolder='vae')
39
+ self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_path, subfolder='tokenizer')
40
+ self.text_encoder = CLIPTextModel.from_pretrained(pretrained_path, subfolder='text_encoder')
41
+ self.unet = UNet2DConditionModel.from_pretrained(pretrained_path, subfolder='unet')
42
+
43
+ if gradient_checkpoint:
44
+ self.unet.enable_gradient_checkpointing()
45
+
46
+ if enable_xformers:
47
+ assert is_xformers_available(), 'need to install xformer first'
48
+
49
+ # 2. Define train scheduler
50
+ self.scheduler = DDPMScheduler.from_pretrained(pretrained_path, subfolder='scheduler')
51
+
52
+ # 3. define training cfg
53
+ self.enable_edlora = enable_edlora
54
+ self.new_concept_cfg = self.init_new_concept(new_concept_token, initializer_token, enable_edlora=enable_edlora)
55
+
56
+ self.attn_reg_weight = attn_reg_weight
57
+ self.reg_full_identity = reg_full_identity
58
+ if self.attn_reg_weight is not None:
59
+ self.controller = AttentionStore(training=True)
60
+ revise_edlora_unet_attention_controller_forward(self.unet, self.controller) # support both lora and edlora forward
61
+ else:
62
+ revise_edlora_unet_attention_forward(self.unet) # support both lora and edlora forward
63
+
64
+ if finetune_cfg:
65
+ self.set_finetune_cfg(finetune_cfg)
66
+
67
+ self.noise_offset = noise_offset
68
+ self.use_mask_loss = use_mask_loss
69
+
70
+ def set_finetune_cfg(self, finetune_cfg):
71
+ logger = get_logger('mixofshow', log_level='INFO')
72
+ params_to_freeze = [self.vae.parameters(), self.text_encoder.parameters(), self.unet.parameters()]
73
+
74
+ # step 1: close all parameters, required_grad to False
75
+ for params in itertools.chain(*params_to_freeze):
76
+ params.requires_grad = False
77
+
78
+ # step 2: begin to add trainable paramters
79
+ params_group_list = []
80
+
81
+ # 1. text embedding
82
+ if finetune_cfg['text_embedding']['enable_tuning']:
83
+ text_embedding_cfg = finetune_cfg['text_embedding']
84
+
85
+ params_list = []
86
+ for params in self.text_encoder.get_input_embeddings().parameters():
87
+ params.requires_grad = True
88
+ params_list.append(params)
89
+
90
+ params_group = {'params': params_list, 'lr': text_embedding_cfg['lr']}
91
+ if 'weight_decay' in text_embedding_cfg:
92
+ params_group.update({'weight_decay': text_embedding_cfg['weight_decay']})
93
+ params_group_list.append(params_group)
94
+ logger.info(f"optimizing embedding using lr: {text_embedding_cfg['lr']}")
95
+
96
+ # 2. text encoder
97
+ if finetune_cfg['text_encoder']['enable_tuning'] and finetune_cfg['text_encoder'].get('lora_cfg'):
98
+ text_encoder_cfg = finetune_cfg['text_encoder']
99
+
100
+ where = text_encoder_cfg['lora_cfg'].pop('where')
101
+ assert where in ['CLIPEncoderLayer', 'CLIPAttention']
102
+
103
+ self.text_encoder_lora = nn.ModuleList()
104
+ params_list = []
105
+
106
+ for name, module in self.text_encoder.named_modules():
107
+ if module.__class__.__name__ == where:
108
+ for child_name, child_module in module.named_modules():
109
+ if child_module.__class__.__name__ == 'Linear':
110
+ lora_module = LoRALinearLayer(name + '.' + child_name, child_module, **text_encoder_cfg['lora_cfg'])
111
+ self.text_encoder_lora.append(lora_module)
112
+ params_list.extend(list(lora_module.parameters()))
113
+
114
+ params_group_list.append({'params': params_list, 'lr': text_encoder_cfg['lr']})
115
+ logger.info(f"optimizing text_encoder ({len(self.text_encoder_lora)} LoRAs), using lr: {text_encoder_cfg['lr']}")
116
+
117
+ # 3. unet
118
+ if finetune_cfg['unet']['enable_tuning'] and finetune_cfg['unet'].get('lora_cfg'):
119
+ unet_cfg = finetune_cfg['unet']
120
+
121
+ where = unet_cfg['lora_cfg'].pop('where')
122
+ assert where in ['Transformer2DModel', 'Attention']
123
+
124
+ self.unet_lora = nn.ModuleList()
125
+ params_list = []
126
+
127
+ for name, module in self.unet.named_modules():
128
+ if module.__class__.__name__ == where:
129
+ for child_name, child_module in module.named_modules():
130
+ if child_module.__class__.__name__ == 'Linear' or (child_module.__class__.__name__ == 'Conv2d' and child_module.kernel_size == (1, 1)):
131
+ lora_module = LoRALinearLayer(name + '.' + child_name, child_module, **unet_cfg['lora_cfg'])
132
+ self.unet_lora.append(lora_module)
133
+ params_list.extend(list(lora_module.parameters()))
134
+
135
+ params_group_list.append({'params': params_list, 'lr': unet_cfg['lr']})
136
+ logger.info(f"optimizing unet ({len(self.unet_lora)} LoRAs), using lr: {unet_cfg['lr']}")
137
+
138
+ # 4. optimize params
139
+ self.params_to_optimize_iterator = params_group_list
140
+
141
+ def get_params_to_optimize(self):
142
+ return self.params_to_optimize_iterator
143
+
144
+ def init_new_concept(self, new_concept_tokens, initializer_tokens, enable_edlora=True):
145
+ logger = get_logger('mixofshow', log_level='INFO')
146
+ new_concept_cfg = {}
147
+ new_concept_tokens = new_concept_tokens.split('+')
148
+
149
+ if initializer_tokens is None:
150
+ initializer_tokens = ['<rand-0.017>'] * len(new_concept_tokens)
151
+ else:
152
+ initializer_tokens = initializer_tokens.split('+')
153
+ assert len(new_concept_tokens) == len(initializer_tokens), 'concept token should match init token.'
154
+
155
+ for idx, (concept_name, init_token) in enumerate(zip(new_concept_tokens, initializer_tokens)):
156
+ if enable_edlora:
157
+ num_new_embedding = 16
158
+ else:
159
+ num_new_embedding = 1
160
+ new_token_names = [f'<new{idx * num_new_embedding + layer_id}>' for layer_id in range(num_new_embedding)]
161
+
162
+ num_added_tokens = self.tokenizer.add_tokens(new_token_names)
163
+ assert num_added_tokens == len(new_token_names), 'some token is already in tokenizer'
164
+ new_token_ids = [self.tokenizer.convert_tokens_to_ids(token_name) for token_name in new_token_names]
165
+
166
+ # init embedding
167
+ self.text_encoder.resize_token_embeddings(len(self.tokenizer))
168
+ token_embeds = self.text_encoder.get_input_embeddings().weight.data
169
+
170
+ if init_token.startswith('<rand'):
171
+ sigma_val = float(re.findall(r'<rand-(.*)>', init_token)[0])
172
+ init_feature = torch.randn_like(token_embeds[0]) * sigma_val
173
+ logger.info(f'{concept_name} ({min(new_token_ids)}-{max(new_token_ids)}) is random initialized by: {init_token}')
174
+ else:
175
+ # Convert the initializer_token, placeholder_token to ids
176
+ init_token_ids = self.tokenizer.encode(init_token, add_special_tokens=False)
177
+ # print(token_ids)
178
+ # Check if initializer_token is a single token or a sequence of tokens
179
+ if len(init_token_ids) > 1 or init_token_ids[0] == 40497:
180
+ raise ValueError('The initializer token must be a single existing token.')
181
+ init_feature = token_embeds[init_token_ids]
182
+ logger.info(f'{concept_name} ({min(new_token_ids)}-{max(new_token_ids)}) is random initialized by existing token ({init_token}): {init_token_ids[0]}')
183
+
184
+ for token_id in new_token_ids:
185
+ token_embeds[token_id] = init_feature.clone()
186
+
187
+ new_concept_cfg.update({
188
+ concept_name: {
189
+ 'concept_token_ids': new_token_ids,
190
+ 'concept_token_names': new_token_names
191
+ }
192
+ })
193
+
194
+ return new_concept_cfg
195
+
196
+ def get_all_concept_token_ids(self):
197
+ new_concept_token_ids = []
198
+ for _, new_token_cfg in self.new_concept_cfg.items():
199
+ new_concept_token_ids.extend(new_token_cfg['concept_token_ids'])
200
+ return new_concept_token_ids
201
+
202
+ def forward(self, images, prompts, masks, img_masks):
203
+ latents = self.vae.encode(images).latent_dist.sample()
204
+ latents = latents * 0.18215
205
+
206
+ # Sample noise that we'll add to the latents
207
+ noise = torch.randn_like(latents)
208
+ if self.noise_offset is not None:
209
+ noise += self.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
210
+
211
+ bsz = latents.shape[0]
212
+ # Sample a random timestep for each image
213
+ timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (bsz, ), device=latents.device)
214
+ timesteps = timesteps.long()
215
+
216
+ # Add noise to the latents according to the noise magnitude at each timestep
217
+ # (this is the forward diffusion process)
218
+ noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
219
+
220
+ if self.enable_edlora:
221
+ prompts = bind_concept_prompt(prompts, new_concept_cfg=self.new_concept_cfg) # edlora
222
+
223
+ # get text ids
224
+ text_input_ids = self.tokenizer(
225
+ prompts,
226
+ padding='max_length',
227
+ max_length=self.tokenizer.model_max_length,
228
+ truncation=True,
229
+ return_tensors='pt').input_ids.to(latents.device)
230
+
231
+ # Get the text embedding for conditioning
232
+ encoder_hidden_states = self.text_encoder(text_input_ids)[0]
233
+ if self.enable_edlora:
234
+ encoder_hidden_states = rearrange(encoder_hidden_states, '(b n) m c -> b n m c', b=latents.shape[0]) # edlora
235
+
236
+ # Predict the noise residual
237
+ model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
238
+
239
+ # Get the target for loss depending on the prediction type
240
+ if self.scheduler.config.prediction_type == 'epsilon':
241
+ target = noise
242
+ elif self.scheduler.config.prediction_type == 'v_prediction':
243
+ target = self.scheduler.get_velocity(latents, noise, timesteps)
244
+ else:
245
+ raise ValueError(f'Unknown prediction type {self.scheduler.config.prediction_type}')
246
+
247
+ if self.use_mask_loss:
248
+ loss_mask = masks
249
+ else:
250
+ loss_mask = img_masks
251
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction='none')
252
+ loss = ((loss * loss_mask).sum([1, 2, 3]) / loss_mask.sum([1, 2, 3])).mean()
253
+
254
+ if self.attn_reg_weight is not None:
255
+ attention_maps = self.controller.get_average_attention()
256
+ attention_loss = self.cal_attn_reg(attention_maps, masks, text_input_ids)
257
+ if not torch.isnan(attention_loss): # full mask
258
+ loss = loss + attention_loss
259
+ self.controller.reset()
260
+
261
+ return loss
262
+
263
+ def cal_attn_reg(self, attention_maps, masks, text_input_ids):
264
+ '''
265
+ attention_maps: {down_cross:[], mid_cross:[], up_cross:[]}
266
+ masks: torch.Size([1, 1, 64, 64])
267
+ text_input_ids: torch.Size([16, 77])
268
+ '''
269
+ # step 1: find token position
270
+ batch_size = masks.shape[0]
271
+ text_input_ids = rearrange(text_input_ids, '(b l) n -> b l n', b=batch_size)
272
+ # print(masks.shape) # torch.Size([2, 1, 64, 64])
273
+ # print(text_input_ids.shape) # torch.Size([2, 16, 77])
274
+
275
+ new_token_pos = []
276
+ all_concept_token_ids = self.get_all_concept_token_ids()
277
+ for text in text_input_ids:
278
+ text = text[0] # even multi-layer embedding, we extract the first one
279
+ new_token_pos.append([idx for idx in range(len(text)) if text[idx] in all_concept_token_ids])
280
+
281
+ # step2: aggregate attention maps with resolution and concat heads
282
+ attention_groups = {'64': [], '32': [], '16': [], '8': []}
283
+ for _, attention_list in attention_maps.items():
284
+ for attn in attention_list:
285
+ res = int(math.sqrt(attn.shape[1]))
286
+ cross_map = attn.reshape(batch_size, -1, res, res, attn.shape[-1])
287
+ attention_groups[str(res)].append(cross_map)
288
+
289
+ for k, cross_map in attention_groups.items():
290
+ cross_map = torch.cat(cross_map, dim=-4) # concat heads
291
+ cross_map = cross_map.sum(-4) / cross_map.shape[-4] # e.g., 64 torch.Size([2, 64, 64, 77])
292
+ cross_map = torch.stack([batch_map[..., batch_pos] for batch_pos, batch_map in zip(new_token_pos, cross_map)]) # torch.Size([2, 64, 64, 2])
293
+ attention_groups[k] = cross_map
294
+
295
+ attn_reg_total = 0
296
+ # step3: calculate loss for each resolution: <new1> <new2> -> <new1> is to penalize outside mask, <new2> to align with mask
297
+ for k, cross_map in attention_groups.items():
298
+ map_adjective, map_subject = cross_map[..., 0], cross_map[..., 1]
299
+
300
+ map_subject = map_subject / map_subject.max()
301
+ map_adjective = map_adjective / map_adjective.max()
302
+
303
+ gt_mask = F.interpolate(masks, size=map_subject.shape[1:], mode='nearest').squeeze(1)
304
+
305
+ if self.reg_full_identity:
306
+ loss_subject = F.mse_loss(map_subject.float(), gt_mask.float(), reduction='mean')
307
+ else:
308
+ loss_subject = map_subject[gt_mask == 0].mean()
309
+
310
+ loss_adjective = map_adjective[gt_mask == 0].mean()
311
+
312
+ attn_reg_total += self.attn_reg_weight * (loss_subject + loss_adjective)
313
+ return attn_reg_total
314
+
315
+ def load_delta_state_dict(self, delta_state_dict):
316
+ # load embedding
317
+ logger = get_logger('mixofshow', log_level='INFO')
318
+
319
+ if 'new_concept_embedding' in delta_state_dict and len(delta_state_dict['new_concept_embedding']) != 0:
320
+ new_concept_tokens = list(delta_state_dict['new_concept_embedding'].keys())
321
+
322
+ # check whether new concept is initialized
323
+ token_embeds = self.text_encoder.get_input_embeddings().weight.data
324
+ if set(new_concept_tokens) != set(self.new_concept_cfg.keys()):
325
+ logger.warning('Your checkpoint have different concept with your model, loading existing concepts')
326
+
327
+ for concept_name, concept_cfg in self.new_concept_cfg.items():
328
+ logger.info(f'load: concept_{concept_name}')
329
+ token_embeds[concept_cfg['concept_token_ids']] = token_embeds[
330
+ concept_cfg['concept_token_ids']].copy_(delta_state_dict['new_concept_embedding'][concept_name])
331
+
332
+ # load text_encoder
333
+ if 'text_encoder' in delta_state_dict and len(delta_state_dict['text_encoder']) != 0:
334
+ load_keys = delta_state_dict['text_encoder'].keys()
335
+ if hasattr(self, 'text_encoder_lora') and len(load_keys) == 2 * len(self.text_encoder_lora):
336
+ logger.info('loading LoRA for text encoder:')
337
+ for lora_module in self.text_encoder_lora:
338
+ for name, param, in lora_module.named_parameters():
339
+ logger.info(f'load: {lora_module.name}.{name}')
340
+ param.data.copy_(delta_state_dict['text_encoder'][f'{lora_module.name}.{name}'])
341
+ else:
342
+ for name, param, in self.text_encoder.named_parameters():
343
+ if name in load_keys and 'token_embedding' not in name:
344
+ logger.info(f'load: {name}')
345
+ param.data.copy_(delta_state_dict['text_encoder'][f'{name}'])
346
+
347
+ # load unet
348
+ if 'unet' in delta_state_dict and len(delta_state_dict['unet']) != 0:
349
+ load_keys = delta_state_dict['unet'].keys()
350
+ if hasattr(self, 'unet_lora') and len(load_keys) == 2 * len(self.unet_lora):
351
+ logger.info('loading LoRA for unet:')
352
+ for lora_module in self.unet_lora:
353
+ for name, param, in lora_module.named_parameters():
354
+ logger.info(f'load: {lora_module.name}.{name}')
355
+ param.data.copy_(delta_state_dict['unet'][f'{lora_module.name}.{name}'])
356
+ else:
357
+ for name, param, in self.unet.named_parameters():
358
+ if name in load_keys:
359
+ logger.info(f'load: {name}')
360
+ param.data.copy_(delta_state_dict['unet'][f'{name}'])
361
+
362
+ def delta_state_dict(self):
363
+ delta_dict = {'new_concept_embedding': {}, 'text_encoder': {}, 'unet': {}}
364
+
365
+ # save_embedding
366
+ for concept_name, concept_cfg in self.new_concept_cfg.items():
367
+ learned_embeds = self.text_encoder.get_input_embeddings().weight[concept_cfg['concept_token_ids']]
368
+ delta_dict['new_concept_embedding'][concept_name] = learned_embeds.detach().cpu()
369
+
370
+ # save text model
371
+ for lora_module in self.text_encoder_lora:
372
+ for name, param, in lora_module.named_parameters():
373
+ delta_dict['text_encoder'][f'{lora_module.name}.{name}'] = param.cpu().clone()
374
+
375
+ # save unet model
376
+ for lora_module in self.unet_lora:
377
+ for name, param, in lora_module.named_parameters():
378
+ delta_dict['unet'][f'{lora_module.name}.{name}'] = param.cpu().clone()
379
+
380
+ return delta_dict
mixofshow/utils/__init__.py ADDED
File without changes
mixofshow/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (151 Bytes). View file
 
mixofshow/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (149 Bytes). View file
 
mixofshow/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (149 Bytes). View file
 
mixofshow/utils/__pycache__/convert_edlora_to_diffusers.cpython-38.pyc ADDED
Binary file (3.64 kB). View file
 
mixofshow/utils/__pycache__/convert_edlora_to_diffusers.cpython-39.pyc ADDED
Binary file (3.58 kB). View file
 
mixofshow/utils/__pycache__/ptp_util.cpython-38.pyc ADDED
Binary file (7.21 kB). View file
 
mixofshow/utils/__pycache__/ptp_util.cpython-39.pyc ADDED
Binary file (7.2 kB). View file
 
mixofshow/utils/__pycache__/registry.cpython-38.pyc ADDED
Binary file (2.49 kB). View file
 
mixofshow/utils/__pycache__/registry.cpython-39.pyc ADDED
Binary file (2.48 kB). View file
 
mixofshow/utils/__pycache__/util.cpython-310.pyc ADDED
Binary file (9.59 kB). View file
 
mixofshow/utils/__pycache__/util.cpython-38.pyc ADDED
Binary file (9.51 kB). View file
 
mixofshow/utils/__pycache__/util.cpython-39.pyc ADDED
Binary file (9.51 kB). View file
 
mixofshow/utils/arial.ttf ADDED
Binary file (367 kB). View file
 
mixofshow/utils/convert_edlora_to_diffusers.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+
4
+ def load_new_concept(pipe, new_concept_embedding, enable_edlora=True):
5
+ new_concept_cfg = {}
6
+
7
+ for idx, (concept_name, concept_embedding) in enumerate(new_concept_embedding.items()):
8
+ if enable_edlora:
9
+ num_new_embedding = 16
10
+ else:
11
+ num_new_embedding = 1
12
+ new_token_names = [f'<new{idx * num_new_embedding + layer_id}>' for layer_id in range(num_new_embedding)]
13
+ num_added_tokens = pipe.tokenizer.add_tokens(new_token_names)
14
+ assert num_added_tokens == len(new_token_names), 'some token is already in tokenizer'
15
+ new_token_ids = [pipe.tokenizer.convert_tokens_to_ids(token_name) for token_name in new_token_names]
16
+
17
+ # init embedding
18
+ pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))
19
+ token_embeds = pipe.text_encoder.get_input_embeddings().weight.data
20
+ token_embeds[new_token_ids] = concept_embedding.clone().to(token_embeds.device, dtype=token_embeds.dtype)
21
+ print(f'load embedding: {concept_name}')
22
+
23
+ new_concept_cfg.update({
24
+ concept_name: {
25
+ 'concept_token_ids': new_token_ids,
26
+ 'concept_token_names': new_token_names
27
+ }
28
+ })
29
+
30
+ return pipe, new_concept_cfg
31
+
32
+
33
+ def merge_lora_into_weight(original_state_dict, lora_state_dict, model_type, alpha):
34
+ def get_lora_down_name(original_layer_name):
35
+ if model_type == 'text_encoder':
36
+ lora_down_name = original_layer_name.replace('q_proj.weight', 'q_proj.lora_down.weight') \
37
+ .replace('k_proj.weight', 'k_proj.lora_down.weight') \
38
+ .replace('v_proj.weight', 'v_proj.lora_down.weight') \
39
+ .replace('out_proj.weight', 'out_proj.lora_down.weight') \
40
+ .replace('fc1.weight', 'fc1.lora_down.weight') \
41
+ .replace('fc2.weight', 'fc2.lora_down.weight')
42
+ else:
43
+ lora_down_name = k.replace('to_q.weight', 'to_q.lora_down.weight') \
44
+ .replace('to_k.weight', 'to_k.lora_down.weight') \
45
+ .replace('to_v.weight', 'to_v.lora_down.weight') \
46
+ .replace('to_out.0.weight', 'to_out.0.lora_down.weight') \
47
+ .replace('ff.net.0.proj.weight', 'ff.net.0.proj.lora_down.weight') \
48
+ .replace('ff.net.2.weight', 'ff.net.2.lora_down.weight') \
49
+ .replace('proj_out.weight', 'proj_out.lora_down.weight') \
50
+ .replace('proj_in.weight', 'proj_in.lora_down.weight')
51
+
52
+ return lora_down_name
53
+
54
+ assert model_type in ['unet', 'text_encoder']
55
+ new_state_dict = copy.deepcopy(original_state_dict)
56
+
57
+ load_cnt = 0
58
+ for k in new_state_dict.keys():
59
+ lora_down_name = get_lora_down_name(k)
60
+ lora_up_name = lora_down_name.replace('lora_down', 'lora_up')
61
+
62
+ if lora_up_name in lora_state_dict:
63
+ load_cnt += 1
64
+ original_params = new_state_dict[k]
65
+ lora_down_params = lora_state_dict[lora_down_name].to(original_params.device)
66
+ lora_up_params = lora_state_dict[lora_up_name].to(original_params.device)
67
+ if len(original_params.shape) == 4:
68
+ lora_param = lora_up_params.squeeze() @ lora_down_params.squeeze()
69
+ lora_param = lora_param.unsqueeze(-1).unsqueeze(-1)
70
+ else:
71
+ lora_param = lora_up_params @ lora_down_params
72
+ merge_params = original_params + alpha * lora_param
73
+ new_state_dict[k] = merge_params
74
+
75
+ print(f'load {load_cnt} LoRAs of {model_type}')
76
+ return new_state_dict
77
+
78
+
79
+ def convert_edlora(pipe, state_dict, enable_edlora, alpha=0.6):
80
+
81
+ state_dict = state_dict['params'] if 'params' in state_dict.keys() else state_dict
82
+
83
+ # step 1: load embedding
84
+ if 'new_concept_embedding' in state_dict and len(state_dict['new_concept_embedding']) != 0:
85
+ pipe, new_concept_cfg = load_new_concept(pipe, state_dict['new_concept_embedding'], enable_edlora)
86
+
87
+ # step 2: merge lora weight to unet
88
+ unet_lora_state_dict = state_dict['unet']
89
+ pretrained_unet_state_dict = pipe.unet.state_dict()
90
+ updated_unet_state_dict = merge_lora_into_weight(pretrained_unet_state_dict, unet_lora_state_dict, model_type='unet', alpha=alpha)
91
+ pipe.unet.load_state_dict(updated_unet_state_dict)
92
+
93
+ # step 3: merge lora weight to text_encoder
94
+ text_encoder_lora_state_dict = state_dict['text_encoder']
95
+ pretrained_text_encoder_state_dict = pipe.text_encoder.state_dict()
96
+ updated_text_encoder_state_dict = merge_lora_into_weight(pretrained_text_encoder_state_dict, text_encoder_lora_state_dict, model_type='text_encoder', alpha=alpha)
97
+ pipe.text_encoder.load_state_dict(updated_text_encoder_state_dict)
98
+
99
+ return pipe, new_concept_cfg
mixofshow/utils/ptp_util.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ from typing import List, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from IPython.display import display
8
+ from PIL import Image
9
+
10
+
11
+ class EmptyControl:
12
+ def step_callback(self, x_t):
13
+ return x_t
14
+
15
+ def between_steps(self):
16
+ return
17
+
18
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
19
+ return attn
20
+
21
+
22
+ class AttentionControl(abc.ABC):
23
+ def step_callback(self, x_t):
24
+ return x_t
25
+
26
+ def between_steps(self):
27
+ return
28
+
29
+ @property
30
+ def num_uncond_att_layers(self):
31
+ return self.num_att_layers if self.low_resource else 0
32
+
33
+ @abc.abstractmethod
34
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
35
+ raise NotImplementedError
36
+
37
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
38
+ if self.cur_att_layer >= self.num_uncond_att_layers:
39
+ if self.low_resource:
40
+ attn = self.forward(attn, is_cross, place_in_unet)
41
+ else:
42
+ if self.training:
43
+ attn = self.forward(attn, is_cross, place_in_unet)
44
+ else:
45
+ h = attn.shape[0]
46
+ attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
47
+
48
+ self.cur_att_layer += 1
49
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
50
+ self.cur_att_layer = 0
51
+ self.cur_step += 1
52
+ self.between_steps()
53
+ return attn
54
+
55
+ def reset(self):
56
+ self.cur_step = 0
57
+ self.cur_att_layer = 0
58
+
59
+ def __init__(self, low_resource, training):
60
+ self.cur_step = 0
61
+ self.num_att_layers = -1
62
+ self.cur_att_layer = 0
63
+ self.low_resource = low_resource
64
+ self.training = training
65
+
66
+
67
+ class AttentionStore(AttentionControl):
68
+ @staticmethod
69
+ def get_empty_store():
70
+ return {
71
+ 'down_cross': [],
72
+ 'mid_cross': [],
73
+ 'up_cross': [],
74
+ 'down_self': [],
75
+ 'mid_self': [],
76
+ 'up_self': []
77
+ }
78
+
79
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
80
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
81
+ self.step_store[key].append(attn)
82
+ return attn
83
+
84
+ def between_steps(self):
85
+ if len(self.attention_store) == 0:
86
+ self.attention_store = self.step_store
87
+ else:
88
+ for key in self.attention_store:
89
+ for i in range(len(self.attention_store[key])):
90
+ self.attention_store[key][i] = self.attention_store[key][i] + self.step_store[key][i]
91
+ self.step_store = self.get_empty_store()
92
+
93
+ def get_average_attention(self):
94
+ average_attention = {
95
+ key: [item / self.cur_step for item in self.attention_store[key]]
96
+ for key in self.attention_store
97
+ }
98
+ return average_attention
99
+
100
+ def reset(self):
101
+ super(AttentionStore, self).reset()
102
+ self.step_store = self.get_empty_store()
103
+ self.attention_store = {}
104
+
105
+ def __init__(self, low_resource=False, training=False):
106
+ super(AttentionStore, self).__init__(low_resource, training)
107
+ self.step_store = self.get_empty_store()
108
+ self.attention_store = {}
109
+
110
+
111
+ def text_under_image(image: np.ndarray,
112
+ text: str,
113
+ text_color: Tuple[int, int, int] = (0, 0, 0)):
114
+ h, w, c = image.shape
115
+ offset = int(h * .2)
116
+ img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
117
+ font = cv2.FONT_HERSHEY_SIMPLEX
118
+ # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
119
+ img[:h] = image
120
+ textsize = cv2.getTextSize(text, font, 1, 2)[0]
121
+ text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
122
+ cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2)
123
+ return img
124
+
125
+
126
+ def view_images(images, num_rows=1, offset_ratio=0.02, notebook=True):
127
+ if type(images) is list:
128
+ num_empty = len(images) % num_rows
129
+ elif images.ndim == 4:
130
+ num_empty = images.shape[0] % num_rows
131
+ else:
132
+ images = [images]
133
+ num_empty = 0
134
+
135
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
136
+ images = [image.astype(np.uint8)
137
+ for image in images] + [empty_images] * num_empty
138
+ num_items = len(images)
139
+
140
+ h, w, c = images[0].shape
141
+ offset = int(h * offset_ratio)
142
+ num_cols = num_items // num_rows
143
+ image_ = np.ones(
144
+ (h * num_rows + offset * (num_rows - 1), w * num_cols + offset *
145
+ (num_cols - 1), 3),
146
+ dtype=np.uint8) * 255
147
+ for i in range(num_rows):
148
+ for j in range(num_cols):
149
+ image_[i * (h + offset):i * (h + offset) + h:, j * (w + offset):j *
150
+ (w + offset) + w] = images[i * num_cols + j]
151
+
152
+ pil_img = Image.fromarray(image_)
153
+ if notebook is True:
154
+ display(pil_img)
155
+ else:
156
+ return pil_img
157
+
158
+
159
+ def aggregate_attention(attention_store: AttentionStore, res: int,
160
+ from_where: List[str], prompts: List[str],
161
+ is_cross: bool, select: int):
162
+ out = []
163
+ attention_maps = attention_store.get_average_attention()
164
+ num_pixels = res**2
165
+ for location in from_where:
166
+ for item in attention_maps[
167
+ f"{location}_{'cross' if is_cross else 'self'}"]:
168
+ if item.shape[1] == num_pixels:
169
+ cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
170
+ out.append(cross_maps)
171
+ out = torch.cat(out, dim=0)
172
+ out = out.sum(0) / out.shape[0]
173
+ return out.cpu()
174
+
175
+
176
+ def show_cross_attention(attention_store: AttentionStore,
177
+ res: int,
178
+ from_where: List[str],
179
+ prompts: List[str],
180
+ tokenizer,
181
+ select: int = 0,
182
+ notebook=True):
183
+ tokens = tokenizer.encode(prompts[select])
184
+ decoder = tokenizer.decode
185
+ attention_maps = aggregate_attention(attention_store, res, from_where, prompts, True, select)
186
+
187
+ images = []
188
+ for i in range(len(tokens)):
189
+ image = attention_maps[:, :, i]
190
+ image = 255 * image / image.max()
191
+ image = image.unsqueeze(-1).expand(*image.shape, 3)
192
+ image = image.numpy().astype(np.uint8)
193
+ image = np.array(Image.fromarray(image).resize((256, 256)))
194
+ image = text_under_image(image, decoder(int(tokens[i])))
195
+ images.append(image)
196
+
197
+ if notebook is True:
198
+ view_images(np.stack(images, axis=0))
199
+ else:
200
+ return view_images(np.stack(images, axis=0), notebook=False)
mixofshow/utils/registry.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
2
+
3
+
4
+ class Registry():
5
+ """
6
+ The registry that provides name -> object mapping, to support third-party
7
+ users' custom modules.
8
+
9
+ To create a registry (e.g. a backbone registry):
10
+
11
+ .. code-block:: python
12
+
13
+ BACKBONE_REGISTRY = Registry('BACKBONE')
14
+
15
+ To register an object:
16
+
17
+ .. code-block:: python
18
+
19
+ @BACKBONE_REGISTRY.register()
20
+ class MyBackbone():
21
+ ...
22
+
23
+ Or:
24
+
25
+ .. code-block:: python
26
+
27
+ BACKBONE_REGISTRY.register(MyBackbone)
28
+ """
29
+ def __init__(self, name):
30
+ """
31
+ Args:
32
+ name (str): the name of this registry
33
+ """
34
+ self._name = name
35
+ self._obj_map = {}
36
+
37
+ def _do_register(self, name, obj):
38
+ assert (name not in self._obj_map), (
39
+ f"An object named '{name}' was already registered "
40
+ f"in '{self._name}' registry!")
41
+ self._obj_map[name] = obj
42
+
43
+ def register(self, obj=None):
44
+ """
45
+ Register the given object under the the name `obj.__name__`.
46
+ Can be used as either a decorator or not.
47
+ See docstring of this class for usage.
48
+ """
49
+ if obj is None:
50
+ # used as a decorator
51
+ def deco(func_or_class):
52
+ name = func_or_class.__name__
53
+ self._do_register(name, func_or_class)
54
+ return func_or_class
55
+
56
+ return deco
57
+
58
+ # used as a function call
59
+ name = obj.__name__
60
+ self._do_register(name, obj)
61
+
62
+ def get(self, name):
63
+ ret = self._obj_map.get(name)
64
+ if ret is None:
65
+ raise KeyError(
66
+ f"No object named '{name}' found in '{self._name}' registry!")
67
+ return ret
68
+
69
+ def __contains__(self, name):
70
+ return name in self._obj_map
71
+
72
+ def __iter__(self):
73
+ return iter(self._obj_map.items())
74
+
75
+ def keys(self):
76
+ return self._obj_map.keys()
77
+
78
+
79
+ TRANSFORM_REGISTRY = Registry('transform')
mixofshow/utils/util.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import os
4
+ import os.path
5
+ import os.path as osp
6
+ import time
7
+ from collections import OrderedDict
8
+
9
+ import PIL
10
+ import torch
11
+ from accelerate.logging import get_logger
12
+ from accelerate.state import PartialState
13
+ from PIL import Image, ImageDraw, ImageFont
14
+ from torchvision.transforms.transforms import ToTensor
15
+ from torchvision.utils import make_grid
16
+
17
+ NEGATIVE_PROMPT = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
18
+
19
+
20
+ # ----------- file/logger util ----------
21
+ def get_time_str():
22
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
23
+
24
+
25
+ def mkdir_and_rename(path):
26
+ """mkdirs. If path exists, rename it with timestamp and create a new one.
27
+
28
+ Args:
29
+ path (str): Folder path.
30
+ """
31
+ if osp.exists(path):
32
+ new_name = path + '_archived_' + get_time_str()
33
+ print(f'Path already exists. Rename it to {new_name}', flush=True)
34
+ os.rename(path, new_name)
35
+ os.makedirs(path, exist_ok=True)
36
+
37
+
38
+ def make_exp_dirs(opt):
39
+ """Make dirs for experiments."""
40
+ path_opt = opt['path'].copy()
41
+ if opt['is_train']:
42
+ mkdir_and_rename(path_opt.pop('experiments_root'))
43
+ else:
44
+ mkdir_and_rename(path_opt.pop('results_root'))
45
+ for key, path in path_opt.items():
46
+ if ('strict_load' in key) or ('pretrain_network' in key) or (
47
+ 'resume' in key) or ('param_key' in key) or ('lora_path' in key):
48
+ continue
49
+ else:
50
+ os.makedirs(path, exist_ok=True)
51
+
52
+
53
+ def copy_opt_file(opt_file, experiments_root):
54
+ # copy the yml file to the experiment root
55
+ import sys
56
+ import time
57
+ from shutil import copyfile
58
+ cmd = ' '.join(sys.argv)
59
+ filename = osp.join(experiments_root, osp.basename(opt_file))
60
+ copyfile(opt_file, filename)
61
+
62
+ with open(filename, 'r+') as f:
63
+ lines = f.readlines()
64
+ lines.insert(
65
+ 0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
66
+ f.seek(0)
67
+ f.writelines(lines)
68
+
69
+
70
+ def set_path_logger(accelerator, root_path, config_path, opt, is_train=True):
71
+ opt['is_train'] = is_train
72
+
73
+ if is_train:
74
+ experiments_root = osp.join(root_path, 'experiments', opt['name'])
75
+ opt['path']['experiments_root'] = experiments_root
76
+ opt['path']['models'] = osp.join(experiments_root, 'models')
77
+ opt['path']['log'] = experiments_root
78
+ opt['path']['visualization'] = osp.join(experiments_root,
79
+ 'visualization')
80
+ else:
81
+ results_root = osp.join(root_path, 'results', opt['name'])
82
+ opt['path']['results_root'] = results_root
83
+ opt['path']['log'] = results_root
84
+ opt['path']['visualization'] = osp.join(results_root, 'visualization')
85
+
86
+ # Handle the output folder creation
87
+ if accelerator.is_main_process:
88
+ make_exp_dirs(opt)
89
+
90
+ accelerator.wait_for_everyone()
91
+
92
+ if is_train:
93
+ copy_opt_file(config_path, opt['path']['experiments_root'])
94
+ log_file = osp.join(opt['path']['log'],
95
+ f"train_{opt['name']}_{get_time_str()}.log")
96
+ set_logger(log_file)
97
+ else:
98
+ copy_opt_file(config_path, opt['path']['results_root'])
99
+ log_file = osp.join(opt['path']['log'],
100
+ f"test_{opt['name']}_{get_time_str()}.log")
101
+ set_logger(log_file)
102
+
103
+
104
+ def set_logger(log_file=None):
105
+ # Make one log on every process with the configuration for debugging.
106
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
107
+ log_level = logging.INFO
108
+ handlers = []
109
+
110
+ file_handler = logging.FileHandler(log_file, 'w')
111
+ file_handler.setFormatter(logging.Formatter(format_str))
112
+ file_handler.setLevel(log_level)
113
+ handlers.append(file_handler)
114
+
115
+ stream_handler = logging.StreamHandler()
116
+ stream_handler.setFormatter(logging.Formatter(format_str))
117
+ handlers.append(stream_handler)
118
+
119
+ logging.basicConfig(handlers=handlers, level=log_level)
120
+
121
+
122
+ def dict2str(opt, indent_level=1):
123
+ """dict to string for printing options.
124
+
125
+ Args:
126
+ opt (dict): Option dict.
127
+ indent_level (int): Indent level. Default: 1.
128
+
129
+ Return:
130
+ (str): Option string for printing.
131
+ """
132
+ msg = '\n'
133
+ for k, v in opt.items():
134
+ if isinstance(v, dict):
135
+ msg += ' ' * (indent_level * 2) + k + ':['
136
+ msg += dict2str(v, indent_level + 1)
137
+ msg += ' ' * (indent_level * 2) + ']\n'
138
+ else:
139
+ msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
140
+ return msg
141
+
142
+
143
+ class MessageLogger():
144
+ """Message logger for printing.
145
+
146
+ Args:
147
+ opt (dict): Config. It contains the following keys:
148
+ name (str): Exp name.
149
+ logger (dict): Contains 'print_freq' (str) for logger interval.
150
+ train (dict): Contains 'total_iter' (int) for total iters.
151
+ use_tb_logger (bool): Use tensorboard logger.
152
+ start_iter (int): Start iter. Default: 1.
153
+ tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
154
+ """
155
+ def __init__(self, opt, start_iter=1):
156
+ self.exp_name = opt['name']
157
+ self.interval = opt['logger']['print_freq']
158
+ self.start_iter = start_iter
159
+ self.max_iters = opt['train']['total_iter']
160
+ self.start_time = time.time()
161
+ self.logger = get_logger('mixofshow', log_level='INFO')
162
+
163
+ def reset_start_time(self):
164
+ self.start_time = time.time()
165
+
166
+ def __call__(self, log_vars):
167
+ """Format logging message.
168
+
169
+ Args:
170
+ log_vars (dict): It contains the following keys:
171
+ epoch (int): Epoch number.
172
+ iter (int): Current iter.
173
+ lrs (list): List for learning rates.
174
+
175
+ time (float): Iter time.
176
+ data_time (float): Data time for each iter.
177
+ """
178
+ # epoch, iter, learning rates
179
+ current_iter = log_vars.pop('iter')
180
+ lrs = log_vars.pop('lrs')
181
+
182
+ message = (
183
+ f'[{self.exp_name[:5]}..][Iter:{current_iter:8,d}, lr:('
184
+ )
185
+ for v in lrs:
186
+ message += f'{v:.3e},'
187
+ message += ')] '
188
+
189
+ # time and estimated time
190
+ total_time = time.time() - self.start_time
191
+ time_sec_avg = total_time / (current_iter - self.start_iter + 1)
192
+ eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
193
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
194
+ message += f'[eta: {eta_str}] '
195
+
196
+ # other items, especially losses
197
+ for k, v in log_vars.items():
198
+ message += f'{k}: {v:.4e} '
199
+
200
+ self.logger.info(message)
201
+
202
+
203
+ def reduce_loss_dict(accelerator, loss_dict):
204
+ """reduce loss dict.
205
+
206
+ In distributed training, it averages the losses among different GPUs .
207
+
208
+ Args:
209
+ loss_dict (OrderedDict): Loss dict.
210
+ """
211
+ with torch.no_grad():
212
+ keys = []
213
+ losses = []
214
+ for name, value in loss_dict.items():
215
+ keys.append(name)
216
+ losses.append(value)
217
+ losses = torch.stack(losses, 0)
218
+ losses = accelerator.reduce(losses)
219
+
220
+ world_size = PartialState().num_processes
221
+ losses /= world_size
222
+
223
+ loss_dict = {key: loss for key, loss in zip(keys, losses)}
224
+
225
+ log_dict = OrderedDict()
226
+ for name, value in loss_dict.items():
227
+ log_dict[name] = value.mean().item()
228
+
229
+ return log_dict
230
+
231
+
232
+ def pil_imwrite(img, file_path, auto_mkdir=True):
233
+ """Write image to file.
234
+ Args:
235
+ img (ndarray): Image array to be written.
236
+ file_path (str): Image file path.
237
+ params (None or list): Same as opencv's :func:`imwrite` interface.
238
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
239
+ whether to create it automatically.
240
+ Returns:
241
+ bool: Successful or not.
242
+ """
243
+ assert isinstance(
244
+ img, PIL.Image.Image), 'model should return a list of PIL images'
245
+ if auto_mkdir:
246
+ dir_name = os.path.abspath(os.path.dirname(file_path))
247
+ os.makedirs(dir_name, exist_ok=True)
248
+ img.save(file_path)
249
+
250
+
251
+ def draw_prompt(text, height, width, font_size=45):
252
+ img = Image.new('RGB', (width, height), (255, 255, 255))
253
+ draw = ImageDraw.Draw(img)
254
+ font = ImageFont.truetype(
255
+ osp.join(osp.dirname(osp.abspath(__file__)), 'arial.ttf'), font_size)
256
+
257
+ guess_count = 0
258
+
259
+ while font.font.getsize(text[:guess_count])[0][
260
+ 0] + 0.1 * width < width - 0.1 * width and guess_count < len(
261
+ text): # centerize
262
+ guess_count += 1
263
+
264
+ text_new = ''
265
+ for idx, s in enumerate(text):
266
+ if idx % guess_count == 0:
267
+ text_new += '\n'
268
+ if s == ' ':
269
+ s = '' # new line trip the first space
270
+ text_new += s
271
+
272
+ draw.text([int(0.1 * width), int(0.3 * height)],
273
+ text_new,
274
+ font=font,
275
+ fill='black')
276
+ return img
277
+
278
+
279
+ def compose_visualize(dir_path):
280
+ file_list = sorted(os.listdir(dir_path))
281
+ img_list = []
282
+ info_dict = {'prompts': set(), 'sample_args': set(), 'suffix': set()}
283
+ for filename in file_list:
284
+ prompt, sample_args, index, suffix = osp.splitext(
285
+ osp.basename(filename))[0].split('---')
286
+
287
+ filepath = osp.join(dir_path, filename)
288
+ img = ToTensor()(Image.open(filepath))
289
+ height, width = img.shape[1:]
290
+
291
+ if prompt not in info_dict['prompts']:
292
+ img_list.append(ToTensor()(draw_prompt(prompt,
293
+ height=height,
294
+ width=width,
295
+ font_size=45)))
296
+ info_dict['prompts'].add(prompt)
297
+ info_dict['sample_args'].add(sample_args)
298
+ info_dict['suffix'].add(suffix)
299
+
300
+ img_list.append(img)
301
+ assert len(
302
+ info_dict['sample_args']
303
+ ) == 1, 'compose dir should contain images form same sample args.'
304
+ assert len(info_dict['suffix']
305
+ ) == 1, 'compose dir should contain images form same suffix.'
306
+
307
+ grid = make_grid(img_list, nrow=len(img_list) // len(info_dict['prompts']))
308
+ # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
309
+ ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to(
310
+ 'cpu', torch.uint8).numpy()
311
+ im = Image.fromarray(ndarr)
312
+ save_name = f"{info_dict['sample_args'].pop()}---{info_dict['suffix'].pop()}.jpg"
313
+ im.save(osp.join(osp.dirname(dir_path), save_name))
orthogonal_mats/1280.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cacbfe6dda3140404a86019e487349f7b693667cd7efe5848fe9fc04b1a3618
3
+ size 13107328
orthogonal_mats/320.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99b8a3b3cf101ac83f43eda74392447207bdc4745e8e77626219d994ea8f2ae9
3
+ size 819328
orthogonal_mats/640.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c3c38d378cf4c22f7ab4cd3fb734c846afd4500ca58f253cb63c44304c360aa
3
+ size 3276928
orthogonal_mats/768.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64bd95b552140a47665e46cb5736cd8562e40b5f27f0f262a4b7563d13061daf
3
+ size 4718720