yeq6x commited on
Commit
a5e97ce
·
1 Parent(s): 738b300
Files changed (1) hide show
  1. app.py +144 -16
app.py CHANGED
@@ -9,6 +9,10 @@ import cv2
9
  import gradio as gr
10
  from torchvision import transforms
11
  from controlnet_aux import OpenposeDetector
 
 
 
 
12
 
13
  ratios_map = {
14
  0.5:{"width":704,"height":1408},
@@ -85,45 +89,170 @@ def resize_image_old(image):
85
 
86
 
87
  @spaces.GPU
88
- def generate_(prompt, negative_prompt, pose_image, input_image, num_steps, controlnet_conditioning_scale, seed):
89
- generator = torch.Generator("cuda").manual_seed(seed)
 
90
  images = pipe(
91
- prompt, negative_prompt=negative_prompt, image=pose_image, num_inference_steps=num_steps, controlnet_conditioning_scale=float(controlnet_conditioning_scale),
92
  generator=generator, height=input_image.size[1], width=input_image.size[0],
93
  ).images
94
  return images
95
 
96
  @spaces.GPU
97
- def process(input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed):
98
 
99
  # resize input_image to 1024x1024
100
  input_image = resize_image(input_image)
101
 
102
  pose_image = openpose(input_image, include_body=True, include_hand=True, include_face=True)
103
 
104
- images = generate_(prompt, negative_prompt, pose_image, input_image, num_steps, controlnet_conditioning_scale, seed)
105
 
106
  return [pose_image,images[0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  block = gr.Blocks().queue()
109
 
110
  with block:
111
  gr.Markdown("## BRIA 2.3 ControlNet Pose")
112
- gr.HTML('''
113
- <p style="margin-bottom: 10px; font-size: 94%">
114
- This is a demo for ControlNet Pose that using
115
- <a href="https://huggingface.co/briaai/BRIA-2.3" target="_blank">BRIA 2.3 text-to-image model</a> as backbone.
116
- Trained on licensed data, BRIA 2.3 provide full legal liability coverage for copyright and privacy infringement.
117
- </p>
118
- ''')
119
  with gr.Row():
120
  with gr.Column():
121
  input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
122
  prompt = gr.Textbox(label="Prompt")
123
  negative_prompt = gr.Textbox(label="Negative prompt", value="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers")
124
- num_steps = gr.Slider(label="Number of steps", minimum=25, maximum=100, value=50, step=1)
125
  controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=1.0, step=0.05)
126
- seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True,)
127
  run_button = gr.Button(value="Run")
128
 
129
  with gr.Column():
@@ -131,8 +260,7 @@ with block:
131
  pose_image_output = gr.Image(label="Pose Image", type="pil", interactive=False)
132
  generated_image_output = gr.Image(label="Generated Image", type="pil", interactive=False)
133
 
134
- ips = [input_image, prompt, negative_prompt, num_steps, controlnet_conditioning_scale, seed]
135
- run_button.click(fn=process, inputs=ips, outputs=[pose_image_output, generated_image_output])
136
 
137
 
138
  block.launch(debug = True)
 
9
  import gradio as gr
10
  from torchvision import transforms
11
  from controlnet_aux import OpenposeDetector
12
+ import random
13
+ import open3d as o3d
14
+ from collections import Counter
15
+ import trimesh
16
 
17
  ratios_map = {
18
  0.5:{"width":704,"height":1408},
 
89
 
90
 
91
  @spaces.GPU
92
+ def generate_(prompt, negative_prompt, pose_image, input_image, controlnet_conditioning_scale):
93
+ generator = torch.Generator()
94
+ generator.manual_seed(random.randint(0, 2147483647))
95
  images = pipe(
96
+ prompt, negative_prompt=negative_prompt, image=pose_image, num_inference_steps=20, controlnet_conditioning_scale=float(controlnet_conditioning_scale),
97
  generator=generator, height=input_image.size[1], width=input_image.size[0],
98
  ).images
99
  return images
100
 
101
  @spaces.GPU
102
+ def process(input_image, prompt, negative_prompt, controlnet_conditioning_scale):
103
 
104
  # resize input_image to 1024x1024
105
  input_image = resize_image(input_image)
106
 
107
  pose_image = openpose(input_image, include_body=True, include_hand=True, include_face=True)
108
 
109
+ images = generate_(prompt, negative_prompt, pose_image, input_image, controlnet_conditioning_scale)
110
 
111
  return [pose_image,images[0]]
112
+
113
+ @spaces.GPU
114
+ def predict_image(cond_image, prompt, negative_prompt, controlnet_conditioning_scale):
115
+ print("predict position map")
116
+ global pipe
117
+ generator = torch.Generator()
118
+ generator.manual_seed(random.randint(0, 2147483647))
119
+ image = pipe(
120
+ prompt,
121
+ negative_prompt=negative_prompt,
122
+ image = cond_image,
123
+ width=1024,
124
+ height=1024,
125
+ guidance_scale=8,
126
+ num_inference_steps=20,
127
+ generator=generator,
128
+ guess_mode = True,
129
+ controlnet_conditioning_scale = controlnet_conditioning_scale
130
+ ).images[0]
131
+
132
+ return image
133
+
134
+ def convert_pil_to_opencv(pil_image):
135
+ return np.array(pil_image)
136
+
137
+ def inv_func(y,
138
+ c = -712.380100,
139
+ a = 137.375240,
140
+ b = 192.435866):
141
+ return (np.exp((y - c) / a) - np.exp(-c/a)) / 964.8468371292845
142
+
143
+ def create_point_cloud(img1, img2):
144
+ if img1.shape != img2.shape:
145
+ raise ValueError("Both images must have the same dimensions.")
146
+
147
+ h, w, _ = img1.shape
148
+ points = []
149
+ colors = []
150
+ for y in range(h):
151
+ for x in range(w):
152
+ # ピクセル位置 (x, y) のRGBをXYZとして取得
153
+ r, g, b = img1[y, x]
154
+ r = inv_func(r) * 0.9
155
+ g = inv_func(g) / 1.7 * 0.6
156
+ b = inv_func(b)
157
+ r *= 150
158
+ g *= 150
159
+ b *= 150
160
+ points.append([g, b, r]) # X, Y, Z
161
+ # 対応するピクセル位置の画像2の色を取得
162
+ colors.append(img2[y, x] / 255.0) # 色は0〜1にスケール
163
+
164
+ return np.array(points), np.array(colors)
165
+
166
+ def point_cloud_to_glb(points, colors):
167
+ # Open3Dでポイントクラウドを作成
168
+ pc = o3d.geometry.PointCloud()
169
+ pc.points = o3d.utility.Vector3dVector(points)
170
+ pc.colors = o3d.utility.Vector3dVector(colors)
171
+
172
+ # 一時的にPLY形式で保存
173
+ temp_ply_file = "temp_output.ply"
174
+ o3d.io.write_point_cloud(temp_ply_file, pc)
175
+
176
+ # PLYをGLBに変換
177
+ mesh = trimesh.load(temp_ply_file)
178
+ glb_file = "output.glb"
179
+ mesh.export(glb_file)
180
+
181
+ return glb_file
182
+
183
+ def visualize_3d(image1, image2):
184
+ print("Processing...")
185
+ # PIL画像をOpenCV形式に変換
186
+ img1 = convert_pil_to_opencv(image1)
187
+ img2 = convert_pil_to_opencv(image2)
188
+
189
+ # ポイントクラウド生成
190
+ points, colors = create_point_cloud(img1, img2)
191
+
192
+ # GLB形式に変換
193
+ glb_file = point_cloud_to_glb(points, colors)
194
+
195
+ return glb_file
196
+
197
+ def scale_image(original_image):
198
+ aspect_ratio = original_image.width / original_image.height
199
+
200
+ if original_image.width > original_image.height:
201
+ new_width = 1024
202
+ new_height = round(new_width / aspect_ratio)
203
+ else:
204
+ new_height = 1024
205
+ new_width = round(new_height * aspect_ratio)
206
+
207
+ resized_original = original_image.resize((new_width, new_height), Image.LANCZOS)
208
+
209
+ return resized_original
210
+
211
+ def get_edge_mode_color(img, edge_width=10):
212
+ # 外周の10ピクセル領域を取得
213
+ left = img.crop((0, 0, edge_width, img.height)) # 左端
214
+ right = img.crop((img.width - edge_width, 0, img.width, img.height)) # 右端
215
+ top = img.crop((0, 0, img.width, edge_width)) # 上端
216
+ bottom = img.crop((0, img.height - edge_width, img.width, img.height)) # 下端
217
+
218
+ # 各領域のピクセルデータを取得して結合
219
+ colors = list(left.getdata()) + list(right.getdata()) + list(top.getdata()) + list(bottom.getdata())
220
+
221
+ # 最頻値(mode)を計算
222
+ mode_color = Counter(colors).most_common(1)[0][0] # 最も頻繁に出現する色を取得
223
+
224
+ return mode_color
225
+
226
+ def paste_image(resized_img):
227
+ # 外周10pxの最頻値を背景色に設定
228
+ mode_color = get_edge_mode_color(resized_img, edge_width=10)
229
+ mode_background = Image.new("RGBA", (1024, 1024), mode_color)
230
+ mode_background = mode_background.convert('RGB')
231
+
232
+ x = (1024 - resized_img.width) // 2
233
+ y = (1024 - resized_img.height) // 2
234
+ mode_background.paste(resized_img, (x, y))
235
+
236
+ return mode_background
237
+
238
+ def outpaint_image(image):
239
+ if type(image) == type(None):
240
+ return None
241
+ resized_img = scale_image(image)
242
+ image = paste_image(resized_img)
243
+
244
+ return image
245
 
246
  block = gr.Blocks().queue()
247
 
248
  with block:
249
  gr.Markdown("## BRIA 2.3 ControlNet Pose")
 
 
 
 
 
 
 
250
  with gr.Row():
251
  with gr.Column():
252
  input_image = gr.Image(sources=None, type="pil") # None for upload, ctrl+v and webcam
253
  prompt = gr.Textbox(label="Prompt")
254
  negative_prompt = gr.Textbox(label="Negative prompt", value="Logo,Watermark,Text,Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra limbs,Gross proportions,Missing arms,Mutated hands,Long neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad anatomy,Cloned face,Malformed limbs,Missing legs,Too many fingers")
 
255
  controlnet_conditioning_scale = gr.Slider(label="ControlNet conditioning scale", minimum=0.1, maximum=2.0, value=1.0, step=0.05)
 
256
  run_button = gr.Button(value="Run")
257
 
258
  with gr.Column():
 
260
  pose_image_output = gr.Image(label="Pose Image", type="pil", interactive=False)
261
  generated_image_output = gr.Image(label="Generated Image", type="pil", interactive=False)
262
 
263
+ run_button.click(fn=process, inputs=[input_image, prompt, negative_prompt, controlnet_conditioning_scale], outputs=[pose_image_output, generated_image_output])
 
264
 
265
 
266
  block.launch(debug = True)