macguyver commited on
Commit
19ba71c
·
1 Parent(s): 89c3093

runpod-handler

Browse files
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the specified PyTorch image with CUDA 12.1 and cuDNN 9
2
+ FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-runtime
3
+
4
+ # Install dependencies for Miniconda
5
+ RUN apt-get update && apt-get install -y \
6
+ wget \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ # Install Miniconda
10
+ RUN mkdir -p /opt/miniconda3 && \
11
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /opt/miniconda3/miniconda.sh && \
12
+ bash /opt/miniconda3/miniconda.sh -b -u -p /opt/miniconda3 && \
13
+ rm /opt/miniconda3/miniconda.sh
14
+
15
+ # Set environment variables for Conda
16
+ ENV PATH /opt/miniconda3/bin:$PATH
17
+ ENV CONDA_AUTO_UPDATE_CONDA=false
18
+
19
+ WORKDIR /opt
20
+
21
+ RUN git clone https://github.com/ACE-innovate/wefa-seg-serverless
22
+
23
+ # Copy the environment.yaml file and create the Conda environment
24
+ COPY ./anydoor/environment.yaml /tmp/environment.yaml
25
+ RUN conda env create -f /tmp/environment.yaml
26
+
27
+ # Set up the shell to use the Conda environment by default
28
+ SHELL ["conda", "run", "-n", "anydoor", "/bin/bash", "-c"]
29
+
30
+ # Default command
31
+ CMD ["/bin/bash"]
anydoor/run_inference.py CHANGED
@@ -218,7 +218,7 @@ def inference_single_image(ref_image, ref_mask, tar_image, tar_mask, guidance_sc
218
 
219
 
220
  if __name__ == '__main__':
221
- '''
222
  # ==== Example for inferring a single image ===
223
  reference_image_path = './examples/TestDreamBooth/FG/01.png'
224
  bg_image_path = './examples/TestDreamBooth/BG/000000309203_GT.png'
@@ -249,44 +249,44 @@ if __name__ == '__main__':
249
  vis_image = cv2.hconcat([ref_image, back_image, gen_image])
250
 
251
  cv2.imwrite(save_path, vis_image [:,:,::-1])
252
- '''
253
- #'''
254
- # ==== Example for inferring VITON-HD Test dataset ===
255
-
256
- from omegaconf import OmegaConf
257
- import os
258
- DConf = OmegaConf.load('./configs/datasets.yaml')
259
- save_dir = '../INFERRED_TRAINED'
260
- if not os.path.exists(save_dir):
261
- os.mkdir(save_dir)
262
-
263
- test_dir = DConf.Test.VitonHDTest.image_dir
264
- image_names = os.listdir(test_dir)
265
 
266
- for image_name in image_names[:10]:
267
- ref_image_path = os.path.join(test_dir, image_name)
268
- tar_image_path = ref_image_path.replace('/cloth/', '/image/')
269
- ref_mask_path = ref_image_path.replace('/cloth/','/cloth-mask/')
270
- tar_mask_path = ref_image_path.replace('/cloth/', '/image-parse-v3/').replace('.jpg','.png')
271
 
272
- ref_image = cv2.imread(ref_image_path)
273
- ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB)
274
 
275
- gt_image = cv2.imread(tar_image_path)
276
- gt_image = cv2.cvtColor(gt_image, cv2.COLOR_BGR2RGB)
277
 
278
- ref_mask = (cv2.imread(ref_mask_path) > 128).astype(np.uint8)[:,:,0]
279
 
280
- tar_mask = Image.open(tar_mask_path ).convert('P')
281
- tar_mask= np.array(tar_mask)
282
- tar_mask = tar_mask == 5
283
 
284
- gen_image = inference_single_image(ref_image, ref_mask, gt_image.copy(), tar_mask)
285
- gen_path = os.path.join(save_dir, image_name)
286
 
287
- vis_image = cv2.hconcat([ref_image, gt_image, gen_image])
288
- cv2.imwrite(gen_path, vis_image[:,:,::-1])
289
- #'''
290
 
291
 
292
 
 
218
 
219
 
220
  if __name__ == '__main__':
221
+ # '''
222
  # ==== Example for inferring a single image ===
223
  reference_image_path = './examples/TestDreamBooth/FG/01.png'
224
  bg_image_path = './examples/TestDreamBooth/BG/000000309203_GT.png'
 
249
  vis_image = cv2.hconcat([ref_image, back_image, gen_image])
250
 
251
  cv2.imwrite(save_path, vis_image [:,:,::-1])
252
+ # '''
253
+ # #'''
254
+ # # ==== Example for inferring VITON-HD Test dataset ===
255
+
256
+ # from omegaconf import OmegaConf
257
+ # import os
258
+ # DConf = OmegaConf.load('./configs/datasets.yaml')
259
+ # save_dir = '../INFERRED_TRAINED'
260
+ # if not os.path.exists(save_dir):
261
+ # os.mkdir(save_dir)
262
+
263
+ # test_dir = DConf.Test.VitonHDTest.image_dir
264
+ # image_names = os.listdir(test_dir)
265
 
266
+ # for image_name in image_names[:10]:
267
+ # ref_image_path = os.path.join(test_dir, image_name)
268
+ # tar_image_path = ref_image_path.replace('/cloth/', '/image/')
269
+ # ref_mask_path = ref_image_path.replace('/cloth/','/cloth-mask/')
270
+ # tar_mask_path = ref_image_path.replace('/cloth/', '/image-parse-v3/').replace('.jpg','.png')
271
 
272
+ # ref_image = cv2.imread(ref_image_path)
273
+ # ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB)
274
 
275
+ # gt_image = cv2.imread(tar_image_path)
276
+ # gt_image = cv2.cvtColor(gt_image, cv2.COLOR_BGR2RGB)
277
 
278
+ # ref_mask = (cv2.imread(ref_mask_path) > 128).astype(np.uint8)[:,:,0]
279
 
280
+ # tar_mask = Image.open(tar_mask_path ).convert('P')
281
+ # tar_mask= np.array(tar_mask)
282
+ # tar_mask = tar_mask == 5
283
 
284
+ # gen_image = inference_single_image(ref_image, ref_mask, gt_image.copy(), tar_mask)
285
+ # gen_path = os.path.join(save_dir, image_name)
286
 
287
+ # vis_image = cv2.hconcat([ref_image, gt_image, gen_image])
288
+ # cv2.imwrite(gen_path, vis_image[:,:,::-1])
289
+ # #'''
290
 
291
 
292
 
anydoor/run_inference_api_select.py CHANGED
@@ -229,9 +229,8 @@ def inference_single_image(ref_image, ref_mask, tar_image, tar_mask, guidance_sc
229
  import cv2
230
  import numpy as np
231
  import base64
232
- import os
233
- from http.server import BaseHTTPRequestHandler, HTTPServer
234
  import json
 
235
  from io import BytesIO
236
  from PIL import Image
237
 
@@ -242,251 +241,60 @@ def base64_to_cv2_image(base64_str):
242
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
243
  return img
244
 
245
- def base64_to_pil_image(base64_str):
246
- img_data = base64.b64decode(base64_str)
247
- img = Image.open(BytesIO(img_data))
248
- return img
249
-
250
- def pil_image_to_np_array(pil_img, target_index):
251
- np_array = np.array(pil_img)
252
- return (np_array == target_index).astype(np.uint8)
253
-
254
  def image_to_base64(img):
255
  img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
256
  _, buffer = cv2.imencode('.jpg', img)
257
  base64_str = base64.b64encode(buffer).decode("utf-8")
258
  return base64_str
259
 
260
- class RequestHandler(BaseHTTPRequestHandler):
261
- API_KEY = "xiCQTaoQKXUNATzuFLWRgtoJKiFXiDGvnk"
262
-
263
- def _set_response(self, status_code=200, content_type='application/json'):
264
- self.send_response(status_code)
265
- self.send_header('Content-type', content_type)
266
- self.send_header('Access-Control-Allow-Origin', '*')
267
- self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS')
268
- self.send_header('Access-Control-Allow-Headers', 'X-API-Key, Content-Type')
269
- self.end_headers()
270
-
271
- def do_OPTIONS(self):
272
- self._set_response(204)
273
-
274
- def do_GET(self):
275
- self._set_response(405)
276
- self.wfile.write(b'{"error": "GET method not allowed."}')
277
-
278
- def handle_not_supported_method(self):
279
- self._set_response(405)
280
- self.wfile.write(b'{"error": "Method not supported."}')
281
-
282
- def do_PUT(self):
283
- self.handle_not_supported_method()
284
-
285
- def do_DELETE(self):
286
- self.handle_not_supported_method()
287
-
288
- def do_PATCH(self):
289
- self.handle_not_supported_method()
290
-
291
- def do_POST(self):
292
- print("Received POST request...")
293
- received_api_key = self.headers.get('X-API-Key')
294
-
295
- if received_api_key != self.API_KEY:
296
- self._set_response(401)
297
- self.wfile.write(b'{"error": "Invalid API key"}')
298
- print("Invalid API key")
299
- return
300
-
301
- content_length = int(self.headers['Content-Length'])
302
- print(f"Content Length: {content_length}")
303
-
304
- if content_length:
305
- post_data = self.rfile.read(content_length)
306
- print("Data received")
307
- try:
308
- data = json.loads(post_data.decode('utf-8'))
309
- print("Processing data")
310
-
311
- model_name = data.get('model', 'default_model.ckpt')
312
- model_ckpt_map = {
313
- 'boys': 'boys.ckpt',
314
- 'men': 'men.ckpt',
315
- 'women': 'women.ckpt',
316
- 'girls': 'girls.ckpt'
317
- }
318
- new_model_ckpt = model_ckpt_map.get(model_name, current_model_ckpt)
319
- load_model(new_model_ckpt)
320
-
321
- seed = int(data.get('seed'))
322
- steps = int(data.get('steps'))
323
- guidance_scale = float(data.get('guidance_scale'))
324
-
325
- ref_image = base64_to_cv2_image(data['ref_image'])
326
- tar_image = base64_to_cv2_image(data['tar_image'])
327
-
328
- ref_mask_img = base64_to_cv2_image(data['ref_mask'])
329
- ref_mask = cv2.cvtColor(ref_mask_img, cv2.COLOR_RGB2GRAY)
330
- ref_mask = (ref_mask > 128).astype(np.uint8)
331
-
332
- tar_mask_img = base64_to_cv2_image(data['tar_mask'])
333
- tar_mask = cv2.cvtColor(tar_mask_img, cv2.COLOR_RGB2GRAY)
334
- tar_mask = (tar_mask > 128).astype(np.uint8)
335
-
336
- gen_image = inference_single_image(ref_image, ref_mask, tar_image, tar_mask, guidance_scale, seed, steps)
337
- gen_image_base64 = image_to_base64(gen_image)
338
-
339
- self.send_response(200)
340
- self.send_header('Content-Type', 'image/jpeg')
341
- self.end_headers()
342
- self.wfile.write(base64.b64decode(gen_image_base64))
343
-
344
- print("Sent image response")
345
-
346
- except Exception as e:
347
- print(f"An error occurred: {e}")
348
- self._set_response(500)
349
- error_data = json.dumps({'error': str(e)}).encode('utf-8')
350
- self.wfile.write(error_data)
351
- print("Sent error response")
352
-
353
- else:
354
- print("No data received in POST request.")
355
- self._set_response(400)
356
- error_data = json.dumps({'error': 'No data received'}).encode('utf-8')
357
- self.wfile.write(error_data)
358
- print("Sent error response")
359
-
360
- def run(server_class=HTTPServer, handler_class=RequestHandler, port=8084):
361
- server_address = ('', port)
362
- httpd = server_class(server_address, handler_class)
363
- print(f"Starting HTTP server on port {port}")
364
- httpd.serve_forever()
365
 
366
  if __name__ == "__main__":
367
- run()
368
-
369
- # class RequestHandler(BaseHTTPRequestHandler):
370
- # API_KEY = "xiCQTaoQKXUNATzuFLWRgtoJKiFXiDGvnk"
371
-
372
- # def _set_response(self, status_code=200, content_type='application/json'):
373
- # self.send_response(status_code)
374
- # self.send_header('Content-type', content_type)
375
- # self.send_header('Access-Control-Allow-Origin', '*')
376
- # self.send_header('Access-Control-Allow-Methods', 'GET, POST, OPTIONS')
377
- # self.send_header('Access-Control-Allow-Headers', 'X-API-Key, Content-Type')
378
- # self.end_headers()
379
-
380
- # def do_OPTIONS(self):
381
- # self._set_response(204) # No content to send back for OPTIONS request
382
-
383
- # def do_GET(self):
384
- # # If needed, define handling for GET or send a 405 if it's not supported
385
- # self._set_response(405)
386
- # self.wfile.write(b'{"error": "GET method not allowed."}')
387
-
388
- # def handle_not_supported_method(self):
389
- # self._set_response(405)
390
- # self.wfile.write(b'{"error": "Method not supported."}')
391
-
392
- # def do_PUT(self):
393
- # self.handle_not_supported_method()
394
-
395
- # def do_DELETE(self):
396
- # self.handle_not_supported_method()
397
-
398
- # def do_PATCH(self):
399
- # self.handle_not_supported_method()
400
-
401
- # def do_POST(self):
402
- # print("Received POST request...")
403
- # received_api_key = self.headers.get('X-API-Key')
404
- # # Check if the API key is correct
405
- # if received_api_key != self.API_KEY:
406
- # # If the API key is incorrect, respond with 401 Unauthorized
407
- # self._set_response(401)
408
- # self.wfile.write(b'{"error": "Invalid API key"}')
409
- # print("Invalid API key")
410
- # return
411
-
412
- # content_length = int(self.headers['Content-Length'])
413
- # print(f"Content Length: {content_length}")
414
-
415
- # if content_length:
416
- # post_data = self.rfile.read(content_length)
417
- # print("Data received")
418
- # try:
419
- # data = json.loads(post_data.decode('utf-8'))
420
- # print("Processing data")
421
- # # print(data)
422
-
423
- # seed = int(data.get('seed'))
424
- # steps = int(data.get('steps'))
425
- # guidance_scale = float(data.get('guidance_scale'))
426
-
427
- # ref_image = base64_to_cv2_image(data['ref_image'])
428
- # tar_image = base64_to_cv2_image(data['tar_image'])
429
- # # print(seed)
430
- # # print(steps)
431
- # # print(guidance_scale)
432
- # # Process reference mask
433
- # ref_mask_img = base64_to_cv2_image(data['ref_mask'])
434
- # ref_mask = cv2.cvtColor(ref_mask_img, cv2.COLOR_RGB2GRAY)
435
- # ref_mask = (ref_mask > 128).astype(np.uint8)
436
-
437
- # # Process target mask
438
- # tar_mask_img = base64_to_cv2_image(data['tar_mask'])
439
- # tar_mask = cv2.cvtColor(tar_mask_img, cv2.COLOR_RGB2GRAY)
440
- # tar_mask = (tar_mask > 128).astype(np.uint8)
441
-
442
- # output_dir = '/work/ADOOR_ACE/test_out'
443
- # os.makedirs(output_dir, exist_ok=True)
444
-
445
- # # Save reference and target images
446
- # cv2.imwrite(os.path.join(output_dir, 'out_ref_image.jpg'), cv2.cvtColor(ref_image, cv2.COLOR_RGB2BGR))
447
- # cv2.imwrite(os.path.join(output_dir, 'out_tar_image.jpg'), cv2.cvtColor(tar_image, cv2.COLOR_RGB2BGR))
448
-
449
- # # Save reference mask
450
- # ref_mask_img_to_save = (ref_mask * 255).astype(np.uint8)
451
- # cv2.imwrite(os.path.join(output_dir, 'out_ref_mask.jpg'), ref_mask_img_to_save)
452
-
453
- # # Save target mask
454
- # tar_mask_img_to_save = (tar_mask * 255).astype(np.uint8)
455
- # cv2.imwrite(os.path.join(output_dir,'out_tar_mask.jpg'), tar_mask_img_to_save)
456
-
457
- # gen_image = inference_single_image(ref_image, ref_mask, tar_image, tar_mask, guidance_scale, seed, steps)
458
- # gen_image_base64 = image_to_base64(gen_image)
459
-
460
- # self.send_response(200)
461
- # self.send_header('Content-Type', 'image/jpeg')
462
- # self.end_headers()
463
- # self.wfile.write(base64.b64decode(gen_image_base64))
464
-
465
- # print("Sent image response")
466
-
467
- # except Exception as e:
468
- # print(f"An error occurred: {e}")
469
- # self._set_response(500)
470
- # error_data = json.dumps({'error': str(e)}).encode('utf-8')
471
- # self.wfile.write(error_data)
472
- # print("Sent error response")
473
-
474
- # else:
475
- # print("No data received in POST request.")
476
- # self._set_response(400)
477
- # error_data = json.dumps({'error': 'No data received'}).encode('utf-8')
478
- # self.wfile.write(error_data)
479
- # print("Sent error response")
480
-
481
-
482
-
483
- # def run(server_class=HTTPServer, handler_class=RequestHandler, port=8084):
484
- # server_address = ('', port)
485
- # httpd = server_class(server_address, handler_class)
486
- # print(f"Starting HTTP server on port {port}")
487
- # httpd.serve_forever()
488
-
489
- # if __name__ == "__main__":
490
- # run()
491
-
492
-
 
229
  import cv2
230
  import numpy as np
231
  import base64
 
 
232
  import json
233
+ import sys
234
  from io import BytesIO
235
  from PIL import Image
236
 
 
241
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
242
  return img
243
 
 
 
 
 
 
 
 
 
 
244
  def image_to_base64(img):
245
  img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
246
  _, buffer = cv2.imencode('.jpg', img)
247
  base64_str = base64.b64encode(buffer).decode("utf-8")
248
  return base64_str
249
 
250
+ def inference_single_image(ref_image, ref_mask, tar_image, tar_mask, guidance_scale, seed, steps):
251
+ # Replace this with your image processing model function
252
+ # Placeholder operation (e.g., blending images for demonstration)
253
+ np.random.seed(seed)
254
+ output_img = cv2.addWeighted(ref_image, 0.5, tar_image, 0.5, 0)
255
+ return output_img
256
+
257
+ def process_images(data):
258
+ model_name = data.get('model', 'default_model.ckpt')
259
+ model_ckpt_map = {
260
+ 'boys': 'boys.ckpt',
261
+ 'men': 'men.ckpt',
262
+ 'women': 'women.ckpt',
263
+ 'girls': 'girls.ckpt'
264
+ }
265
+ current_model_ckpt = 'default_model.ckpt'
266
+ new_model_ckpt = model_ckpt_map.get(model_name, current_model_ckpt)
267
+ # load_model(new_model_ckpt) # Load model if needed
268
+
269
+ seed = int(data.get('seed', 42))
270
+ steps = int(data.get('steps', 50))
271
+ guidance_scale = float(data.get('guidance_scale', 1.0))
272
+
273
+ ref_image = base64_to_cv2_image(data['ref_image'])
274
+ tar_image = base64_to_cv2_image(data['tar_image'])
275
+
276
+ ref_mask_img = base64_to_cv2_image(data['ref_mask'])
277
+ ref_mask = cv2.cvtColor(ref_mask_img, cv2.COLOR_RGB2GRAY)
278
+ ref_mask = (ref_mask > 128).astype(np.uint8)
279
+
280
+ tar_mask_img = base64_to_cv2_image(data['tar_mask'])
281
+ tar_mask = cv2.cvtColor(tar_mask_img, cv2.COLOR_RGB2GRAY)
282
+ tar_mask = (tar_mask > 128).astype(np.uint8)
283
+
284
+ gen_image = inference_single_image(ref_image, ref_mask, tar_image, tar_mask, guidance_scale, seed, steps)
285
+ gen_image_base64 = image_to_base64(gen_image)
286
+ return gen_image_base64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  if __name__ == "__main__":
289
+ if len(sys.argv) < 2:
290
+ print("Usage: python script.py '<json_data>'")
291
+ sys.exit(1)
292
+
293
+ # Read JSON data from command line argument
294
+ json_data = sys.argv[1]
295
+ try:
296
+ data = json.loads(json_data)
297
+ result_image_base64 = process_images(data)
298
+ print(result_image_base64)
299
+ except Exception as e:
300
+ print(f"Error processing images: {e}", file=sys.stderr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
anydoor/run_inference_runpod.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import einops
3
+ import numpy as np
4
+ import torch
5
+ import random
6
+ from pytorch_lightning import seed_everything
7
+ from cldm.model import create_model, load_state_dict
8
+ from cldm.ddim_hacked import DDIMSampler
9
+ from cldm.hack import disable_verbosity, enable_sliced_attention
10
+ from datasets.data_utils import *
11
+ cv2.setNumThreads(0)
12
+ cv2.ocl.setUseOpenCL(False)
13
+ import albumentations as A
14
+ from omegaconf import OmegaConf
15
+ from PIL import Image
16
+
17
+ save_memory = True
18
+ disable_verbosity()
19
+ if save_memory:
20
+ enable_sliced_attention()
21
+
22
+ config = OmegaConf.load('./configs/inference.yaml')
23
+ current_model_ckpt = config.pretrained_model
24
+ model_config = config.config_file
25
+
26
+ model = create_model(model_config).cpu()
27
+ model.load_state_dict(load_state_dict(current_model_ckpt, location='cuda'))
28
+ model = model.cuda()
29
+ ddim_sampler = DDIMSampler(model)
30
+
31
+ def load_model(new_model_ckpt):
32
+ global model, ddim_sampler, current_model_ckpt
33
+ if new_model_ckpt != current_model_ckpt:
34
+ print(f"Loading new model: {new_model_ckpt}")
35
+ model.load_state_dict(load_state_dict(f'/workspace/train-wefadoor-master/anydoor/lightning_logs/version_1/checkpoints/epoch=1-step=2499.ckpt', location='cuda'))
36
+ # model.load_state_dict(load_state_dict(f'/workspace/300k_wefa_boys_slim/lightning_logs/version_0/checkpoints/{new_model_ckpt}', location='cuda'))
37
+ current_model_ckpt = new_model_ckpt
38
+ print("New model loaded successfully.")
39
+ else:
40
+ print("Same model is already loaded, skipping reload.")
41
+
42
+ def aug_data_mask(image, mask):
43
+ transform = A.Compose([
44
+ A.HorizontalFlip(p=0.5),
45
+ A.RandomBrightnessContrast(p=0.5),
46
+ ])
47
+ transformed = transform(image=image.astype(np.uint8), mask = mask)
48
+ transformed_image = transformed["image"]
49
+ transformed_mask = transformed["mask"]
50
+ return transformed_image, transformed_mask
51
+
52
+
53
+ def process_pairs(ref_image, ref_mask, tar_image, tar_mask):
54
+ # ========= Reference ===========
55
+ # ref expand
56
+ ref_box_yyxx = get_bbox_from_mask(ref_mask)
57
+
58
+ # ref filter mask
59
+ ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1)
60
+ masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1-ref_mask_3)
61
+
62
+ y1,y2,x1,x2 = ref_box_yyxx
63
+ masked_ref_image = masked_ref_image[y1:y2,x1:x2,:]
64
+ ref_mask = ref_mask[y1:y2,x1:x2]
65
+
66
+
67
+ ratio = np.random.randint(12, 13) / 10
68
+ masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio)
69
+ ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1)
70
+
71
+ # to square and resize
72
+ masked_ref_image = pad_to_square(masked_ref_image, pad_value = 255, random = False)
73
+ masked_ref_image = cv2.resize(masked_ref_image, (224,224) ).astype(np.uint8)
74
+
75
+ ref_mask_3 = pad_to_square(ref_mask_3 * 255, pad_value = 0, random = False)
76
+ ref_mask_3 = cv2.resize(ref_mask_3, (224,224) ).astype(np.uint8)
77
+ ref_mask = ref_mask_3[:,:,0]
78
+
79
+ # ref aug
80
+ masked_ref_image_aug = masked_ref_image #aug_data(masked_ref_image)
81
+
82
+ # collage aug
83
+ masked_ref_image_compose, ref_mask_compose = masked_ref_image, ref_mask #aug_data_mask(masked_ref_image, ref_mask)
84
+ masked_ref_image_aug = masked_ref_image_compose.copy()
85
+ ref_mask_3 = np.stack([ref_mask_compose,ref_mask_compose,ref_mask_compose],-1)
86
+ ref_image_collage = sobel(masked_ref_image_compose, ref_mask_compose/255)
87
+
88
+ # ========= Target ===========
89
+ tar_box_yyxx = get_bbox_from_mask(tar_mask)
90
+ tar_box_yyxx = expand_bbox(tar_mask, tar_box_yyxx, ratio=[1.1,1.2])
91
+
92
+ # crop
93
+ tar_box_yyxx_crop = expand_bbox(tar_image, tar_box_yyxx, ratio=[1.5, 3]) #1.2 1.6
94
+ tar_box_yyxx_crop = box2squre(tar_image, tar_box_yyxx_crop) # crop box
95
+ y1,y2,x1,x2 = tar_box_yyxx_crop
96
+
97
+ cropped_target_image = tar_image[y1:y2,x1:x2,:]
98
+ tar_box_yyxx = box_in_box(tar_box_yyxx, tar_box_yyxx_crop)
99
+ y1,y2,x1,x2 = tar_box_yyxx
100
+
101
+ # collage
102
+ ref_image_collage = cv2.resize(ref_image_collage, (x2-x1, y2-y1))
103
+ ref_mask_compose = cv2.resize(ref_mask_compose.astype(np.uint8), (x2-x1, y2-y1))
104
+ ref_mask_compose = (ref_mask_compose > 128).astype(np.uint8)
105
+
106
+ collage = cropped_target_image.copy()
107
+ collage[y1:y2,x1:x2,:] = ref_image_collage
108
+
109
+ collage_mask = cropped_target_image.copy() * 0.0
110
+ collage_mask[y1:y2,x1:x2,:] = 1.0
111
+
112
+ # the size before pad
113
+ H1, W1 = collage.shape[0], collage.shape[1]
114
+ cropped_target_image = pad_to_square(cropped_target_image, pad_value = 0, random = False).astype(np.uint8)
115
+ collage = pad_to_square(collage, pad_value = 0, random = False).astype(np.uint8)
116
+ collage_mask = pad_to_square(collage_mask, pad_value = -1, random = False).astype(np.uint8)
117
+
118
+ # the size after pad
119
+ H2, W2 = collage.shape[0], collage.shape[1]
120
+ cropped_target_image = cv2.resize(cropped_target_image, (512,512)).astype(np.float32)
121
+ collage = cv2.resize(collage, (512,512)).astype(np.float32)
122
+ collage_mask = (cv2.resize(collage_mask, (512,512)).astype(np.float32) > 0.5).astype(np.float32)
123
+
124
+ masked_ref_image_aug = masked_ref_image_aug / 255
125
+ cropped_target_image = cropped_target_image / 127.5 - 1.0
126
+ collage = collage / 127.5 - 1.0
127
+ collage = np.concatenate([collage, collage_mask[:,:,:1] ] , -1)
128
+
129
+ item = dict(ref=masked_ref_image_aug.copy(), jpg=cropped_target_image.copy(), hint=collage.copy(), extra_sizes=np.array([H1, W1, H2, W2]), tar_box_yyxx_crop=np.array( tar_box_yyxx_crop ) )
130
+ return item
131
+
132
+
133
+ def crop_back( pred, tar_image, extra_sizes, tar_box_yyxx_crop):
134
+ H1, W1, H2, W2 = extra_sizes
135
+ y1,y2,x1,x2 = tar_box_yyxx_crop
136
+ pred = cv2.resize(pred, (W2, H2))
137
+ m = 5 # maigin_pixel
138
+
139
+ if W1 == H1:
140
+ tar_image[y1+m :y2-m, x1+m:x2-m, :] = pred[m:-m, m:-m]
141
+ return tar_image
142
+
143
+ if W1 < W2:
144
+ pad1 = int((W2 - W1) / 2)
145
+ pad2 = W2 - W1 - pad1
146
+ pred = pred[:,pad1: -pad2, :]
147
+ else:
148
+ pad1 = int((H2 - H1) / 2)
149
+ pad2 = H2 - H1 - pad1
150
+ pred = pred[pad1: -pad2, :, :]
151
+
152
+ gen_image = tar_image.copy()
153
+ gen_image[y1+m :y2-m, x1+m:x2-m, :] = pred[m:-m, m:-m]
154
+ return gen_image
155
+
156
+
157
+ def inference_single_image(ref_image, ref_mask, tar_image, tar_mask, guidance_scale, seed, steps):
158
+ item = process_pairs(ref_image, ref_mask, tar_image, tar_mask)
159
+ ref = item['ref'] * 255
160
+ tar = item['jpg'] * 127.5 + 127.5
161
+ hint = item['hint'] * 127.5 + 127.5
162
+
163
+ hint_image = hint[:,:,:-1]
164
+ hint_mask = item['hint'][:,:,-1] * 255
165
+ hint_mask = np.stack([hint_mask,hint_mask,hint_mask],-1)
166
+ ref = cv2.resize(ref.astype(np.uint8), (512,512))
167
+
168
+ seed = random.randint(0, 65535)
169
+ if save_memory:
170
+ model.low_vram_shift(is_diffusing=False)
171
+
172
+ ref = item['ref']
173
+ tar = item['jpg']
174
+ hint = item['hint']
175
+ num_samples = 1
176
+
177
+ control = torch.from_numpy(hint.copy()).float().cuda()
178
+ control = torch.stack([control for _ in range(num_samples)], dim=0)
179
+ control = einops.rearrange(control, 'b h w c -> b c h w').clone()
180
+
181
+
182
+ clip_input = torch.from_numpy(ref.copy()).float().cuda()
183
+ clip_input = torch.stack([clip_input for _ in range(num_samples)], dim=0)
184
+ clip_input = einops.rearrange(clip_input, 'b h w c -> b c h w').clone()
185
+
186
+ guess_mode = False
187
+ H,W = 512,512
188
+
189
+ cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning( clip_input )]}
190
+ un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([torch.zeros((1,3,224,224))] * num_samples)]}
191
+ shape = (4, H // 8, W // 8)
192
+
193
+ if save_memory:
194
+ model.low_vram_shift(is_diffusing=True)
195
+
196
+ # ====
197
+ num_samples = 1 #gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
198
+ image_resolution = 512 #gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
199
+ strength = 1 #gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
200
+ guess_mode = False #gr.Checkbox(label='Guess Mode', value=False)
201
+ #detect_resolution = 512 #gr.Slider(label="Segmentation Resolution", minimum=128, maximum=1024, value=512, step=1)
202
+ ddim_steps = steps #gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
203
+ scale = guidance_scale #gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
204
+ seed = seed #gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
205
+ eta = 0.0 #gr.Number(label="eta (DDIM)", value=0.0)
206
+
207
+ model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
208
+ samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
209
+ shape, cond, verbose=False, eta=eta,
210
+ unconditional_guidance_scale=scale,
211
+ unconditional_conditioning=un_cond)
212
+ if save_memory:
213
+ model.low_vram_shift(is_diffusing=False)
214
+
215
+ x_samples = model.decode_first_stage(samples)
216
+ x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy()#.clip(0, 255).astype(np.uint8)
217
+
218
+ result = x_samples[0][:,:,::-1]
219
+ result = np.clip(result,0,255)
220
+
221
+ pred = x_samples[0]
222
+ pred = np.clip(pred,0,255)[1:,:,:]
223
+ sizes = item['extra_sizes']
224
+ tar_box_yyxx_crop = item['tar_box_yyxx_crop']
225
+ gen_image = crop_back(pred, tar_image, sizes, tar_box_yyxx_crop)
226
+ return gen_image
227
+
228
+ import cv2
229
+ import numpy as np
230
+ import base64
231
+ import json
232
+ import sys
233
+ from io import BytesIO
234
+ from PIL import Image
235
+
236
+ def base64_to_cv2_image(base64_str):
237
+ img_str = base64.b64decode(base64_str)
238
+ np_img = np.frombuffer(img_str, dtype=np.uint8)
239
+ img = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
240
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
241
+ return img
242
+
243
+ def image_to_base64(img):
244
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
245
+ _, buffer = cv2.imencode('.jpg', img)
246
+ base64_str = base64.b64encode(buffer).decode("utf-8")
247
+ return base64_str
248
+
249
+ def process_images(data):
250
+ model_name = data.get('model', './step_357500_slim.ckpt')
251
+ model_ckpt_map = {
252
+ 'boys': 'boys.ckpt',
253
+ 'men': 'men.ckpt',
254
+ 'women': 'women.ckpt',
255
+ 'girls': 'girls.ckpt'
256
+ }
257
+ current_model_ckpt = './step_357500_slim.ckpt'
258
+ new_model_ckpt = model_ckpt_map.get(model_name, current_model_ckpt)
259
+ load_model(new_model_ckpt) # Load model if needed
260
+
261
+ seed = int(data.get('seed', 1351352))
262
+ steps = int(data.get('steps', 50))
263
+ guidance_scale = float(data.get('guidance_scale', 3.0))
264
+
265
+ ref_image = base64_to_cv2_image(data['ref_image'])
266
+ tar_image = base64_to_cv2_image(data['tar_image'])
267
+
268
+ ref_mask_img = base64_to_cv2_image(data['ref_mask'])
269
+ ref_mask = cv2.cvtColor(ref_mask_img, cv2.COLOR_RGB2GRAY)
270
+ ref_mask = (ref_mask > 128).astype(np.uint8)
271
+
272
+ tar_mask_img = base64_to_cv2_image(data['tar_mask'])
273
+ tar_mask = cv2.cvtColor(tar_mask_img, cv2.COLOR_RGB2GRAY)
274
+ tar_mask = (tar_mask > 128).astype(np.uint8)
275
+
276
+ gen_image = inference_single_image(ref_image, ref_mask, tar_image, tar_mask, guidance_scale, seed, steps)
277
+ gen_image_base64 = image_to_base64(gen_image)
278
+ return gen_image_base64
279
+
280
+ # Define the handler function for RunPod
281
+ def handler(job):
282
+ # Access input data from the job
283
+ job_input = job["input"]
284
+
285
+ try:
286
+ # Process the images using the provided data
287
+ result_image_base64 = process_images(job_input)
288
+ return {"status": "success", "output": result_image_base64}
289
+ except Exception as e:
290
+ return {"status": "error", "message": str(e)}
291
+
292
+ # Start the serverless handler with RunPod
293
+ runpod.serverless.start({"handler": handler})