qnguyen3 commited on
Commit
7b79735
·
1 Parent(s): 6af2451
app.py CHANGED
@@ -1,86 +1,555 @@
 
 
 
 
 
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
- from threading import Thread
5
- import re
6
- import time
7
- from PIL import Image
8
- import torch
9
  import spaces
 
 
 
 
 
 
10
  import subprocess
11
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
12
 
13
- torch.set_default_device('cuda')
 
 
14
 
15
- tokenizer = AutoTokenizer.from_pretrained(
16
- 'qnguyen3/nanoLLaVA',
17
- trust_remote_code=True)
18
 
19
- model = AutoModelForCausalLM.from_pretrained(
20
- 'qnguyen3/nanoLLaVA',
21
- torch_dtype=torch.float16,
22
- device_map='auto',
23
- trust_remote_code=True)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  @spaces.GPU
27
- def bot_streaming(message, history):
28
- messages = []
29
- if message["files"]:
30
- image = message["files"][-1]["path"]
31
- else:
32
- for i, hist in enumerate(history):
33
- if type(hist[0])==tuple:
34
- image = hist[0][0]
35
- image_turn = i
36
-
37
- if len(history) > 0 and image is not None:
38
- messages.append({"role": "user", "content": f'<image>\n{history[1][0]}'})
39
- messages.append({"role": "assistant", "content": history[1][1] })
40
- for human, assistant in history[2:]:
41
- messages.append({"role": "user", "content": human })
42
- messages.append({"role": "assistant", "content": assistant })
43
- messages.append({"role": "user", "content": message['text']})
44
- elif len(history) > 0 and image is None:
45
- for human, assistant in history:
46
- messages.append({"role": "user", "content": human })
47
- messages.append({"role": "assistant", "content": assistant })
48
- messages.append({"role": "user", "content": message['text']})
49
- elif len(history) == 0 and image is not None:
50
- messages.append({"role": "user", "content": f"<image>\n{message['text']}"})
51
- elif len(history) == 0 and image is None:
52
- messages.append({"role": "user", "content": message['text'] })
53
-
54
- # if image is None:
55
- # gr.Error("You need to upload an image for LLaVA to work.")
56
- image = Image.open(image).convert("RGB")
57
- text = tokenizer.apply_chat_template(
58
- messages,
59
- tokenize=False,
60
- add_generation_prompt=True)
61
- text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
62
- input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
63
- streamer = TextIteratorStreamer(tokenizer, skip_special_tokens = True)
64
-
65
- image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
66
- generation_kwargs = dict(input_ids=input_ids, images=image_tensor, streamer=streamer, max_new_tokens=100)
67
- generated_text = ""
68
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
69
- thread.start()
70
- text_prompt =f"<|im_start|>user\n{message['text']}<|im_end|>"
71
-
72
- buffer = ""
73
- for new_text in streamer:
74
-
75
- buffer += new_text
76
-
77
- generated_text_without_prompt = buffer[len(text_prompt):]
78
- time.sleep(0.04)
79
- yield generated_text_without_prompt
80
-
81
-
82
- demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA NeXT", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]},
83
- {"text": "How to make this pastry?", "files":["./baklava.png"]}],
84
- description="Try [LLaVA NeXT](https://huggingface.co/docs/transformers/main/en/model_doc/llava_next) in this demo (more specifically, the [Mistral-7B variant](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
85
- stop_btn="Stop Generation", multimodal=True)
86
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
  import gradio as gr
7
+ import requests
8
+ import hashlib
9
+ import pypandoc
10
+ import base64
11
+ import sys
 
 
12
  import spaces
13
+
14
+ from io import BytesIO
15
+
16
+ from serve.conversation import (default_conversation, conv_templates, SeparatorStyle)
17
+ from serve.constants import LOGDIR
18
+ from serve.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg)
19
  import subprocess
 
20
 
21
+ subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
22
+
23
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
24
 
25
+ headers = {"User-Agent": "Bunny Client"}
 
 
26
 
27
+ no_change_btn = gr.update()
28
+ enable_btn = gr.update(interactive=True)
29
+ disable_btn = gr.update(interactive=False)
 
 
30
 
31
+ priority = {
32
+ "Bunny": "aaaaaaa",
33
+ }
34
+
35
+ def start_controller():
36
+ print("Starting the controller")
37
+ controller_command = [
38
+ sys.executable,
39
+ "serve/controller.py",
40
+ "--host",
41
+ "0.0.0.0",
42
+ "--port",
43
+ "10000",
44
+ ]
45
+ print(controller_command)
46
+ return subprocess.Popen(controller_command)
47
 
48
  @spaces.GPU
49
+ def start_worker(model_path: str):
50
+ print(f"Starting the model worker for the model {model_path}")
51
+ model_path = 'qnguyen3/nanoLLaVA'
52
+ worker_command = [
53
+ sys.executable,
54
+ "serve/model_worker.py",
55
+ "--host",
56
+ "0.0.0.0",
57
+ "--controller",
58
+ "http://localhost:10000",
59
+ "--port",
60
+ "40000",
61
+ "worker",
62
+ "http://localhost:40000":
63
+ "--model-path",
64
+ model_path,
65
+ "--model-type",
66
+ "qwen1.5-0.5b",
67
+ "--use-flash-attn",
68
+ ]
69
+ print(worker_command)
70
+ return subprocess.Popen(worker_command)
71
+
72
+
73
+ def get_conv_log_filename():
74
+ t = datetime.datetime.now()
75
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
76
+ return name
77
+
78
+
79
+ def get_model_list():
80
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
81
+ assert ret.status_code == 200
82
+ ret = requests.post(args.controller_url + "/list_models")
83
+ models = ret.json()["models"]
84
+ models.sort(key=lambda x: priority.get(x, x))
85
+ logger.info(f"Models: {models}")
86
+ return models
87
+
88
+
89
+ get_window_url_params = """
90
+ function() {
91
+ const params = new URLSearchParams(window.location.search);
92
+ url_params = Object.fromEntries(params);
93
+ console.log(url_params);
94
+ return url_params;
95
+ }
96
+ """
97
+
98
+
99
+ def load_demo(url_params, request: gr.Request):
100
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
101
+
102
+ dropdown_update = gr.update(visible=True)
103
+ if "model" in url_params:
104
+ model = url_params["model"]
105
+ if model in models:
106
+ dropdown_update = gr.update(
107
+ value=model, visible=True)
108
+
109
+ state = default_conversation.copy()
110
+ return state, dropdown_update
111
+
112
+
113
+ def load_demo_refresh_model_list(request: gr.Request):
114
+ logger.info(f"load_demo. ip: {request.client.host}")
115
+ models = get_model_list()
116
+ state = default_conversation.copy()
117
+ dropdown_update = gr.update(
118
+ choices=models,
119
+ value=models[0] if len(models) > 0 else ""
120
+ )
121
+ return state, dropdown_update
122
+
123
+
124
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
125
+ with open(get_conv_log_filename(), "a") as fout:
126
+ data = {
127
+ "tstamp": round(time.time(), 4),
128
+ "type": vote_type,
129
+ "model": model_selector,
130
+ "state": state.dict(),
131
+ "ip": request.client.host,
132
+ }
133
+ fout.write(json.dumps(data) + "\n")
134
+
135
+
136
+ def upvote_last_response(state, model_selector, request: gr.Request):
137
+ logger.info(f"upvote. ip: {request.client.host}")
138
+ vote_last_response(state, "upvote", model_selector, request)
139
+ return ("",) + (disable_btn,) * 3
140
+
141
+
142
+ def downvote_last_response(state, model_selector, request: gr.Request):
143
+ logger.info(f"downvote. ip: {request.client.host}")
144
+ vote_last_response(state, "downvote", model_selector, request)
145
+ return ("",) + (disable_btn,) * 3
146
+
147
+
148
+ def flag_last_response(state, model_selector, request: gr.Request):
149
+ logger.info(f"flag. ip: {request.client.host}")
150
+ vote_last_response(state, "flag", model_selector, request)
151
+ return ("",) + (disable_btn,) * 3
152
+
153
+
154
+ def regenerate(state, image_process_mode, request: gr.Request):
155
+ logger.info(f"regenerate. ip: {request.client.host}")
156
+ state.messages[-1][-1] = None
157
+ prev_human_msg = state.messages[-2]
158
+ if type(prev_human_msg[1]) in (tuple, list):
159
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
160
+ state.skip_next = False
161
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
162
+
163
+
164
+ def clear_history(request: gr.Request):
165
+ logger.info(f"clear_history. ip: {request.client.host}")
166
+ state = default_conversation.copy()
167
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
168
+
169
+
170
+ def save_conversation(conversation):
171
+ print("save_conversation_wrapper is called")
172
+ html_content = "<html><body>"
173
+
174
+ for role, message in conversation.messages:
175
+ if isinstance(message, str): # only text
176
+ html_content += f"<p><b>{role}</b>: {message}</p>"
177
+ elif isinstance(message, tuple): # text+image
178
+ text, image_obj, _ = message
179
+
180
+ # add text
181
+ if text:
182
+ html_content += f"<p><b>{role}</b>: {text}</p>"
183
+
184
+ # add image
185
+ buffered = BytesIO()
186
+ image_obj.save(buffered, format="PNG")
187
+ encoded_image = base64.b64encode(buffered.getvalue()).decode()
188
+ html_content += f'<img src="data:image/png;base64,{encoded_image}" /><br>'
189
+
190
+ html_content += "</body></html>"
191
+
192
+ doc_path = "./conversation.docx"
193
+ pypandoc.convert_text(html_content, 'docx', format='html', outputfile=doc_path,
194
+ extra_args=["-M2GB", "+RTS", "-K64m", "-RTS"])
195
+ return doc_path
196
+
197
+
198
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
199
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
200
+ if len(text) <= 0 and image is None:
201
+ state.skip_next = True
202
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
203
+ if args.moderate:
204
+ flagged = violates_moderation(text)
205
+ if flagged:
206
+ state.skip_next = True
207
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
208
+ no_change_btn,) * 5
209
+
210
+ text = text[:1536] # Hard cut-off
211
+ if image is not None:
212
+ text = text[:1200] # Hard cut-off for images
213
+ if '<image>' not in text:
214
+ # text = '<Image><image></Image>' + text
215
+ text = text + '\n<image>'
216
+ text = (text, image, image_process_mode)
217
+ if len(state.get_images(return_pil=True)) > 0:
218
+ state = default_conversation.copy()
219
+ logger.info(f"Input Text: {text}")
220
+ state.append_message(state.roles[0], text)
221
+ state.append_message(state.roles[1], None)
222
+ state.skip_next = False
223
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
224
+
225
+
226
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
227
+ logger.info(f"http_bot. ip: {request.client.host}")
228
+ start_tstamp = time.time()
229
+ model_name = model_selector
230
+
231
+ if state.skip_next:
232
+ # This generate call is skipped due to invalid inputs
233
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
234
+ return
235
+
236
+ if len(state.messages) == state.offset + 2:
237
+ template_name = "bunny"
238
+ new_state = conv_templates[template_name].copy()
239
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
240
+ new_state.append_message(new_state.roles[1], None)
241
+ state = new_state
242
+
243
+ logger.info(f"Processed Input Text: {state.messages[-2][1]}")
244
+ # Query worker address
245
+ controller_url = args.controller_url
246
+ ret = requests.post(controller_url + "/get_worker_address",
247
+ json={"model": model_name})
248
+ worker_addr = ret.json()["address"]
249
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
250
+
251
+ # No available worker
252
+ if worker_addr == "":
253
+ state.messages[-1][-1] = server_error_msg
254
+ yield (state, state.to_gradio_chatbot(), enable_btn, enable_btn, enable_btn)
255
+ return
256
+
257
+ # Construct prompt
258
+ prompt = state.get_prompt()
259
+
260
+ all_images = state.get_images(return_pil=True)
261
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
262
+ for image, hash in zip(all_images, all_image_hash):
263
+ t = datetime.datetime.now()
264
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
265
+ if not os.path.isfile(filename):
266
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
267
+ image.save(filename)
268
+
269
+ # Make requests
270
+ pload = {
271
+ "model": model_name,
272
+ "prompt": prompt,
273
+ "temperature": float(temperature),
274
+ "top_p": float(top_p),
275
+ "max_new_tokens": min(int(max_new_tokens), 1536),
276
+ "stop": '<|im_end|>', #state.sep if state.sep_style in [SeparatorStyle.PLAIN, ] else state.sep2,
277
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
278
+ }
279
+ logger.info(f"==== request ====\n{pload}")
280
+
281
+ pload['images'] = state.get_images()
282
+ print('=========> get_images')
283
+ state.messages[-1][-1] = "▌"
284
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
285
+ print('=========> state', state.messages[-1][-1])
286
+
287
+ try:
288
+ # Stream output
289
+ response = requests.post(worker_addr + "/worker_generate_stream",
290
+ headers=headers, json=pload, stream=True, timeout=1000)
291
+ print("====> response ok")
292
+ print("====> response dir", dir(response))
293
+ print("====> response", response)
294
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
295
+ if chunk:
296
+ data = json.loads(chunk.decode())
297
+ if data["error_code"] == 0:
298
+ output = data["text"][len(prompt):].strip()
299
+ state.messages[-1][-1] = output + "▌"
300
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
301
+ else:
302
+ output = data["text"] + f" (error_code: {data['error_code']})"
303
+ state.messages[-1][-1] = output
304
+ yield (state, state.to_gradio_chatbot()) + (enable_btn, enable_btn, enable_btn)
305
+ return
306
+ time.sleep(0.03)
307
+ except requests.exceptions.RequestException as e:
308
+ state.messages[-1][-1] = server_error_msg
309
+ yield (state, state.to_gradio_chatbot()) + (enable_btn, enable_btn, enable_btn)
310
+ return
311
+
312
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
313
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
314
+
315
+ finish_tstamp = time.time()
316
+ logger.info(f"{output}")
317
+
318
+ with open(get_conv_log_filename(), "a") as fout:
319
+ data = {
320
+ "tstamp": round(finish_tstamp, 4),
321
+ "type": "chat",
322
+ "model": model_name,
323
+ "start": round(start_tstamp, 4),
324
+ "finish": round(finish_tstamp, 4),
325
+ "state": state.dict(),
326
+ "images": all_image_hash,
327
+ "ip": request.client.host,
328
+ }
329
+ fout.write(json.dumps(data) + "\n")
330
+
331
+
332
+ title_markdown = ("""
333
+ # 🐰 Bunny: A family of lightweight multimodal models
334
+
335
+ [📖[Technical report](https://arxiv.org/abs/2402.11530)] | [🏠[Code](https://github.com/BAAI-DCAI/Bunny)] | [🤗[Model](https://huggingface.co/BAAI/Bunny-v1_0-3B)]
336
+
337
+ """)
338
+
339
+ tos_markdown = ("""
340
+ ### Terms of use
341
+ By using this service, users are required to agree to the following terms:
342
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
343
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
344
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
345
+ """)
346
+
347
+ learn_more_markdown = ("""
348
+ ### License
349
+ This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses. The content of this project itself is licensed under the Apache license 2.0.
350
+ """)
351
+
352
+ block_css = """
353
+ .centered {
354
+ text-align: center;
355
+ }
356
+ #buttons button {
357
+ min-width: min(120px,100%);
358
+ }
359
+ #file-downloader {
360
+ min-width: min(120px,100%);
361
+ height: 50px;
362
+ }
363
+ """
364
+
365
+
366
+ def trigger_download(doc_path):
367
+ return doc_path
368
+
369
+
370
+ def build_demo(embed_mode):
371
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
372
+ with gr.Blocks(title="Bunny", theme=gr.themes.Default(primary_hue="blue", secondary_hue="lime"),
373
+ css=block_css) as demo:
374
+ state = gr.State()
375
+
376
+ if not embed_mode:
377
+ gr.Markdown(title_markdown)
378
+
379
+ with gr.Row():
380
+ with gr.Column(scale=4):
381
+ with gr.Row(elem_id="model_selector_row"):
382
+ model_selector = gr.Dropdown(
383
+ choices=models,
384
+ value=models[0] if len(models) > 0 else "",
385
+ interactive=True,
386
+ show_label=False,
387
+ container=False,
388
+ allow_custom_value=True
389
+ )
390
+
391
+ imagebox = gr.Image(type="pil")
392
+ image_process_mode = gr.Radio(
393
+ ["Crop", "Resize", "Pad", "Default"],
394
+ value="Default",
395
+ label="Preprocess for non-square image", visible=False)
396
+
397
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
398
+ gr.Examples(examples=[
399
+ [f"{cur_dir}/examples/example_1.png", "What is the astronaut holding in his hand?"],
400
+ [f"{cur_dir}/examples/example_2.png", "Why is the image funny?"],
401
+ ], inputs=[imagebox, textbox])
402
+
403
+ with gr.Accordion("Parameters", open=False) as parameter_row:
404
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
405
+ label="Temperature", )
406
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P", )
407
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
408
+ label="Max output tokens", )
409
+
410
+ file_output = gr.components.File(label="Download Document", visible=True, elem_id="file-downloader")
411
+ with gr.Column(scale=8):
412
+ chatbot = gr.Chatbot(elem_id="chatbot", label="Bunny Chatbot",
413
+ avatar_images=[f"{cur_dir}/examples/user.png", f"{cur_dir}/examples/icon.jpg"],
414
+ height=550)
415
+ with gr.Row():
416
+ with gr.Column(scale=8):
417
+ textbox.render()
418
+ with gr.Column(scale=1, min_width=50):
419
+ submit_btn = gr.Button(value="Send", variant="primary")
420
+
421
+ with gr.Row(elem_id="buttons") as button_row:
422
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
423
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
424
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
425
+ regenerate_btn = gr.Button(value="🔁 Regenerate", interactive=False)
426
+ clear_btn = gr.Button(value="🚮 Clear", interactive=False)
427
+ save_conversation_btn = gr.Button(value="🗃️ Save", interactive=False)
428
+
429
+ if not embed_mode:
430
+ gr.Markdown(tos_markdown)
431
+ gr.Markdown(learn_more_markdown)
432
+ url_params = gr.JSON(visible=False)
433
+
434
+ # Register listeners
435
+ btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn, save_conversation_btn]
436
+
437
+ upvote_btn.click(
438
+ upvote_last_response,
439
+ [state, model_selector],
440
+ [textbox, upvote_btn, downvote_btn]
441
+ )
442
+ downvote_btn.click(
443
+ downvote_last_response,
444
+ [state, model_selector],
445
+ [textbox, upvote_btn, downvote_btn]
446
+ )
447
+
448
+ regenerate_btn.click(
449
+ regenerate,
450
+ [state, image_process_mode],
451
+ [state, chatbot, textbox, imagebox] + btn_list,
452
+ queue=False
453
+ ).then(
454
+ http_bot,
455
+ [state, model_selector, temperature, top_p, max_output_tokens],
456
+ [state, chatbot] + btn_list
457
+ )
458
+
459
+ clear_btn.click(
460
+ clear_history,
461
+ None,
462
+ [state, chatbot, textbox, imagebox] + btn_list,
463
+ queue=False
464
+ )
465
+
466
+ save_conversation_btn.click(
467
+ save_conversation,
468
+ inputs=[state],
469
+ outputs=file_output
470
+ )
471
+
472
+ textbox.submit(
473
+ add_text,
474
+ [state, textbox, imagebox, image_process_mode],
475
+ [state, chatbot, textbox, imagebox] + btn_list,
476
+ queue=False
477
+ ).then(
478
+ http_bot,
479
+ [state, model_selector, temperature, top_p, max_output_tokens],
480
+ [state, chatbot] + btn_list
481
+ )
482
+
483
+ submit_btn.click(
484
+ add_text,
485
+ [state, textbox, imagebox, image_process_mode],
486
+ [state, chatbot, textbox, imagebox] + btn_list,
487
+ queue=False
488
+ ).then(
489
+ http_bot,
490
+ [state, model_selector, temperature, top_p, max_output_tokens],
491
+ [state, chatbot] + btn_list
492
+ )
493
+
494
+ if args.model_list_mode == "once":
495
+ demo.load(
496
+ load_demo,
497
+ [url_params],
498
+ [state, model_selector],
499
+ _js=get_window_url_params,
500
+ queue=False
501
+ )
502
+ elif args.model_list_mode == "reload":
503
+ demo.load(
504
+ load_demo_refresh_model_list,
505
+ None,
506
+ [state, model_selector],
507
+ queue=False
508
+ )
509
+ else:
510
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
511
+
512
+ return demo
513
+
514
+
515
+ if __name__ == "__main__":
516
+ parser = argparse.ArgumentParser()
517
+ parser.add_argument("--host", type=str, default="127.0.0.1")
518
+ parser.add_argument("--port", type=int)
519
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
520
+ parser.add_argument("--concurrency-count", type=int, default=10)
521
+ parser.add_argument("--model-list-mode", type=str, default="once",
522
+ choices=["once", "reload"])
523
+ parser.add_argument("--share", action="store_true")
524
+ parser.add_argument("--moderate", action="store_true")
525
+ parser.add_argument("--embed", action="store_true")
526
+ args = parser.parse_args()
527
+ logger.info(f"args: {args}")
528
+
529
+ models = get_model_list()
530
+ logger.info(args)
531
+
532
+ model_path = os.getenv("model", "liuhaotian/llava-v1.6-mistral-7b")
533
+ concurrency_count = int(os.getenv("concurrency_count", 5))
534
+
535
+ controller_proc = start_controller()
536
+ model_path = 'qnguyen3/nanoLLaVA'
537
+ worker_proc = start_worker(model_path)
538
+ time.sleep(10)
539
+ exit_status = 0
540
+ try:
541
+ demo = build_demo(args.embed)
542
+ demo.launch(
543
+ server_name=args.host,
544
+ server_port=args.port,
545
+ share=args.share,
546
+ debug=True,
547
+ max_threads=10
548
+ )
549
+ except Exception as e:
550
+ print(e)
551
+ exit_status = 1
552
+ finally:
553
+ worker_proc.kill()
554
+ controller_proc.kill()
555
+ sys.exit(exit_status)
serve/builder.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ import torch
4
+
5
+ from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig, logging, AutoModelForCausalLM
6
+
7
+ logging.set_verbosity_error()
8
+
9
+ def load_pretrained_model(model_path, model_base, model_name, model_type, load_8bit=False, load_4bit=False,
10
+ device_map="auto", device="cuda", **kwargs):
11
+ if model_type not in {'qwen1.5-1.8b', 'qwen1.5-0.5b'}:
12
+ raise ValueError(f"Unknown Model Type {model_type}")
13
+
14
+ kwargs = {"device_map": device_map, **kwargs}
15
+
16
+ if device != "cuda":
17
+ kwargs['device_map'] = {"": device}
18
+
19
+ if load_8bit:
20
+ kwargs['load_in_8bit'] = True
21
+ elif load_4bit:
22
+ kwargs['load_in_4bit'] = True
23
+ kwargs['quantization_config'] = BitsAndBytesConfig(
24
+ load_in_4bit=True,
25
+ bnb_4bit_compute_dtype=torch.float16,
26
+ bnb_4bit_use_double_quant=True,
27
+ bnb_4bit_quant_type='nf4'
28
+ )
29
+ else:
30
+ kwargs['torch_dtype'] = torch.float16
31
+
32
+ if 'lora' in model_name.lower() and model_base is None:
33
+ warnings.warn(
34
+ 'There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument.')
35
+ if 'lora' in model_name.lower() and model_base is not None:
36
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
37
+
38
+ print('Loading nanoLLaVA from base model...')
39
+ if model_type == 'qwen1.5-1.8b' or model_type == 'qwen1.5-0.5b':
40
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
41
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained,
42
+ **kwargs)
43
+
44
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
45
+ if model.lm_head.weight.shape[0] != token_num:
46
+ model.lm_head.weight = torch.nn.Parameter(
47
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
48
+ model.model.embed_tokens.weight = torch.nn.Parameter(
49
+ torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
50
+
51
+ print('Loading additional nanoLLaVA weights...')
52
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
53
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
54
+ else:
55
+ # this is probably from HF Hub
56
+ from huggingface_hub import hf_hub_download
57
+ def load_from_hf(repo_id, filename, subfolder=None):
58
+ cache_file = hf_hub_download(
59
+ repo_id=repo_id,
60
+ filename=filename,
61
+ subfolder=subfolder)
62
+ return torch.load(cache_file, map_location='cpu')
63
+
64
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
65
+
66
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in
67
+ non_lora_trainables.items()}
68
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
69
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in
70
+ non_lora_trainables.items()}
71
+ model.load_state_dict(non_lora_trainables, strict=False)
72
+
73
+ from peft import PeftModel
74
+ print('Loading LoRA weights...')
75
+ model = PeftModel.from_pretrained(model, model_path)
76
+ print('Merging LoRA weights...')
77
+ model = model.merge_and_unload()
78
+ print('Model is loaded...')
79
+ elif model_base is not None:
80
+ # this may be mm projector only
81
+ print('Loading nanoLLaVA from base model...')
82
+
83
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
84
+ if model_type == 'qwen1.5-1.8b' or model_type == 'qwen1.5-0.5b':
85
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
86
+ model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained,
87
+ **kwargs)
88
+
89
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
90
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
91
+ model.load_state_dict(mm_projector_weights, strict=False)
92
+ else:
93
+ if model_type == 'qwen1.5-1.8b' or model_type == 'qwen1.5-0.5b':
94
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
95
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
96
+
97
+ model.resize_token_embeddings(len(tokenizer))
98
+
99
+ vision_tower = model.get_vision_tower()
100
+ if not vision_tower.is_loaded:
101
+ vision_tower.load_model()
102
+ vision_tower.to(device=device, dtype=torch.float16)
103
+ image_processor = vision_tower.image_processor
104
+
105
+ if hasattr(model.config, "max_sequence_length"):
106
+ context_len = model.config.max_sequence_length
107
+ else:
108
+ context_len = 2048
109
+
110
+ if model.generation_config.pad_token_id is None:
111
+ model.generation_config.pad_token_id = model.generation_config.eos_token_id
112
+
113
+ return tokenizer, model, image_processor, context_len
serve/constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Model Constants
2
+ IGNORE_INDEX = -100
3
+ IMAGE_TOKEN_INDEX = -200
4
+ DEFAULT_IMAGE_TOKEN = "<image>"
5
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
6
+ LOGDIR = "gradio-logs"
7
+ WORKER_HEART_BEAT_INTERVAL = 15
serve/controller.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import dataclasses
7
+ import threading
8
+ import json
9
+ import time
10
+ import numpy as np
11
+ import requests
12
+ import uvicorn
13
+
14
+ from typing import List
15
+ from enum import Enum, auto
16
+ from fastapi import FastAPI, Request
17
+ from fastapi.responses import StreamingResponse
18
+
19
+ from .constants import CONTROLLER_HEART_BEAT_EXPIRATION
20
+ from .utils import build_logger, server_error_msg
21
+
22
+ logger = build_logger("controller", "controller.log")
23
+
24
+
25
+ class DispatchMethod(Enum):
26
+ LOTTERY = auto()
27
+ SHORTEST_QUEUE = auto()
28
+
29
+ @classmethod
30
+ def from_str(cls, name):
31
+ if name == "lottery":
32
+ return cls.LOTTERY
33
+ elif name == "shortest_queue":
34
+ return cls.SHORTEST_QUEUE
35
+ else:
36
+ raise ValueError(f"Invalid dispatch method")
37
+
38
+
39
+ @dataclasses.dataclass
40
+ class WorkerInfo:
41
+ model_names: List[str]
42
+ speed: int
43
+ queue_length: int
44
+ check_heart_beat: bool
45
+ last_heart_beat: str
46
+
47
+
48
+ def heart_beat_controller(controller):
49
+ while True:
50
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
51
+ controller.remove_stable_workers_by_expiration()
52
+
53
+
54
+ class Controller:
55
+ def __init__(self, dispatch_method: str):
56
+ # Dict[str -> WorkerInfo]
57
+ self.worker_info = {}
58
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
59
+
60
+ self.heart_beat_thread = threading.Thread(
61
+ target=heart_beat_controller, args=(self,))
62
+ self.heart_beat_thread.start()
63
+
64
+ logger.info("Init controller")
65
+
66
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
67
+ worker_status: dict):
68
+ if worker_name not in self.worker_info:
69
+ logger.info(f"Register a new worker: {worker_name}")
70
+ else:
71
+ logger.info(f"Register an existing worker: {worker_name}")
72
+
73
+ if not worker_status:
74
+ worker_status = self.get_worker_status(worker_name)
75
+ if not worker_status:
76
+ return False
77
+
78
+ self.worker_info[worker_name] = WorkerInfo(
79
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
80
+ check_heart_beat, time.time())
81
+
82
+ logger.info(f"Register done: {worker_name}, {worker_status}")
83
+ return True
84
+
85
+ def get_worker_status(self, worker_name: str):
86
+ try:
87
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
88
+ except requests.exceptions.RequestException as e:
89
+ logger.error(f"Get status fails: {worker_name}, {e}")
90
+ return None
91
+
92
+ if r.status_code != 200:
93
+ logger.error(f"Get status fails: {worker_name}, {r}")
94
+ return None
95
+
96
+ return r.json()
97
+
98
+ def remove_worker(self, worker_name: str):
99
+ del self.worker_info[worker_name]
100
+
101
+ def refresh_all_workers(self):
102
+ old_info = dict(self.worker_info)
103
+ self.worker_info = {}
104
+
105
+ for w_name, w_info in old_info.items():
106
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
107
+ logger.info(f"Remove stale worker: {w_name}")
108
+
109
+ def list_models(self):
110
+ model_names = set()
111
+
112
+ for w_name, w_info in self.worker_info.items():
113
+ model_names.update(w_info.model_names)
114
+
115
+ return list(model_names)
116
+
117
+ def get_worker_address(self, model_name: str):
118
+ if self.dispatch_method == DispatchMethod.LOTTERY:
119
+ worker_names = []
120
+ worker_speeds = []
121
+ for w_name, w_info in self.worker_info.items():
122
+ if model_name in w_info.model_names:
123
+ worker_names.append(w_name)
124
+ worker_speeds.append(w_info.speed)
125
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
126
+ norm = np.sum(worker_speeds)
127
+ if norm < 1e-4:
128
+ return ""
129
+ worker_speeds = worker_speeds / norm
130
+
131
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
132
+ worker_name = worker_names[pt]
133
+ return worker_name
134
+
135
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
136
+ worker_names = []
137
+ worker_qlen = []
138
+ for w_name, w_info in self.worker_info.items():
139
+ if model_name in w_info.model_names:
140
+ worker_names.append(w_name)
141
+ worker_qlen.append(w_info.queue_length / w_info.speed)
142
+ if len(worker_names) == 0:
143
+ return ""
144
+ min_index = np.argmin(worker_qlen)
145
+ w_name = worker_names[min_index]
146
+ self.worker_info[w_name].queue_length += 1
147
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
148
+ return w_name
149
+ else:
150
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
151
+
152
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
153
+ if worker_name not in self.worker_info:
154
+ logger.info(f"Receive unknown heart beat. {worker_name}")
155
+ return False
156
+
157
+ self.worker_info[worker_name].queue_length = queue_length
158
+ self.worker_info[worker_name].last_heart_beat = time.time()
159
+ logger.info(f"Receive heart beat. {worker_name}")
160
+ return True
161
+
162
+ def remove_stable_workers_by_expiration(self):
163
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
164
+ to_delete = []
165
+ for worker_name, w_info in self.worker_info.items():
166
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
167
+ to_delete.append(worker_name)
168
+
169
+ for worker_name in to_delete:
170
+ self.remove_worker(worker_name)
171
+
172
+ def worker_api_generate_stream(self, params):
173
+ worker_addr = self.get_worker_address(params["model"])
174
+ if not worker_addr:
175
+ logger.info(f"no worker: {params['model']}")
176
+ ret = {
177
+ "text": server_error_msg,
178
+ "error_code": 2,
179
+ }
180
+ yield json.dumps(ret).encode() + b"\0"
181
+
182
+ try:
183
+ response = requests.post(worker_addr + "/worker_generate_stream",
184
+ json=params, stream=True, timeout=5)
185
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
186
+ if chunk:
187
+ yield chunk + b"\0"
188
+ except requests.exceptions.RequestException as e:
189
+ logger.info(f"worker timeout: {worker_addr}")
190
+ ret = {
191
+ "text": server_error_msg,
192
+ "error_code": 3,
193
+ }
194
+ yield json.dumps(ret).encode() + b"\0"
195
+
196
+ # Let the controller act as a worker to achieve hierarchical
197
+ # management. This can be used to connect isolated sub networks.
198
+ def worker_api_get_status(self):
199
+ model_names = set()
200
+ speed = 0
201
+ queue_length = 0
202
+
203
+ for w_name in self.worker_info:
204
+ worker_status = self.get_worker_status(w_name)
205
+ if worker_status is not None:
206
+ model_names.update(worker_status["model_names"])
207
+ speed += worker_status["speed"]
208
+ queue_length += worker_status["queue_length"]
209
+
210
+ return {
211
+ "model_names": list(model_names),
212
+ "speed": speed,
213
+ "queue_length": queue_length,
214
+ }
215
+
216
+
217
+ app = FastAPI()
218
+
219
+
220
+ @app.post("/register_worker")
221
+ async def register_worker(request: Request):
222
+ data = await request.json()
223
+ controller.register_worker(
224
+ data["worker_name"], data["check_heart_beat"],
225
+ data.get("worker_status", None))
226
+
227
+
228
+ @app.post("/refresh_all_workers")
229
+ async def refresh_all_workers():
230
+ models = controller.refresh_all_workers()
231
+
232
+
233
+ @app.post("/list_models")
234
+ async def list_models():
235
+ models = controller.list_models()
236
+ return {"models": models}
237
+
238
+
239
+ @app.post("/get_worker_address")
240
+ async def get_worker_address(request: Request):
241
+ data = await request.json()
242
+ addr = controller.get_worker_address(data["model"])
243
+ return {"address": addr}
244
+
245
+
246
+ @app.post("/receive_heart_beat")
247
+ async def receive_heart_beat(request: Request):
248
+ data = await request.json()
249
+ exist = controller.receive_heart_beat(
250
+ data["worker_name"], data["queue_length"])
251
+ return {"exist": exist}
252
+
253
+
254
+ @app.post("/worker_generate_stream")
255
+ async def worker_api_generate_stream(request: Request):
256
+ params = await request.json()
257
+ generator = controller.worker_api_generate_stream(params)
258
+ return StreamingResponse(generator)
259
+
260
+
261
+ @app.post("/worker_get_status")
262
+ async def worker_api_get_status(request: Request):
263
+ return controller.worker_api_get_status()
264
+
265
+
266
+ if __name__ == "__main__":
267
+ parser = argparse.ArgumentParser()
268
+ parser.add_argument("--host", type=str, default="localhost")
269
+ parser.add_argument("--port", type=int, default=21001)
270
+ parser.add_argument("--dispatch-method", type=str, choices=["lottery", "shortest_queue"], default="shortest_queue")
271
+ args = parser.parse_args()
272
+ logger.info(f"args: {args}")
273
+
274
+ controller = Controller(args.dispatch_method)
275
+ log_config = uvicorn.config.LOGGING_CONFIG
276
+ log_config['handlers']['default']['stream'] = 'ext://sys.stdout'
277
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
serve/conversation.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ TWO = auto()
9
+ PLAIN = auto()
10
+ MPT = auto()
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class Conversation:
15
+ """A class that keeps all conversation history."""
16
+ system: str
17
+ roles: List[str]
18
+ messages: List[List[str]]
19
+ offset: int
20
+ sep_style: SeparatorStyle
21
+ sep: str = "###"
22
+ sep2: str = None
23
+ version: str = "Unknown"
24
+
25
+ skip_next: bool = False
26
+
27
+ def get_prompt(self):
28
+ messages = self.messages
29
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
30
+ messages = self.messages.copy()
31
+ init_role, init_msg = messages[0].copy()
32
+ init_msg = init_msg[0].replace("<image>", "").strip()
33
+ if 'mmtag' in self.version:
34
+ messages[0] = (init_role, init_msg)
35
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
36
+ messages.insert(1, (self.roles[1], "Received."))
37
+ else:
38
+ messages[0] = (init_role, "<image>\n" + init_msg)
39
+
40
+ if self.sep_style == SeparatorStyle.TWO:
41
+ seps = [self.sep, self.sep2]
42
+ ret = self.system + seps[0]
43
+ for i, (role, message) in enumerate(messages):
44
+ if message:
45
+ if type(message) is tuple:
46
+ message, _, _ = message
47
+ ret += role + ": " + message + seps[i % 2]
48
+ else:
49
+ ret += role + ":"
50
+
51
+ elif self.sep_style == SeparatorStyle.MPT:
52
+ ret = self.system + self.sep
53
+ for role, message in messages:
54
+ if message:
55
+ if type(message) is tuple:
56
+ message, _, _ = message
57
+ ret += role + message + self.sep
58
+ else:
59
+ ret += role
60
+
61
+ elif self.sep_style == SeparatorStyle.PLAIN:
62
+ seps = [self.sep, self.sep2]
63
+ ret = self.system
64
+ for i, (role, message) in enumerate(messages):
65
+ if message:
66
+ if type(message) is tuple:
67
+ message, _, _ = message
68
+ ret += message + seps[i % 2]
69
+ else:
70
+ ret += ""
71
+ else:
72
+ raise ValueError(f"Invalid style: {self.sep_style}")
73
+
74
+ return ret
75
+
76
+ def append_message(self, role, message):
77
+ self.messages.append([role, message])
78
+
79
+ def get_images(self, return_pil=False):
80
+ images = []
81
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
82
+ if i % 2 == 0:
83
+ if type(msg) is tuple:
84
+ import base64
85
+ from io import BytesIO
86
+ from PIL import Image
87
+ msg, image, image_process_mode = msg
88
+ if image_process_mode == "Pad":
89
+ def expand2square(pil_img, background_color=(122, 116, 104)):
90
+ width, height = pil_img.size
91
+ if width == height:
92
+ return pil_img
93
+ elif width > height:
94
+ result = Image.new(pil_img.mode, (width, width), background_color)
95
+ result.paste(pil_img, (0, (width - height) // 2))
96
+ return result
97
+ else:
98
+ result = Image.new(pil_img.mode, (height, height), background_color)
99
+ result.paste(pil_img, ((height - width) // 2, 0))
100
+ return result
101
+
102
+ image = expand2square(image)
103
+ elif image_process_mode in ["Default", "Crop"]:
104
+ pass
105
+ elif image_process_mode == "Resize":
106
+ image = image.resize((336, 336))
107
+ else:
108
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
109
+ max_hw, min_hw = max(image.size), min(image.size)
110
+ aspect_ratio = max_hw / min_hw
111
+ max_len, min_len = 800, 400
112
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
113
+ longest_edge = int(shortest_edge * aspect_ratio)
114
+ W, H = image.size
115
+ if longest_edge != max(image.size):
116
+ if H > W:
117
+ H, W = longest_edge, shortest_edge
118
+ else:
119
+ H, W = shortest_edge, longest_edge
120
+ image = image.resize((W, H))
121
+ if return_pil:
122
+ images.append(image)
123
+ else:
124
+ buffered = BytesIO()
125
+ image.save(buffered, format="PNG")
126
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
127
+ images.append(img_b64_str)
128
+ return images
129
+
130
+ def to_gradio_chatbot(self):
131
+ ret = []
132
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
133
+ if i % 2 == 0:
134
+ if type(msg) is tuple:
135
+ import base64
136
+ from io import BytesIO
137
+ msg, image, image_process_mode = msg
138
+ max_hw, min_hw = max(image.size), min(image.size)
139
+ aspect_ratio = max_hw / min_hw
140
+ max_len, min_len = 800, 400
141
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
142
+ longest_edge = int(shortest_edge * aspect_ratio)
143
+ W, H = image.size
144
+ if H > W:
145
+ H, W = longest_edge, shortest_edge
146
+ else:
147
+ H, W = shortest_edge, longest_edge
148
+ image = image.resize((W, H))
149
+ buffered = BytesIO()
150
+ image.save(buffered, format="JPEG")
151
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
152
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
153
+ msg = img_str + msg.replace('<image>', '').strip()
154
+ ret.append([msg, None])
155
+ else:
156
+ ret.append([msg, None])
157
+ else:
158
+ ret[-1][-1] = msg
159
+ return ret
160
+
161
+ def copy(self):
162
+ return Conversation(
163
+ system=self.system,
164
+ roles=self.roles,
165
+ messages=[[x, y] for x, y in self.messages],
166
+ offset=self.offset,
167
+ sep_style=self.sep_style,
168
+ sep=self.sep,
169
+ sep2=self.sep2,
170
+ version=self.version)
171
+
172
+ def dict(self):
173
+ if len(self.get_images()) > 0:
174
+ return {
175
+ "system": self.system,
176
+ "roles": self.roles,
177
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
178
+ "offset": self.offset,
179
+ "sep": self.sep,
180
+ "sep2": self.sep2,
181
+ }
182
+ return {
183
+ "system": self.system,
184
+ "roles": self.roles,
185
+ "messages": self.messages,
186
+ "offset": self.offset,
187
+ "sep": self.sep,
188
+ "sep2": self.sep2,
189
+ }
190
+
191
+
192
+ conv_bunny = Conversation(
193
+ system="A chat between a curious user and an artificial intelligence assistant. "
194
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
195
+ roles=("USER", "ASSISTANT"),
196
+ version="bunny",
197
+ messages=(),
198
+ offset=0,
199
+ sep_style=SeparatorStyle.TWO,
200
+ sep=" ",
201
+ sep2="<|endoftext|>",
202
+ )
203
+
204
+ conv_plain = Conversation(
205
+ system="",
206
+ roles=("", ""),
207
+ messages=(
208
+ ),
209
+ offset=0,
210
+ sep_style=SeparatorStyle.PLAIN,
211
+ sep="\n",
212
+ )
213
+
214
+ conv_chatml_direct = Conversation(
215
+ system="""<|im_start|>system
216
+ Answer the questions.""",
217
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
218
+ version="mpt",
219
+ messages=(),
220
+ offset=0,
221
+ sep_style=SeparatorStyle.MPT,
222
+ sep="<|im_end|>",
223
+ )
224
+
225
+ default_conversation = conv_bunny
226
+ conv_templates = {
227
+ "default": conv_bunny,
228
+ "bunny": conv_bunny,
229
+ "plain": conv_plain,
230
+ "chatml_direct": conv_chatml_direct,
231
+ }
232
+
233
+ if __name__ == "__main__":
234
+ print(default_conversation.get_prompt())
serve/examples/example_1.png ADDED
serve/examples/example_2.png ADDED
serve/examples/icon.jpg ADDED
serve/examples/user.png ADDED
serve/gradio_web_server.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+ import gradio as gr
7
+ import requests
8
+ import hashlib
9
+ import pypandoc
10
+ import base64
11
+
12
+ from io import BytesIO
13
+
14
+ from .conversation import (default_conversation, conv_templates, SeparatorStyle)
15
+ from .constants import LOGDIR
16
+ from .utils import (build_logger, server_error_msg, violates_moderation, moderation_msg)
17
+
18
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
19
+
20
+ headers = {"User-Agent": "Bunny Client"}
21
+
22
+ no_change_btn = gr.update()
23
+ enable_btn = gr.update(interactive=True)
24
+ disable_btn = gr.update(interactive=False)
25
+
26
+ priority = {
27
+ "Bunny": "aaaaaaa",
28
+ }
29
+
30
+
31
+ def get_conv_log_filename():
32
+ t = datetime.datetime.now()
33
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
34
+ return name
35
+
36
+
37
+ def get_model_list():
38
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
39
+ assert ret.status_code == 200
40
+ ret = requests.post(args.controller_url + "/list_models")
41
+ models = ret.json()["models"]
42
+ models.sort(key=lambda x: priority.get(x, x))
43
+ logger.info(f"Models: {models}")
44
+ return models
45
+
46
+
47
+ get_window_url_params = """
48
+ function() {
49
+ const params = new URLSearchParams(window.location.search);
50
+ url_params = Object.fromEntries(params);
51
+ console.log(url_params);
52
+ return url_params;
53
+ }
54
+ """
55
+
56
+
57
+ def load_demo(url_params, request: gr.Request):
58
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
59
+
60
+ dropdown_update = gr.update(visible=True)
61
+ if "model" in url_params:
62
+ model = url_params["model"]
63
+ if model in models:
64
+ dropdown_update = gr.update(
65
+ value=model, visible=True)
66
+
67
+ state = default_conversation.copy()
68
+ return state, dropdown_update
69
+
70
+
71
+ def load_demo_refresh_model_list(request: gr.Request):
72
+ logger.info(f"load_demo. ip: {request.client.host}")
73
+ models = get_model_list()
74
+ state = default_conversation.copy()
75
+ dropdown_update = gr.update(
76
+ choices=models,
77
+ value=models[0] if len(models) > 0 else ""
78
+ )
79
+ return state, dropdown_update
80
+
81
+
82
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
83
+ with open(get_conv_log_filename(), "a") as fout:
84
+ data = {
85
+ "tstamp": round(time.time(), 4),
86
+ "type": vote_type,
87
+ "model": model_selector,
88
+ "state": state.dict(),
89
+ "ip": request.client.host,
90
+ }
91
+ fout.write(json.dumps(data) + "\n")
92
+
93
+
94
+ def upvote_last_response(state, model_selector, request: gr.Request):
95
+ logger.info(f"upvote. ip: {request.client.host}")
96
+ vote_last_response(state, "upvote", model_selector, request)
97
+ return ("",) + (disable_btn,) * 3
98
+
99
+
100
+ def downvote_last_response(state, model_selector, request: gr.Request):
101
+ logger.info(f"downvote. ip: {request.client.host}")
102
+ vote_last_response(state, "downvote", model_selector, request)
103
+ return ("",) + (disable_btn,) * 3
104
+
105
+
106
+ def flag_last_response(state, model_selector, request: gr.Request):
107
+ logger.info(f"flag. ip: {request.client.host}")
108
+ vote_last_response(state, "flag", model_selector, request)
109
+ return ("",) + (disable_btn,) * 3
110
+
111
+
112
+ def regenerate(state, image_process_mode, request: gr.Request):
113
+ logger.info(f"regenerate. ip: {request.client.host}")
114
+ state.messages[-1][-1] = None
115
+ prev_human_msg = state.messages[-2]
116
+ if type(prev_human_msg[1]) in (tuple, list):
117
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
118
+ state.skip_next = False
119
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
120
+
121
+
122
+ def clear_history(request: gr.Request):
123
+ logger.info(f"clear_history. ip: {request.client.host}")
124
+ state = default_conversation.copy()
125
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
126
+
127
+
128
+ def save_conversation(conversation):
129
+ print("save_conversation_wrapper is called")
130
+ html_content = "<html><body>"
131
+
132
+ for role, message in conversation.messages:
133
+ if isinstance(message, str): # only text
134
+ html_content += f"<p><b>{role}</b>: {message}</p>"
135
+ elif isinstance(message, tuple): # text+image
136
+ text, image_obj, _ = message
137
+
138
+ # add text
139
+ if text:
140
+ html_content += f"<p><b>{role}</b>: {text}</p>"
141
+
142
+ # add image
143
+ buffered = BytesIO()
144
+ image_obj.save(buffered, format="PNG")
145
+ encoded_image = base64.b64encode(buffered.getvalue()).decode()
146
+ html_content += f'<img src="data:image/png;base64,{encoded_image}" /><br>'
147
+
148
+ html_content += "</body></html>"
149
+
150
+ doc_path = "./conversation.docx"
151
+ pypandoc.convert_text(html_content, 'docx', format='html', outputfile=doc_path,
152
+ extra_args=["-M2GB", "+RTS", "-K64m", "-RTS"])
153
+ return doc_path
154
+
155
+
156
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
157
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
158
+ if len(text) <= 0 and image is None:
159
+ state.skip_next = True
160
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
161
+ if args.moderate:
162
+ flagged = violates_moderation(text)
163
+ if flagged:
164
+ state.skip_next = True
165
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
166
+ no_change_btn,) * 5
167
+
168
+ text = text[:1536] # Hard cut-off
169
+ if image is not None:
170
+ text = text[:1200] # Hard cut-off for images
171
+ if '<image>' not in text:
172
+ # text = '<Image><image></Image>' + text
173
+ text = text + '\n<image>'
174
+ text = (text, image, image_process_mode)
175
+ if len(state.get_images(return_pil=True)) > 0:
176
+ state = default_conversation.copy()
177
+ logger.info(f"Input Text: {text}")
178
+ state.append_message(state.roles[0], text)
179
+ state.append_message(state.roles[1], None)
180
+ state.skip_next = False
181
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
182
+
183
+
184
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
185
+ logger.info(f"http_bot. ip: {request.client.host}")
186
+ start_tstamp = time.time()
187
+ model_name = model_selector
188
+
189
+ if state.skip_next:
190
+ # This generate call is skipped due to invalid inputs
191
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
192
+ return
193
+
194
+ if len(state.messages) == state.offset + 2:
195
+ template_name = "bunny"
196
+ new_state = conv_templates[template_name].copy()
197
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
198
+ new_state.append_message(new_state.roles[1], None)
199
+ state = new_state
200
+
201
+ logger.info(f"Processed Input Text: {state.messages[-2][1]}")
202
+ # Query worker address
203
+ controller_url = args.controller_url
204
+ ret = requests.post(controller_url + "/get_worker_address",
205
+ json={"model": model_name})
206
+ worker_addr = ret.json()["address"]
207
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
208
+
209
+ # No available worker
210
+ if worker_addr == "":
211
+ state.messages[-1][-1] = server_error_msg
212
+ yield (state, state.to_gradio_chatbot(), enable_btn, enable_btn, enable_btn)
213
+ return
214
+
215
+ # Construct prompt
216
+ prompt = state.get_prompt()
217
+
218
+ all_images = state.get_images(return_pil=True)
219
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
220
+ for image, hash in zip(all_images, all_image_hash):
221
+ t = datetime.datetime.now()
222
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
223
+ if not os.path.isfile(filename):
224
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
225
+ image.save(filename)
226
+
227
+ # Make requests
228
+ pload = {
229
+ "model": model_name,
230
+ "prompt": prompt,
231
+ "temperature": float(temperature),
232
+ "top_p": float(top_p),
233
+ "max_new_tokens": min(int(max_new_tokens), 1536),
234
+ "stop": state.sep if state.sep_style in [SeparatorStyle.PLAIN, ] else state.sep2,
235
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
236
+ }
237
+ logger.info(f"==== request ====\n{pload}")
238
+
239
+ pload['images'] = state.get_images()
240
+ print('=========> get_images')
241
+ state.messages[-1][-1] = "▌"
242
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
243
+ print('=========> state', state.messages[-1][-1])
244
+
245
+ try:
246
+ # Stream output
247
+ response = requests.post(worker_addr + "/worker_generate_stream",
248
+ headers=headers, json=pload, stream=True, timeout=1000)
249
+ print("====> response ok")
250
+ print("====> response dir", dir(response))
251
+ print("====> response", response)
252
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
253
+ if chunk:
254
+ data = json.loads(chunk.decode())
255
+ if data["error_code"] == 0:
256
+ output = data["text"][len(prompt):].strip()
257
+ state.messages[-1][-1] = output + "▌"
258
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
259
+ else:
260
+ output = data["text"] + f" (error_code: {data['error_code']})"
261
+ state.messages[-1][-1] = output
262
+ yield (state, state.to_gradio_chatbot()) + (enable_btn, enable_btn, enable_btn)
263
+ return
264
+ time.sleep(0.03)
265
+ except requests.exceptions.RequestException as e:
266
+ state.messages[-1][-1] = server_error_msg
267
+ yield (state, state.to_gradio_chatbot()) + (enable_btn, enable_btn, enable_btn)
268
+ return
269
+
270
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
271
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
272
+
273
+ finish_tstamp = time.time()
274
+ logger.info(f"{output}")
275
+
276
+ with open(get_conv_log_filename(), "a") as fout:
277
+ data = {
278
+ "tstamp": round(finish_tstamp, 4),
279
+ "type": "chat",
280
+ "model": model_name,
281
+ "start": round(start_tstamp, 4),
282
+ "finish": round(finish_tstamp, 4),
283
+ "state": state.dict(),
284
+ "images": all_image_hash,
285
+ "ip": request.client.host,
286
+ }
287
+ fout.write(json.dumps(data) + "\n")
288
+
289
+
290
+ title_markdown = ("""
291
+ # 🐰 Bunny: A family of lightweight multimodal models
292
+
293
+ [📖[Technical report](https://arxiv.org/abs/2402.11530)] | [🏠[Code](https://github.com/BAAI-DCAI/Bunny)] | [🤗[Model](https://huggingface.co/BAAI/Bunny-v1_0-3B)]
294
+
295
+ """)
296
+
297
+ tos_markdown = ("""
298
+ ### Terms of use
299
+ By using this service, users are required to agree to the following terms:
300
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
301
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
302
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
303
+ """)
304
+
305
+ learn_more_markdown = ("""
306
+ ### License
307
+ This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses. The content of this project itself is licensed under the Apache license 2.0.
308
+ """)
309
+
310
+ block_css = """
311
+ .centered {
312
+ text-align: center;
313
+ }
314
+ #buttons button {
315
+ min-width: min(120px,100%);
316
+ }
317
+ #file-downloader {
318
+ min-width: min(120px,100%);
319
+ height: 50px;
320
+ }
321
+ """
322
+
323
+
324
+ def trigger_download(doc_path):
325
+ return doc_path
326
+
327
+
328
+ def build_demo(embed_mode):
329
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
330
+ with gr.Blocks(title="Bunny", theme=gr.themes.Default(primary_hue="blue", secondary_hue="lime"),
331
+ css=block_css) as demo:
332
+ state = gr.State()
333
+
334
+ if not embed_mode:
335
+ gr.Markdown(title_markdown)
336
+
337
+ with gr.Row():
338
+ with gr.Column(scale=4):
339
+ with gr.Row(elem_id="model_selector_row"):
340
+ model_selector = gr.Dropdown(
341
+ choices=models,
342
+ value=models[0] if len(models) > 0 else "",
343
+ interactive=True,
344
+ show_label=False,
345
+ container=False,
346
+ allow_custom_value=True
347
+ )
348
+
349
+ imagebox = gr.Image(type="pil")
350
+ image_process_mode = gr.Radio(
351
+ ["Crop", "Resize", "Pad", "Default"],
352
+ value="Default",
353
+ label="Preprocess for non-square image", visible=False)
354
+
355
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
356
+ gr.Examples(examples=[
357
+ [f"{cur_dir}/examples/example_1.png", "What is the astronaut holding in his hand?"],
358
+ [f"{cur_dir}/examples/example_2.png", "Why is the image funny?"],
359
+ ], inputs=[imagebox, textbox])
360
+
361
+ with gr.Accordion("Parameters", open=False) as parameter_row:
362
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
363
+ label="Temperature", )
364
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P", )
365
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True,
366
+ label="Max output tokens", )
367
+
368
+ file_output = gr.components.File(label="Download Document", visible=True, elem_id="file-downloader")
369
+ with gr.Column(scale=8):
370
+ chatbot = gr.Chatbot(elem_id="chatbot", label="Bunny Chatbot",
371
+ avatar_images=[f"{cur_dir}/examples/user.png", f"{cur_dir}/examples/icon.jpg"],
372
+ height=550)
373
+ with gr.Row():
374
+ with gr.Column(scale=8):
375
+ textbox.render()
376
+ with gr.Column(scale=1, min_width=50):
377
+ submit_btn = gr.Button(value="Send", variant="primary")
378
+
379
+ with gr.Row(elem_id="buttons") as button_row:
380
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
381
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
382
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
383
+ regenerate_btn = gr.Button(value="🔁 Regenerate", interactive=False)
384
+ clear_btn = gr.Button(value="🚮 Clear", interactive=False)
385
+ save_conversation_btn = gr.Button(value="🗃️ Save", interactive=False)
386
+
387
+ if not embed_mode:
388
+ gr.Markdown(tos_markdown)
389
+ gr.Markdown(learn_more_markdown)
390
+ url_params = gr.JSON(visible=False)
391
+
392
+ # Register listeners
393
+ btn_list = [upvote_btn, downvote_btn, regenerate_btn, clear_btn, save_conversation_btn]
394
+
395
+ upvote_btn.click(
396
+ upvote_last_response,
397
+ [state, model_selector],
398
+ [textbox, upvote_btn, downvote_btn]
399
+ )
400
+ downvote_btn.click(
401
+ downvote_last_response,
402
+ [state, model_selector],
403
+ [textbox, upvote_btn, downvote_btn]
404
+ )
405
+
406
+ regenerate_btn.click(
407
+ regenerate,
408
+ [state, image_process_mode],
409
+ [state, chatbot, textbox, imagebox] + btn_list,
410
+ queue=False
411
+ ).then(
412
+ http_bot,
413
+ [state, model_selector, temperature, top_p, max_output_tokens],
414
+ [state, chatbot] + btn_list
415
+ )
416
+
417
+ clear_btn.click(
418
+ clear_history,
419
+ None,
420
+ [state, chatbot, textbox, imagebox] + btn_list,
421
+ queue=False
422
+ )
423
+
424
+ save_conversation_btn.click(
425
+ save_conversation,
426
+ inputs=[state],
427
+ outputs=file_output
428
+ )
429
+
430
+ textbox.submit(
431
+ add_text,
432
+ [state, textbox, imagebox, image_process_mode],
433
+ [state, chatbot, textbox, imagebox] + btn_list,
434
+ queue=False
435
+ ).then(
436
+ http_bot,
437
+ [state, model_selector, temperature, top_p, max_output_tokens],
438
+ [state, chatbot] + btn_list
439
+ )
440
+
441
+ submit_btn.click(
442
+ add_text,
443
+ [state, textbox, imagebox, image_process_mode],
444
+ [state, chatbot, textbox, imagebox] + btn_list,
445
+ queue=False
446
+ ).then(
447
+ http_bot,
448
+ [state, model_selector, temperature, top_p, max_output_tokens],
449
+ [state, chatbot] + btn_list
450
+ )
451
+
452
+ if args.model_list_mode == "once":
453
+ demo.load(
454
+ load_demo,
455
+ [url_params],
456
+ [state, model_selector],
457
+ _js=get_window_url_params,
458
+ queue=False
459
+ )
460
+ elif args.model_list_mode == "reload":
461
+ demo.load(
462
+ load_demo_refresh_model_list,
463
+ None,
464
+ [state, model_selector],
465
+ queue=False
466
+ )
467
+ else:
468
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
469
+
470
+ return demo
471
+
472
+
473
+ if __name__ == "__main__":
474
+ parser = argparse.ArgumentParser()
475
+ parser.add_argument("--host", type=str, default="127.0.0.1")
476
+ parser.add_argument("--port", type=int)
477
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
478
+ parser.add_argument("--concurrency-count", type=int, default=10)
479
+ parser.add_argument("--model-list-mode", type=str, default="once",
480
+ choices=["once", "reload"])
481
+ parser.add_argument("--share", action="store_true")
482
+ parser.add_argument("--moderate", action="store_true")
483
+ parser.add_argument("--embed", action="store_true")
484
+ args = parser.parse_args()
485
+ logger.info(f"args: {args}")
486
+
487
+ models = get_model_list()
488
+ logger.info(args)
489
+ demo = build_demo(args.embed)
490
+ demo.launch(
491
+ server_name=args.host,
492
+ server_port=args.port,
493
+ share=args.share,
494
+ debug=True,
495
+ max_threads=10
496
+ )
serve/mm_utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import torch
3
+
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from transformers import StoppingCriteria
7
+
8
+ from .constants import IMAGE_TOKEN_INDEX
9
+
10
+
11
+ def load_image_from_base64(image):
12
+ return Image.open(BytesIO(base64.b64decode(image)))
13
+
14
+
15
+ def expand2square(pil_img, background_color):
16
+ width, height = pil_img.size
17
+ if width == height:
18
+ return pil_img
19
+ elif width > height:
20
+ result = Image.new(pil_img.mode, (width, width), background_color)
21
+ result.paste(pil_img, (0, (width - height) // 2))
22
+ return result
23
+ else:
24
+ result = Image.new(pil_img.mode, (height, height), background_color)
25
+ result.paste(pil_img, ((height - width) // 2, 0))
26
+ return result
27
+
28
+
29
+ def process_images(images, image_processor, model_cfg):
30
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
31
+ new_images = []
32
+ if image_aspect_ratio == 'pad':
33
+ for image in images:
34
+ image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean))
35
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
36
+ new_images.append(image)
37
+ else:
38
+ return image_processor(images, return_tensors='pt')['pixel_values']
39
+ if all(x.shape == new_images[0].shape for x in new_images):
40
+ new_images = torch.stack(new_images, dim=0)
41
+ return new_images
42
+
43
+
44
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
45
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
46
+
47
+ def insert_separator(X, sep):
48
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
49
+
50
+ input_ids = []
51
+ offset = 0
52
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
53
+ offset = 1
54
+ input_ids.append(prompt_chunks[0][0])
55
+
56
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
57
+ input_ids.extend(x[offset:])
58
+
59
+ if return_tensors is not None:
60
+ if return_tensors == 'pt':
61
+ return torch.tensor(input_ids, dtype=torch.long)
62
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
63
+ return input_ids
64
+
65
+
66
+ def get_model_name_from_path(model_path):
67
+ model_path = model_path.strip("/")
68
+ model_paths = model_path.split("/")
69
+ if model_paths[-1].startswith('checkpoint-'):
70
+ return model_paths[-2] + "_" + model_paths[-1]
71
+ else:
72
+ return model_paths[-1]
73
+
74
+
75
+ class KeywordsStoppingCriteria(StoppingCriteria):
76
+ def __init__(self, keywords, tokenizer, input_ids):
77
+ self.keywords = keywords
78
+ self.keyword_ids = []
79
+ self.max_keyword_len = 0
80
+ for keyword in keywords:
81
+ cur_keyword_ids = tokenizer(keyword).input_ids
82
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
83
+ cur_keyword_ids = cur_keyword_ids[1:]
84
+ if len(cur_keyword_ids) > self.max_keyword_len:
85
+ self.max_keyword_len = len(cur_keyword_ids)
86
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
87
+ self.tokenizer = tokenizer
88
+ self.start_len = input_ids.shape[1]
89
+
90
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
91
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
92
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
93
+ for keyword_id in self.keyword_ids:
94
+ truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
95
+ if torch.equal(truncated_output_ids, keyword_id):
96
+ return True
97
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
98
+ for keyword in self.keywords:
99
+ if keyword in outputs:
100
+ return True
101
+ return False
102
+
103
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
104
+ outputs = []
105
+ for i in range(output_ids.shape[0]):
106
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
107
+ return all(outputs)
serve/model_worker.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import json
4
+ import time
5
+ import threading
6
+ import uuid
7
+ import requests
8
+ import torch
9
+ import uvicorn
10
+ import transformers
11
+
12
+ from fastapi import FastAPI, Request, BackgroundTasks
13
+ from fastapi.responses import StreamingResponse
14
+ from functools import partial
15
+ from transformers import TextIteratorStreamer
16
+ from threading import Thread
17
+
18
+ from .constants import WORKER_HEART_BEAT_INTERVAL
19
+ from .utils import (build_logger, server_error_msg, pretty_print_semaphore)
20
+ from .builder import load_pretrained_model
21
+ from .mm_utils import process_images, load_image_from_base64, tokenizer_image_token, get_model_name_from_path, \
22
+ KeywordsStoppingCriteria
23
+ from .constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
24
+
25
+ GB = 1 << 30
26
+
27
+ worker_id = str(uuid.uuid4())[:6]
28
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
29
+ global_counter = 0
30
+
31
+ model_semaphore = None
32
+
33
+
34
+ def heart_beat_worker(controller):
35
+ while True:
36
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
37
+ controller.send_heart_beat()
38
+
39
+
40
+ class ModelWorker:
41
+ def __init__(self, controller_addr, worker_addr,
42
+ worker_id, no_register,
43
+ model_path, model_base, model_name, model_type,
44
+ load_8bit, load_4bit, device):
45
+ self.controller_addr = controller_addr
46
+ self.worker_addr = worker_addr
47
+ self.worker_id = worker_id
48
+ if model_path.endswith("/"):
49
+ model_path = model_path[:-1]
50
+ if model_name is None:
51
+ self.model_name = get_model_name_from_path(model_path)
52
+ else:
53
+ self.model_name = model_name
54
+
55
+ self.device = device
56
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
57
+ transformers.logging.disable_progress_bar()
58
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
59
+ model_path, model_base, self.model_name, model_type, load_8bit, load_4bit, device=self.device)
60
+ self.is_multimodal = True
61
+
62
+ if not no_register:
63
+ self.register_to_controller()
64
+ self.heart_beat_thread = threading.Thread(
65
+ target=heart_beat_worker, args=(self,))
66
+ self.heart_beat_thread.start()
67
+
68
+ def register_to_controller(self):
69
+ logger.info("Register to controller")
70
+
71
+ url = self.controller_addr + "/register_worker"
72
+ data = {
73
+ "worker_name": self.worker_addr,
74
+ "check_heart_beat": True,
75
+ "worker_status": self.get_status()
76
+ }
77
+ r = requests.post(url, json=data)
78
+ assert r.status_code == 200
79
+
80
+ def send_heart_beat(self):
81
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
82
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
83
+ f"global_counter: {global_counter}")
84
+
85
+ url = self.controller_addr + "/receive_heart_beat"
86
+
87
+ while True:
88
+ try:
89
+ ret = requests.post(url, json={
90
+ "worker_name": self.worker_addr,
91
+ "queue_length": self.get_queue_length()}, timeout=5)
92
+ exist = ret.json()["exist"]
93
+ break
94
+ except requests.exceptions.RequestException as e:
95
+ logger.error(f"heart beat error: {e}")
96
+ time.sleep(5)
97
+
98
+ if not exist:
99
+ self.register_to_controller()
100
+
101
+ def get_queue_length(self):
102
+ if model_semaphore is None:
103
+ return 0
104
+ else:
105
+ return args.limit_model_concurrency - model_semaphore._value + (len(
106
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
107
+
108
+ def get_status(self):
109
+ return {
110
+ "model_names": [self.model_name],
111
+ "speed": 1,
112
+ "queue_length": self.get_queue_length(),
113
+ }
114
+
115
+ @torch.inference_mode()
116
+ def generate_stream(self, params):
117
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
118
+
119
+ prompt = params["prompt"]
120
+ ori_prompt = prompt
121
+ images = params.get("images", None)
122
+ num_image_tokens = 0
123
+ if images is not None and len(images) > 0 and self.is_multimodal:
124
+ if len(images) > 0:
125
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
126
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
127
+
128
+ images = [load_image_from_base64(image) for image in images]
129
+ images = process_images(images, image_processor, model.config)
130
+ print(f"----> process_images {images}")
131
+ print(f"----> process_images sum {torch.sum(images)}")
132
+ if type(images) is list:
133
+ images = [image.to(self.model.device, dtype=model.dtype) for image in images]
134
+ else:
135
+ images = images.to(self.model.device, dtype=model.dtype)
136
+
137
+ replace_token = DEFAULT_IMAGE_TOKEN
138
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
139
+
140
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
141
+ else:
142
+ images = None
143
+ image_args = {"images": images}
144
+ else:
145
+ images = None
146
+ image_args = {}
147
+
148
+ temperature = float(params.get("temperature", 1.0))
149
+ top_p = float(params.get("top_p", 1.0))
150
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
151
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
152
+ stop_str = params.get("stop", None)
153
+ do_sample = True if temperature > 0.001 else False
154
+
155
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(
156
+ self.device)
157
+ keywords = [stop_str]
158
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
159
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
160
+
161
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
162
+
163
+ if max_new_tokens < 1:
164
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.",
165
+ "error_code": 0}).encode() + b"\0"
166
+ return
167
+ print("max_new_tokens", max_new_tokens)
168
+ print("start!")
169
+
170
+ thread = Thread(target=model.generate, kwargs=dict(
171
+ inputs=input_ids,
172
+ do_sample=do_sample,
173
+ temperature=temperature,
174
+ top_p=top_p,
175
+ max_new_tokens=max_new_tokens,
176
+ streamer=streamer,
177
+ stopping_criteria=[stopping_criteria],
178
+ use_cache=True,
179
+ **image_args
180
+ ))
181
+ thread.start()
182
+
183
+ generated_text = ori_prompt
184
+ for new_text in streamer:
185
+ if generated_text and not generated_text.endswith(' '):
186
+ generated_text += ' '
187
+ generated_text += new_text
188
+ if generated_text.endswith(stop_str):
189
+ generated_text = generated_text[:-len(stop_str)]
190
+ logger.info(f"new_text: {new_text}")
191
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
192
+
193
+ def generate_stream_gate(self, params):
194
+ try:
195
+ for x in self.generate_stream(params):
196
+ yield x
197
+ except ValueError as e:
198
+ print("Caught ValueError:", e)
199
+ ret = {
200
+ "text": server_error_msg,
201
+ "error_code": 1,
202
+ }
203
+ yield json.dumps(ret).encode() + b"\0"
204
+ except torch.cuda.CudaError as e:
205
+ print("Caught torch.cuda.CudaError:", e)
206
+ ret = {
207
+ "text": server_error_msg,
208
+ "error_code": 1,
209
+ }
210
+ yield json.dumps(ret).encode() + b"\0"
211
+ except Exception as e:
212
+ print("Caught Unknown Error", e)
213
+ ret = {
214
+ "text": server_error_msg,
215
+ "error_code": 1,
216
+ }
217
+ yield json.dumps(ret).encode() + b"\0"
218
+
219
+
220
+ app = FastAPI()
221
+
222
+
223
+ def release_model_semaphore(fn=None):
224
+ model_semaphore.release()
225
+ if fn is not None:
226
+ fn()
227
+
228
+
229
+ @app.post("/worker_generate_stream")
230
+ async def generate_stream(request: Request):
231
+ global model_semaphore, global_counter
232
+ global_counter += 1
233
+ params = await request.json()
234
+
235
+ if model_semaphore is None:
236
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
237
+ await model_semaphore.acquire()
238
+ worker.send_heart_beat()
239
+ generator = worker.generate_stream_gate(params)
240
+ background_tasks = BackgroundTasks()
241
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
242
+ return StreamingResponse(generator, background=background_tasks)
243
+
244
+
245
+ @app.post("/worker_get_status")
246
+ async def get_status(request: Request):
247
+ return worker.get_status()
248
+
249
+
250
+ if __name__ == "__main__":
251
+ parser = argparse.ArgumentParser()
252
+ parser.add_argument("--host", type=str, default="localhost")
253
+ parser.add_argument("--port", type=int, default=21002)
254
+ parser.add_argument("--worker-address", type=str,
255
+ default="http://localhost:21002")
256
+ parser.add_argument("--controller-address", type=str,
257
+ default="http://localhost:21001")
258
+ parser.add_argument("--model-path", type=str, default=None)
259
+ parser.add_argument("--model-base", type=str, default=None)
260
+ parser.add_argument("--model-name", type=str)
261
+ parser.add_argument("--model-type", type=str, default=None)
262
+ parser.add_argument("--device", type=str, default="cuda")
263
+ parser.add_argument("--multi-modal", action="store_true",
264
+ help="Multimodal mode is automatically detected with model name.")
265
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
266
+ parser.add_argument("--stream-interval", type=int, default=1)
267
+ parser.add_argument("--no-register", action="store_true")
268
+ parser.add_argument("--load-8bit", action="store_true")
269
+ parser.add_argument("--load-4bit", action="store_true")
270
+ args = parser.parse_args()
271
+ logger.info(f"args: {args}")
272
+
273
+ if args.multi_modal:
274
+ logger.warning("Multimodal mode is automatically detected with model name.")
275
+
276
+ worker = ModelWorker(args.controller_address,
277
+ args.worker_address,
278
+ worker_id,
279
+ args.no_register,
280
+ args.model_path,
281
+ args.model_base,
282
+ args.model_name,
283
+ args.model_type,
284
+ args.load_8bit,
285
+ args.load_4bit,
286
+ args.device)
287
+
288
+ log_config = uvicorn.config.LOGGING_CONFIG
289
+ log_config['handlers']['default']['stream'] = 'ext://sys.stdout'
290
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
serve/utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import logging.handlers
3
+ import os
4
+ import sys
5
+
6
+ from .constants import LOGDIR
7
+
8
+ server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
9
+ moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
10
+
11
+ handler = None
12
+
13
+
14
+ def disable_torch_init():
15
+ """
16
+ Disable the redundant torch default initialization to accelerate model creation.
17
+ """
18
+ import torch
19
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
20
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
21
+
22
+
23
+ def build_logger(logger_name, logger_filename):
24
+ global handler
25
+
26
+ formatter = logging.Formatter(
27
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
28
+ datefmt="%Y-%m-%d %H:%M:%S",
29
+ )
30
+
31
+ # Set the format of root handlers
32
+ if not logging.getLogger().handlers:
33
+ logging.basicConfig(level=logging.INFO)
34
+ logging.getLogger().handlers[0].setFormatter(formatter)
35
+
36
+ # Redirect stdout and stderr to loggers
37
+ stdout_logger = logging.getLogger("stdout")
38
+ stdout_logger.setLevel(logging.INFO)
39
+ sl = StreamToLogger(stdout_logger, logging.INFO)
40
+ sys.stdout = sl
41
+
42
+ stderr_logger = logging.getLogger("stderr")
43
+ stderr_logger.setLevel(logging.ERROR)
44
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
45
+ sys.stderr = sl
46
+
47
+ # Get logger
48
+ logger = logging.getLogger(logger_name)
49
+ logger.setLevel(logging.INFO)
50
+
51
+ # Add a file handler for all loggers
52
+ if handler is None:
53
+ os.makedirs(LOGDIR, exist_ok=True)
54
+ filename = os.path.join(LOGDIR, logger_filename)
55
+ handler = logging.handlers.TimedRotatingFileHandler(
56
+ filename, when='D', utc=True, encoding='UTF-8')
57
+ handler.setFormatter(formatter)
58
+
59
+ for name, item in logging.root.manager.loggerDict.items():
60
+ if isinstance(item, logging.Logger):
61
+ item.addHandler(handler)
62
+
63
+ return logger
64
+
65
+
66
+ class StreamToLogger(object):
67
+ """
68
+ Fake file-like stream object that redirects writes to a logger instance.
69
+ """
70
+
71
+ def __init__(self, logger, log_level=logging.INFO):
72
+ self.terminal = sys.stdout
73
+ self.logger = logger
74
+ self.log_level = log_level
75
+ self.linebuf = ''
76
+
77
+ def __getattr__(self, attr):
78
+ return getattr(self.terminal, attr)
79
+
80
+ def write(self, buf):
81
+ temp_linebuf = self.linebuf + buf
82
+ self.linebuf = ''
83
+ for line in temp_linebuf.splitlines(True):
84
+ # From the io.TextIOWrapper docs:
85
+ # On output, if newline is None, any '\n' characters written
86
+ # are translated to the system default line separator.
87
+ # By default sys.stdout.write() expects '\n' newlines and then
88
+ # translates them so this is still cross platform.
89
+ if line[-1] == '\n':
90
+ self.logger.log(self.log_level, line.rstrip())
91
+ else:
92
+ self.linebuf += line
93
+
94
+ def flush(self):
95
+ if self.linebuf != '':
96
+ self.logger.log(self.log_level, self.linebuf.rstrip())
97
+ self.linebuf = ''
98
+
99
+
100
+ def violates_moderation(text):
101
+ """
102
+ Check whether the text violates OpenAI moderation API.
103
+ """
104
+ url = "https://api.openai.com/v1/moderations"
105
+ headers = {"Content-Type": "application/json",
106
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
107
+ text = text.replace("\n", "")
108
+ data = "{" + '"input": ' + f'"{text}"' + "}"
109
+ data = data.encode("utf-8")
110
+ try:
111
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
112
+ flagged = ret.json()["results"][0]["flagged"]
113
+ except requests.exceptions.RequestException as e:
114
+ flagged = False
115
+ except KeyError as e:
116
+ flagged = False
117
+
118
+ return flagged
119
+
120
+
121
+ def pretty_print_semaphore(semaphore):
122
+ if semaphore is None:
123
+ return "None"
124
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"