liuyizhang commited on
Commit
2262ea9
·
1 Parent(s): 2ef873f

update app.py and delete api_client.py

Browse files
Files changed (2) hide show
  1. api_client.py +0 -85
  2. app.py +2 -133
api_client.py DELETED
@@ -1,85 +0,0 @@
1
- import requests, json
2
- from PIL import Image
3
- import matplotlib.pyplot as plt
4
- import numpy as np
5
- import base64
6
- import io
7
-
8
- def request_post(url, data, timeout=600, headers = None):
9
- if headers is None:
10
- headers = {
11
- # 'content-type': 'application/json'
12
- # 'Connection': 'keep-alive',
13
- 'Accept': '*/*', # 接受任何类型的返回数据
14
- 'Content-Type': 'application/json;charset=UTF-8', # 发送数据为json
15
- # 'Content-Length': '156',
16
- # 'Accept-Encoding': 'gzip, deflate',
17
- # 'Accept-Language': 'zh-CN,zh;q=0.9',
18
- # 'User-Agent': 'SamClub/5.0.45 (iPhone; iOS 15.4; Scale/3.00)',
19
- # 'device-name': 'iPhone14,3',
20
- # 'device-os-version': '15.4',
21
- # 'device-type': 'ios',
22
- # 'auth-token': authtoken,
23
- # 'app-version': '5.0.45.1'
24
- }
25
- try:
26
- response = requests.post(url=url, headers=headers, data=json.dumps(data), timeout=timeout)
27
- response_data = response.json()
28
- return response_data
29
- except Exception as e:
30
- print(f'request_post[Error]:' + str(e))
31
- print(f'url: {url}')
32
- print(f'data: {data}')
33
- print(f'response: {response}')
34
- return None
35
-
36
- url = "http://127.0.0.1:7860/imgCLeaner"
37
-
38
- def imgFile_to_base64(image_file):
39
- with open(image_file, "rb") as f:
40
- im_bytes = f.read()
41
- im_b64_encode = base64.b64encode(im_bytes)
42
- im_b64 = im_b64_encode.decode("utf8")
43
- return im_b64
44
-
45
- def base64_to_bytes(im_b64):
46
- im_b64_encode = im_b64.encode("utf-8")
47
- im_bytes = base64.b64decode(im_b64_encode)
48
- return im_bytes
49
-
50
- def base64_to_PILImage(im_b64):
51
- im_bytes = base64_to_bytes(im_b64)
52
- pil_img = Image.open(io.BytesIO(im_bytes))
53
- return pil_img
54
-
55
- def cleaner_img(image_file, remove_texts, mask_extend=20, disp_debug=True):
56
- data = {'remove_texts': remove_texts,
57
- 'mask_extend': mask_extend,
58
- 'img': imgFile_to_base64(image_file),
59
- }
60
- ret = request_post(url, data, timeout=600, headers = None)
61
- if ret['code'] == 0:
62
- if disp_debug:
63
- for img in ret['result']['imgs']:
64
- pilImage = base64_to_PILImage(img)
65
- plt.imshow(pilImage)
66
- plt.show()
67
- plt.clf()
68
- plt.close('all')
69
- img_len = len(ret['result']['imgs'])
70
- pilImage = base64_to_PILImage(ret['result']['imgs'][img_len-1])
71
- else:
72
- pilImage = None
73
- return pilImage, ret
74
-
75
- image_file = 'dog.png'
76
- remove_texts = "小狗 . 椅子"
77
-
78
- mask_extend = 20
79
- pil_image, ret = cleaner_img(image_file, remove_texts, mask_extend, disp_debug=False)
80
-
81
- plt.imshow(pil_image)
82
- plt.show()
83
- plt.clf()
84
- plt.close()
85
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -4,15 +4,7 @@ warnings.filterwarnings('ignore')
4
 
5
  import subprocess, io, os, sys, time
6
 
7
- run_gradio = False
8
- if os.environ.get('IS_MY_DEBUG') is None:
9
- run_gradio = True
10
- else:
11
- run_gradio = False
12
- # run_gradio = True
13
-
14
- if run_gradio:
15
- os.system("pip install gradio==3.50.2")
16
 
17
  import gradio as gr
18
  from loguru import logger
@@ -982,124 +974,6 @@ def main_gradio(args):
982
  computer_info()
983
  block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)
984
 
985
- import signal
986
- import json
987
- from datetime import date, datetime, timedelta
988
- from gevent import pywsgi
989
- import base64
990
-
991
- def imgFile_to_base64(image_file):
992
- with open(image_file, "rb") as f:
993
- im_bytes = f.read()
994
- im_b64_encode = base64.b64encode(im_bytes)
995
- im_b64 = im_b64_encode.decode("utf8")
996
- return im_b64
997
-
998
- def base64_to_bytes(im_b64):
999
- im_b64_encode = im_b64.encode("utf-8")
1000
- im_bytes = base64.b64decode(im_b64_encode)
1001
- return im_bytes
1002
-
1003
- def base64_to_PILImage(im_b64):
1004
- im_bytes = base64_to_bytes(im_b64)
1005
- pil_img = Image.open(io.BytesIO(im_bytes))
1006
- return pil_img
1007
-
1008
- class API_Starter:
1009
- def __init__(self):
1010
- from flask import Flask, request, jsonify, make_response
1011
- from flask_cors import CORS, cross_origin
1012
- import logging
1013
-
1014
- app = Flask(__name__)
1015
- app.logger.setLevel(logging.ERROR)
1016
- CORS(app, supports_credentials=True, resources={r"/*": {"origins": "*"}})
1017
-
1018
- @app.route('/imgCLeaner', methods=['GET', 'POST'])
1019
- @cross_origin()
1020
- def processAssist():
1021
- if request.method == 'GET':
1022
- ret_json = {'code': -1, 'reason':'no support to get'}
1023
- elif request.method == 'POST':
1024
- request_data = request.data.decode('utf-8')
1025
- data = json.loads(request_data)
1026
- result = self.handle_data(data)
1027
- if result is None:
1028
- ret_json = {'code': -2, 'reason':'handle error'}
1029
- else:
1030
- ret_json = {'code': 0, 'result':result}
1031
- return jsonify(ret_json)
1032
-
1033
- self.app = app
1034
- now_time = datetime.now().strftime('%Y%m%d_%H%M%S')
1035
- logger.add(f'./logs/logger_[{args.port}]_{now_time}.log')
1036
- signal.signal(signal.SIGINT, self.signal_handler)
1037
-
1038
- def handle_data(self, data):
1039
- im_b64 = data['img']
1040
- img = base64_to_PILImage(im_b64)
1041
- remove_texts = data['remove_texts']
1042
- remove_mask_extend = data['mask_extend']
1043
- results = run_anything_task(input_image = img,
1044
- text_prompt = f"{remove_texts}",
1045
- task_type = 'remove',
1046
- inpaint_prompt = '',
1047
- box_threshold = 0.3,
1048
- text_threshold = 0.25,
1049
- iou_threshold = 0.8,
1050
- inpaint_mode = "merge",
1051
- mask_source_radio = "type what to detect below",
1052
- remove_mode = "rectangle", # ["segment", "rectangle"]
1053
- remove_mask_extend = f"{remove_mask_extend}",
1054
- num_relation = 5,
1055
- kosmos_input = None,
1056
- cleaner_size_limit = -1,
1057
- )
1058
- output_images = results[0]
1059
- if output_images is None:
1060
- return None
1061
- ret_json_images = []
1062
- file_temp = int(time.time())
1063
- count = 0
1064
- output_images = output_images[-1:]
1065
- for image_pil in output_images:
1066
- try:
1067
- img_format = image_pil.format.lower()
1068
- except Exception as e:
1069
- img_format = 'png'
1070
- image_path = os.path.join(output_dir, f"api_images_{file_temp}_{count}.{img_format}")
1071
- count += 1
1072
- try:
1073
- image_pil.save(image_path)
1074
- except Exception as e:
1075
- Image.fromarray(image_pil).save(image_path)
1076
- im_b64 = imgFile_to_base64(image_path)
1077
- ret_json_images.append(im_b64)
1078
- os.remove(image_path)
1079
- data = {
1080
- 'imgs': ret_json_images,
1081
- }
1082
- return data
1083
-
1084
- def signal_handler(self, signal, frame):
1085
- print('\nSignal Catched! You have just type Ctrl+C!')
1086
- sys.exit(0)
1087
-
1088
- def run(self):
1089
- from gevent import pywsgi
1090
- logger.info(f'\nargs={args}\n')
1091
- computer_info()
1092
- print(f"Start a api server: http://0.0.0.0:{args.port}/imgCLeaner")
1093
- server = pywsgi.WSGIServer(('0.0.0.0', args.port), self.app)
1094
- server.serve_forever()
1095
-
1096
- def main_api(args):
1097
- if args.port == 0:
1098
- print('Please give valid port!')
1099
- else:
1100
- api_starter = API_Starter()
1101
- api_starter.run()
1102
-
1103
  if __name__ == "__main__":
1104
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
1105
  parser.add_argument("--debug", action="store_true", help="using debug mode")
@@ -1137,12 +1011,7 @@ if __name__ == "__main__":
1137
  if os.environ.get('IS_MY_DEBUG') is None:
1138
  os.system("pip list")
1139
 
1140
- if run_gradio:
1141
- # Provide gradio services
1142
- main_gradio(args)
1143
- else:
1144
- # Provide API services
1145
- main_api(args)
1146
 
1147
 
1148
 
 
4
 
5
  import subprocess, io, os, sys, time
6
 
7
+ os.system("pip install gradio==3.50.2")
 
 
 
 
 
 
 
 
8
 
9
  import gradio as gr
10
  from loguru import logger
 
974
  computer_info()
975
  block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)
976
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
977
  if __name__ == "__main__":
978
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
979
  parser.add_argument("--debug", action="store_true", help="using debug mode")
 
1011
  if os.environ.get('IS_MY_DEBUG') is None:
1012
  os.system("pip list")
1013
 
1014
+ main_gradio(args)
 
 
 
 
 
1015
 
1016
 
1017