Fabrice-TIERCELIN commited on
Commit
c6701f2
·
verified ·
1 Parent(s): 949b6f6

Upload 7 files

Browse files
llava/serve/__init__.py ADDED
File without changes
llava/serve/cli.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5
+ from llava.conversation import conv_templates, SeparatorStyle
6
+ from llava.model.builder import load_pretrained_model
7
+ from llava.utils import disable_torch_init
8
+ from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
9
+
10
+ from PIL import Image
11
+
12
+ import requests
13
+ from PIL import Image
14
+ from io import BytesIO
15
+ from transformers import TextStreamer
16
+
17
+
18
+ def load_image(image_file):
19
+ if image_file.startswith('http://') or image_file.startswith('https://'):
20
+ response = requests.get(image_file)
21
+ image = Image.open(BytesIO(response.content)).convert('RGB')
22
+ else:
23
+ image = Image.open(image_file).convert('RGB')
24
+ return image
25
+
26
+
27
+ def main(args):
28
+ # Model
29
+ disable_torch_init()
30
+
31
+ model_name = get_model_name_from_path(args.model_path)
32
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
33
+
34
+ if 'llama-2' in model_name.lower():
35
+ conv_mode = "llava_llama_2"
36
+ elif "v1" in model_name.lower():
37
+ conv_mode = "llava_v1"
38
+ elif "mpt" in model_name.lower():
39
+ conv_mode = "mpt"
40
+ else:
41
+ conv_mode = "llava_v0"
42
+
43
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
44
+ print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
45
+ else:
46
+ args.conv_mode = conv_mode
47
+
48
+ conv = conv_templates[args.conv_mode].copy()
49
+ if "mpt" in model_name.lower():
50
+ roles = ('user', 'assistant')
51
+ else:
52
+ roles = conv.roles
53
+
54
+ image = load_image(args.image_file)
55
+ # Similar operation in model_worker.py
56
+ image_tensor = process_images([image], image_processor, args)
57
+ if type(image_tensor) is list:
58
+ image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
59
+ else:
60
+ image_tensor = image_tensor.to(model.device, dtype=torch.float16)
61
+
62
+ while True:
63
+ try:
64
+ inp = input(f"{roles[0]}: ")
65
+ except EOFError:
66
+ inp = ""
67
+ if not inp:
68
+ print("exit...")
69
+ break
70
+
71
+ print(f"{roles[1]}: ", end="")
72
+
73
+ if image is not None:
74
+ # first message
75
+ if model.config.mm_use_im_start_end:
76
+ inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
77
+ else:
78
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
79
+ conv.append_message(conv.roles[0], inp)
80
+ image = None
81
+ else:
82
+ # later messages
83
+ conv.append_message(conv.roles[0], inp)
84
+ conv.append_message(conv.roles[1], None)
85
+ prompt = conv.get_prompt()
86
+
87
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
88
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
89
+ keywords = [stop_str]
90
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
91
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
92
+
93
+ with torch.inference_mode():
94
+ output_ids = model.generate(
95
+ input_ids,
96
+ images=image_tensor,
97
+ do_sample=True,
98
+ temperature=args.temperature,
99
+ max_new_tokens=args.max_new_tokens,
100
+ streamer=streamer,
101
+ use_cache=True,
102
+ stopping_criteria=[stopping_criteria])
103
+
104
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
105
+ conv.messages[-1][-1] = outputs
106
+
107
+ if args.debug:
108
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
109
+
110
+
111
+ if __name__ == "__main__":
112
+ parser = argparse.ArgumentParser()
113
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
114
+ parser.add_argument("--model-base", type=str, default=None)
115
+ parser.add_argument("--image-file", type=str, required=True)
116
+ parser.add_argument("--device", type=str, default="cuda")
117
+ parser.add_argument("--conv-mode", type=str, default=None)
118
+ parser.add_argument("--temperature", type=float, default=0.2)
119
+ parser.add_argument("--max-new-tokens", type=int, default=512)
120
+ parser.add_argument("--load-8bit", action="store_true")
121
+ parser.add_argument("--load-4bit", action="store_true")
122
+ parser.add_argument("--debug", action="store_true")
123
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
124
+ args = parser.parse_args()
125
+ main(args)
llava/serve/controller.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A controller manages distributed workers.
3
+ It sends worker addresses to clients.
4
+ """
5
+ import argparse
6
+ import asyncio
7
+ import dataclasses
8
+ from enum import Enum, auto
9
+ import json
10
+ import logging
11
+ import time
12
+ from typing import List, Union
13
+ import threading
14
+
15
+ from fastapi import FastAPI, Request
16
+ from fastapi.responses import StreamingResponse
17
+ import numpy as np
18
+ import requests
19
+ import uvicorn
20
+
21
+ from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22
+ from llava.utils import build_logger, server_error_msg
23
+
24
+
25
+ logger = build_logger("controller", "controller.log")
26
+
27
+
28
+ class DispatchMethod(Enum):
29
+ LOTTERY = auto()
30
+ SHORTEST_QUEUE = auto()
31
+
32
+ @classmethod
33
+ def from_str(cls, name):
34
+ if name == "lottery":
35
+ return cls.LOTTERY
36
+ elif name == "shortest_queue":
37
+ return cls.SHORTEST_QUEUE
38
+ else:
39
+ raise ValueError(f"Invalid dispatch method")
40
+
41
+
42
+ @dataclasses.dataclass
43
+ class WorkerInfo:
44
+ model_names: List[str]
45
+ speed: int
46
+ queue_length: int
47
+ check_heart_beat: bool
48
+ last_heart_beat: str
49
+
50
+
51
+ def heart_beat_controller(controller):
52
+ while True:
53
+ time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54
+ controller.remove_stable_workers_by_expiration()
55
+
56
+
57
+ class Controller:
58
+ def __init__(self, dispatch_method: str):
59
+ # Dict[str -> WorkerInfo]
60
+ self.worker_info = {}
61
+ self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62
+
63
+ self.heart_beat_thread = threading.Thread(
64
+ target=heart_beat_controller, args=(self,))
65
+ self.heart_beat_thread.start()
66
+
67
+ logger.info("Init controller")
68
+
69
+ def register_worker(self, worker_name: str, check_heart_beat: bool,
70
+ worker_status: dict):
71
+ if worker_name not in self.worker_info:
72
+ logger.info(f"Register a new worker: {worker_name}")
73
+ else:
74
+ logger.info(f"Register an existing worker: {worker_name}")
75
+
76
+ if not worker_status:
77
+ worker_status = self.get_worker_status(worker_name)
78
+ if not worker_status:
79
+ return False
80
+
81
+ self.worker_info[worker_name] = WorkerInfo(
82
+ worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83
+ check_heart_beat, time.time())
84
+
85
+ logger.info(f"Register done: {worker_name}, {worker_status}")
86
+ return True
87
+
88
+ def get_worker_status(self, worker_name: str):
89
+ try:
90
+ r = requests.post(worker_name + "/worker_get_status", timeout=5)
91
+ except requests.exceptions.RequestException as e:
92
+ logger.error(f"Get status fails: {worker_name}, {e}")
93
+ return None
94
+
95
+ if r.status_code != 200:
96
+ logger.error(f"Get status fails: {worker_name}, {r}")
97
+ return None
98
+
99
+ return r.json()
100
+
101
+ def remove_worker(self, worker_name: str):
102
+ del self.worker_info[worker_name]
103
+
104
+ def refresh_all_workers(self):
105
+ old_info = dict(self.worker_info)
106
+ self.worker_info = {}
107
+
108
+ for w_name, w_info in old_info.items():
109
+ if not self.register_worker(w_name, w_info.check_heart_beat, None):
110
+ logger.info(f"Remove stale worker: {w_name}")
111
+
112
+ def list_models(self):
113
+ model_names = set()
114
+
115
+ for w_name, w_info in self.worker_info.items():
116
+ model_names.update(w_info.model_names)
117
+
118
+ return list(model_names)
119
+
120
+ def get_worker_address(self, model_name: str):
121
+ if self.dispatch_method == DispatchMethod.LOTTERY:
122
+ worker_names = []
123
+ worker_speeds = []
124
+ for w_name, w_info in self.worker_info.items():
125
+ if model_name in w_info.model_names:
126
+ worker_names.append(w_name)
127
+ worker_speeds.append(w_info.speed)
128
+ worker_speeds = np.array(worker_speeds, dtype=np.float32)
129
+ norm = np.sum(worker_speeds)
130
+ if norm < 1e-4:
131
+ return ""
132
+ worker_speeds = worker_speeds / norm
133
+ if True: # Directly return address
134
+ pt = np.random.choice(np.arange(len(worker_names)),
135
+ p=worker_speeds)
136
+ worker_name = worker_names[pt]
137
+ return worker_name
138
+
139
+ # Check status before returning
140
+ while True:
141
+ pt = np.random.choice(np.arange(len(worker_names)),
142
+ p=worker_speeds)
143
+ worker_name = worker_names[pt]
144
+
145
+ if self.get_worker_status(worker_name):
146
+ break
147
+ else:
148
+ self.remove_worker(worker_name)
149
+ worker_speeds[pt] = 0
150
+ norm = np.sum(worker_speeds)
151
+ if norm < 1e-4:
152
+ return ""
153
+ worker_speeds = worker_speeds / norm
154
+ continue
155
+ return worker_name
156
+ elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157
+ worker_names = []
158
+ worker_qlen = []
159
+ for w_name, w_info in self.worker_info.items():
160
+ if model_name in w_info.model_names:
161
+ worker_names.append(w_name)
162
+ worker_qlen.append(w_info.queue_length / w_info.speed)
163
+ if len(worker_names) == 0:
164
+ return ""
165
+ min_index = np.argmin(worker_qlen)
166
+ w_name = worker_names[min_index]
167
+ self.worker_info[w_name].queue_length += 1
168
+ logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169
+ return w_name
170
+ else:
171
+ raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172
+
173
+ def receive_heart_beat(self, worker_name: str, queue_length: int):
174
+ if worker_name not in self.worker_info:
175
+ logger.info(f"Receive unknown heart beat. {worker_name}")
176
+ return False
177
+
178
+ self.worker_info[worker_name].queue_length = queue_length
179
+ self.worker_info[worker_name].last_heart_beat = time.time()
180
+ logger.info(f"Receive heart beat. {worker_name}")
181
+ return True
182
+
183
+ def remove_stable_workers_by_expiration(self):
184
+ expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185
+ to_delete = []
186
+ for worker_name, w_info in self.worker_info.items():
187
+ if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188
+ to_delete.append(worker_name)
189
+
190
+ for worker_name in to_delete:
191
+ self.remove_worker(worker_name)
192
+
193
+ def worker_api_generate_stream(self, params):
194
+ worker_addr = self.get_worker_address(params["model"])
195
+ if not worker_addr:
196
+ logger.info(f"no worker: {params['model']}")
197
+ ret = {
198
+ "text": server_error_msg,
199
+ "error_code": 2,
200
+ }
201
+ yield json.dumps(ret).encode() + b"\0"
202
+
203
+ try:
204
+ response = requests.post(worker_addr + "/worker_generate_stream",
205
+ json=params, stream=True, timeout=5)
206
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207
+ if chunk:
208
+ yield chunk + b"\0"
209
+ except requests.exceptions.RequestException as e:
210
+ logger.info(f"worker timeout: {worker_addr}")
211
+ ret = {
212
+ "text": server_error_msg,
213
+ "error_code": 3,
214
+ }
215
+ yield json.dumps(ret).encode() + b"\0"
216
+
217
+
218
+ # Let the controller act as a worker to achieve hierarchical
219
+ # management. This can be used to connect isolated sub networks.
220
+ def worker_api_get_status(self):
221
+ model_names = set()
222
+ speed = 0
223
+ queue_length = 0
224
+
225
+ for w_name in self.worker_info:
226
+ worker_status = self.get_worker_status(w_name)
227
+ if worker_status is not None:
228
+ model_names.update(worker_status["model_names"])
229
+ speed += worker_status["speed"]
230
+ queue_length += worker_status["queue_length"]
231
+
232
+ return {
233
+ "model_names": list(model_names),
234
+ "speed": speed,
235
+ "queue_length": queue_length,
236
+ }
237
+
238
+
239
+ app = FastAPI()
240
+
241
+
242
+ @app.post("/register_worker")
243
+ async def register_worker(request: Request):
244
+ data = await request.json()
245
+ controller.register_worker(
246
+ data["worker_name"], data["check_heart_beat"],
247
+ data.get("worker_status", None))
248
+
249
+
250
+ @app.post("/refresh_all_workers")
251
+ async def refresh_all_workers():
252
+ models = controller.refresh_all_workers()
253
+
254
+
255
+ @app.post("/list_models")
256
+ async def list_models():
257
+ models = controller.list_models()
258
+ return {"models": models}
259
+
260
+
261
+ @app.post("/get_worker_address")
262
+ async def get_worker_address(request: Request):
263
+ data = await request.json()
264
+ addr = controller.get_worker_address(data["model"])
265
+ return {"address": addr}
266
+
267
+
268
+ @app.post("/receive_heart_beat")
269
+ async def receive_heart_beat(request: Request):
270
+ data = await request.json()
271
+ exist = controller.receive_heart_beat(
272
+ data["worker_name"], data["queue_length"])
273
+ return {"exist": exist}
274
+
275
+
276
+ @app.post("/worker_generate_stream")
277
+ async def worker_api_generate_stream(request: Request):
278
+ params = await request.json()
279
+ generator = controller.worker_api_generate_stream(params)
280
+ return StreamingResponse(generator)
281
+
282
+
283
+ @app.post("/worker_get_status")
284
+ async def worker_api_get_status(request: Request):
285
+ return controller.worker_api_get_status()
286
+
287
+
288
+ if __name__ == "__main__":
289
+ parser = argparse.ArgumentParser()
290
+ parser.add_argument("--host", type=str, default="localhost")
291
+ parser.add_argument("--port", type=int, default=21001)
292
+ parser.add_argument("--dispatch-method", type=str, choices=[
293
+ "lottery", "shortest_queue"], default="shortest_queue")
294
+ args = parser.parse_args()
295
+ logger.info(f"args: {args}")
296
+
297
+ controller = Controller(args.dispatch_method)
298
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
llava/serve/gradio_web_server.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import time
6
+
7
+ import gradio as gr
8
+ import requests
9
+
10
+ from llava.conversation import (default_conversation, conv_templates,
11
+ SeparatorStyle)
12
+ from llava.constants import LOGDIR
13
+ from llava.utils import (build_logger, server_error_msg,
14
+ violates_moderation, moderation_msg)
15
+ import hashlib
16
+
17
+
18
+ logger = build_logger("gradio_web_server", "gradio_web_server.log")
19
+
20
+ headers = {"User-Agent": "LLaVA Client"}
21
+
22
+ no_change_btn = gr.Button.update()
23
+ enable_btn = gr.Button.update(interactive=True)
24
+ disable_btn = gr.Button.update(interactive=False)
25
+
26
+ priority = {
27
+ "vicuna-13b": "aaaaaaa",
28
+ "koala-13b": "aaaaaab",
29
+ }
30
+
31
+
32
+ def get_conv_log_filename():
33
+ t = datetime.datetime.now()
34
+ name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
35
+ return name
36
+
37
+
38
+ def get_model_list():
39
+ ret = requests.post(args.controller_url + "/refresh_all_workers")
40
+ assert ret.status_code == 200
41
+ ret = requests.post(args.controller_url + "/list_models")
42
+ models = ret.json()["models"]
43
+ models.sort(key=lambda x: priority.get(x, x))
44
+ logger.info(f"Models: {models}")
45
+ return models
46
+
47
+
48
+ get_window_url_params = """
49
+ function() {
50
+ const params = new URLSearchParams(window.location.search);
51
+ url_params = Object.fromEntries(params);
52
+ console.log(url_params);
53
+ return url_params;
54
+ }
55
+ """
56
+
57
+
58
+ def load_demo(url_params, request: gr.Request):
59
+ logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
60
+
61
+ dropdown_update = gr.Dropdown.update(visible=True)
62
+ if "model" in url_params:
63
+ model = url_params["model"]
64
+ if model in models:
65
+ dropdown_update = gr.Dropdown.update(
66
+ value=model, visible=True)
67
+
68
+ state = default_conversation.copy()
69
+ return state, dropdown_update
70
+
71
+
72
+ def load_demo_refresh_model_list(request: gr.Request):
73
+ logger.info(f"load_demo. ip: {request.client.host}")
74
+ models = get_model_list()
75
+ state = default_conversation.copy()
76
+ dropdown_update = gr.Dropdown.update(
77
+ choices=models,
78
+ value=models[0] if len(models) > 0 else ""
79
+ )
80
+ return state, dropdown_update
81
+
82
+
83
+ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
84
+ with open(get_conv_log_filename(), "a") as fout:
85
+ data = {
86
+ "tstamp": round(time.time(), 4),
87
+ "type": vote_type,
88
+ "model": model_selector,
89
+ "state": state.dict(),
90
+ "ip": request.client.host,
91
+ }
92
+ fout.write(json.dumps(data) + "\n")
93
+
94
+
95
+ def upvote_last_response(state, model_selector, request: gr.Request):
96
+ logger.info(f"upvote. ip: {request.client.host}")
97
+ vote_last_response(state, "upvote", model_selector, request)
98
+ return ("",) + (disable_btn,) * 3
99
+
100
+
101
+ def downvote_last_response(state, model_selector, request: gr.Request):
102
+ logger.info(f"downvote. ip: {request.client.host}")
103
+ vote_last_response(state, "downvote", model_selector, request)
104
+ return ("",) + (disable_btn,) * 3
105
+
106
+
107
+ def flag_last_response(state, model_selector, request: gr.Request):
108
+ logger.info(f"flag. ip: {request.client.host}")
109
+ vote_last_response(state, "flag", model_selector, request)
110
+ return ("",) + (disable_btn,) * 3
111
+
112
+
113
+ def regenerate(state, image_process_mode, request: gr.Request):
114
+ logger.info(f"regenerate. ip: {request.client.host}")
115
+ state.messages[-1][-1] = None
116
+ prev_human_msg = state.messages[-2]
117
+ if type(prev_human_msg[1]) in (tuple, list):
118
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
119
+ state.skip_next = False
120
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
121
+
122
+
123
+ def clear_history(request: gr.Request):
124
+ logger.info(f"clear_history. ip: {request.client.host}")
125
+ state = default_conversation.copy()
126
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
127
+
128
+
129
+ def add_text(state, text, image, image_process_mode, request: gr.Request):
130
+ logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
131
+ if len(text) <= 0 and image is None:
132
+ state.skip_next = True
133
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
134
+ if args.moderate:
135
+ flagged = violates_moderation(text)
136
+ if flagged:
137
+ state.skip_next = True
138
+ return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
139
+ no_change_btn,) * 5
140
+
141
+ text = text[:1536] # Hard cut-off
142
+ if image is not None:
143
+ text = text[:1200] # Hard cut-off for images
144
+ if '<image>' not in text:
145
+ # text = '<Image><image></Image>' + text
146
+ text = text + '\n<image>'
147
+ text = (text, image, image_process_mode)
148
+ if len(state.get_images(return_pil=True)) > 0:
149
+ state = default_conversation.copy()
150
+ state.append_message(state.roles[0], text)
151
+ state.append_message(state.roles[1], None)
152
+ state.skip_next = False
153
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
154
+
155
+
156
+ def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
157
+ logger.info(f"http_bot. ip: {request.client.host}")
158
+ start_tstamp = time.time()
159
+ model_name = model_selector
160
+
161
+ if state.skip_next:
162
+ # This generate call is skipped due to invalid inputs
163
+ yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
164
+ return
165
+
166
+ if len(state.messages) == state.offset + 2:
167
+ # First round of conversation
168
+ if "llava" in model_name.lower():
169
+ if 'llama-2' in model_name.lower():
170
+ template_name = "llava_llama_2"
171
+ elif "v1" in model_name.lower():
172
+ if 'mmtag' in model_name.lower():
173
+ template_name = "v1_mmtag"
174
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
175
+ template_name = "v1_mmtag"
176
+ else:
177
+ template_name = "llava_v1"
178
+ elif "mpt" in model_name.lower():
179
+ template_name = "mpt"
180
+ else:
181
+ if 'mmtag' in model_name.lower():
182
+ template_name = "v0_mmtag"
183
+ elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
184
+ template_name = "v0_mmtag"
185
+ else:
186
+ template_name = "llava_v0"
187
+ elif "mpt" in model_name:
188
+ template_name = "mpt_text"
189
+ elif "llama-2" in model_name:
190
+ template_name = "llama_2"
191
+ else:
192
+ template_name = "vicuna_v1"
193
+ new_state = conv_templates[template_name].copy()
194
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
195
+ new_state.append_message(new_state.roles[1], None)
196
+ state = new_state
197
+
198
+ # Query worker address
199
+ controller_url = args.controller_url
200
+ ret = requests.post(controller_url + "/get_worker_address",
201
+ json={"model": model_name})
202
+ worker_addr = ret.json()["address"]
203
+ logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
204
+
205
+ # No available worker
206
+ if worker_addr == "":
207
+ state.messages[-1][-1] = server_error_msg
208
+ yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
209
+ return
210
+
211
+ # Construct prompt
212
+ prompt = state.get_prompt()
213
+
214
+ all_images = state.get_images(return_pil=True)
215
+ all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
216
+ for image, hash in zip(all_images, all_image_hash):
217
+ t = datetime.datetime.now()
218
+ filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
219
+ if not os.path.isfile(filename):
220
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
221
+ image.save(filename)
222
+
223
+ # Make requests
224
+ pload = {
225
+ "model": model_name,
226
+ "prompt": prompt,
227
+ "temperature": float(temperature),
228
+ "top_p": float(top_p),
229
+ "max_new_tokens": min(int(max_new_tokens), 1536),
230
+ "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
231
+ "images": f'List of {len(state.get_images())} images: {all_image_hash}',
232
+ }
233
+ logger.info(f"==== request ====\n{pload}")
234
+
235
+ pload['images'] = state.get_images()
236
+
237
+ state.messages[-1][-1] = "▌"
238
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
239
+
240
+ try:
241
+ # Stream output
242
+ response = requests.post(worker_addr + "/worker_generate_stream",
243
+ headers=headers, json=pload, stream=True, timeout=10)
244
+ for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
245
+ if chunk:
246
+ data = json.loads(chunk.decode())
247
+ if data["error_code"] == 0:
248
+ output = data["text"][len(prompt):].strip()
249
+ state.messages[-1][-1] = output + "▌"
250
+ yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
251
+ else:
252
+ output = data["text"] + f" (error_code: {data['error_code']})"
253
+ state.messages[-1][-1] = output
254
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
255
+ return
256
+ time.sleep(0.03)
257
+ except requests.exceptions.RequestException as e:
258
+ state.messages[-1][-1] = server_error_msg
259
+ yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
260
+ return
261
+
262
+ state.messages[-1][-1] = state.messages[-1][-1][:-1]
263
+ yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
264
+
265
+ finish_tstamp = time.time()
266
+ logger.info(f"{output}")
267
+
268
+ with open(get_conv_log_filename(), "a") as fout:
269
+ data = {
270
+ "tstamp": round(finish_tstamp, 4),
271
+ "type": "chat",
272
+ "model": model_name,
273
+ "start": round(start_tstamp, 4),
274
+ "finish": round(start_tstamp, 4),
275
+ "state": state.dict(),
276
+ "images": all_image_hash,
277
+ "ip": request.client.host,
278
+ }
279
+ fout.write(json.dumps(data) + "\n")
280
+
281
+ title_markdown = ("""
282
+ # 🌋 LLaVA: Large Language and Vision Assistant
283
+ [[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)]
284
+ """)
285
+
286
+ tos_markdown = ("""
287
+ ### Terms of use
288
+ By using this service, users are required to agree to the following terms:
289
+ The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
290
+ Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
291
+ For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
292
+ """)
293
+
294
+
295
+ learn_more_markdown = ("""
296
+ ### License
297
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
298
+ """)
299
+
300
+ block_css = """
301
+
302
+ #buttons button {
303
+ min-width: min(120px,100%);
304
+ }
305
+
306
+ """
307
+
308
+ def build_demo(embed_mode):
309
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
310
+ with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
311
+ state = gr.State()
312
+
313
+ if not embed_mode:
314
+ gr.Markdown(title_markdown)
315
+
316
+ with gr.Row():
317
+ with gr.Column(scale=3):
318
+ with gr.Row(elem_id="model_selector_row"):
319
+ model_selector = gr.Dropdown(
320
+ choices=models,
321
+ value=models[0] if len(models) > 0 else "",
322
+ interactive=True,
323
+ show_label=False,
324
+ container=False)
325
+
326
+ imagebox = gr.Image(type="pil")
327
+ image_process_mode = gr.Radio(
328
+ ["Crop", "Resize", "Pad", "Default"],
329
+ value="Default",
330
+ label="Preprocess for non-square image", visible=False)
331
+
332
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
333
+ gr.Examples(examples=[
334
+ [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
335
+ [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
336
+ ], inputs=[imagebox, textbox])
337
+
338
+ with gr.Accordion("Parameters", open=False) as parameter_row:
339
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
340
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
341
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
342
+
343
+ with gr.Column(scale=8):
344
+ chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
345
+ with gr.Row():
346
+ with gr.Column(scale=8):
347
+ textbox.render()
348
+ with gr.Column(scale=1, min_width=50):
349
+ submit_btn = gr.Button(value="Send", variant="primary")
350
+ with gr.Row(elem_id="buttons") as button_row:
351
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
352
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
353
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
354
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
355
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
356
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
357
+
358
+ if not embed_mode:
359
+ gr.Markdown(tos_markdown)
360
+ gr.Markdown(learn_more_markdown)
361
+ url_params = gr.JSON(visible=False)
362
+
363
+ # Register listeners
364
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
365
+ upvote_btn.click(upvote_last_response,
366
+ [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
367
+ downvote_btn.click(downvote_last_response,
368
+ [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
369
+ flag_btn.click(flag_last_response,
370
+ [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
371
+ regenerate_btn.click(regenerate, [state, image_process_mode],
372
+ [state, chatbot, textbox, imagebox] + btn_list).then(
373
+ http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
374
+ [state, chatbot] + btn_list)
375
+ clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)
376
+
377
+ textbox.submit(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
378
+ ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
379
+ [state, chatbot] + btn_list)
380
+ submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
381
+ ).then(http_bot, [state, model_selector, temperature, top_p, max_output_tokens],
382
+ [state, chatbot] + btn_list)
383
+
384
+ if args.model_list_mode == "once":
385
+ demo.load(load_demo, [url_params], [state, model_selector],
386
+ _js=get_window_url_params)
387
+ elif args.model_list_mode == "reload":
388
+ demo.load(load_demo_refresh_model_list, None, [state, model_selector])
389
+ else:
390
+ raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
391
+
392
+ return demo
393
+
394
+
395
+ if __name__ == "__main__":
396
+ parser = argparse.ArgumentParser()
397
+ parser.add_argument("--host", type=str, default="0.0.0.0")
398
+ parser.add_argument("--port", type=int)
399
+ parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
400
+ parser.add_argument("--concurrency-count", type=int, default=10)
401
+ parser.add_argument("--model-list-mode", type=str, default="once",
402
+ choices=["once", "reload"])
403
+ parser.add_argument("--share", action="store_true")
404
+ parser.add_argument("--moderate", action="store_true")
405
+ parser.add_argument("--embed", action="store_true")
406
+ args = parser.parse_args()
407
+ logger.info(f"args: {args}")
408
+
409
+ models = get_model_list()
410
+
411
+ logger.info(args)
412
+ demo = build_demo(args.embed)
413
+ demo.queue(
414
+ concurrency_count=args.concurrency_count,
415
+ api_open=False
416
+ ).launch(
417
+ server_name=args.host,
418
+ server_port=args.port,
419
+ share=args.share
420
+ )
llava/serve/model_worker.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A model worker executes the model.
3
+ """
4
+ import argparse
5
+ import asyncio
6
+ import json
7
+ import time
8
+ import threading
9
+ import uuid
10
+
11
+ from fastapi import FastAPI, Request, BackgroundTasks
12
+ from fastapi.responses import StreamingResponse
13
+ import requests
14
+ import torch
15
+ import uvicorn
16
+ from functools import partial
17
+
18
+ from llava.constants import WORKER_HEART_BEAT_INTERVAL
19
+ from llava.utils import (build_logger, server_error_msg,
20
+ pretty_print_semaphore)
21
+ from llava.model.builder import load_pretrained_model
22
+ from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
23
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+ from transformers import TextIteratorStreamer
25
+ from threading import Thread
26
+
27
+
28
+ GB = 1 << 30
29
+
30
+ worker_id = str(uuid.uuid4())[:6]
31
+ logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
32
+ global_counter = 0
33
+
34
+ model_semaphore = None
35
+
36
+
37
+ def heart_beat_worker(controller):
38
+
39
+ while True:
40
+ time.sleep(WORKER_HEART_BEAT_INTERVAL)
41
+ controller.send_heart_beat()
42
+
43
+
44
+ class ModelWorker:
45
+ def __init__(self, controller_addr, worker_addr,
46
+ worker_id, no_register,
47
+ model_path, model_base, model_name,
48
+ load_8bit, load_4bit, device):
49
+ self.controller_addr = controller_addr
50
+ self.worker_addr = worker_addr
51
+ self.worker_id = worker_id
52
+ if model_path.endswith("/"):
53
+ model_path = model_path[:-1]
54
+ if model_name is None:
55
+ model_paths = model_path.split("/")
56
+ if model_paths[-1].startswith('checkpoint-'):
57
+ self.model_name = model_paths[-2] + "_" + model_paths[-1]
58
+ else:
59
+ self.model_name = model_paths[-1]
60
+ else:
61
+ self.model_name = model_name
62
+
63
+ self.device = device
64
+ logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
65
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
66
+ model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
67
+ self.is_multimodal = 'llava' in self.model_name.lower()
68
+
69
+ if not no_register:
70
+ self.register_to_controller()
71
+ self.heart_beat_thread = threading.Thread(
72
+ target=heart_beat_worker, args=(self,))
73
+ self.heart_beat_thread.start()
74
+
75
+ def register_to_controller(self):
76
+ logger.info("Register to controller")
77
+
78
+ url = self.controller_addr + "/register_worker"
79
+ data = {
80
+ "worker_name": self.worker_addr,
81
+ "check_heart_beat": True,
82
+ "worker_status": self.get_status()
83
+ }
84
+ r = requests.post(url, json=data)
85
+ assert r.status_code == 200
86
+
87
+ def send_heart_beat(self):
88
+ logger.info(f"Send heart beat. Models: {[self.model_name]}. "
89
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
90
+ f"global_counter: {global_counter}")
91
+
92
+ url = self.controller_addr + "/receive_heart_beat"
93
+
94
+ while True:
95
+ try:
96
+ ret = requests.post(url, json={
97
+ "worker_name": self.worker_addr,
98
+ "queue_length": self.get_queue_length()}, timeout=5)
99
+ exist = ret.json()["exist"]
100
+ break
101
+ except requests.exceptions.RequestException as e:
102
+ logger.error(f"heart beat error: {e}")
103
+ time.sleep(5)
104
+
105
+ if not exist:
106
+ self.register_to_controller()
107
+
108
+ def get_queue_length(self):
109
+ if model_semaphore is None:
110
+ return 0
111
+ else:
112
+ return args.limit_model_concurrency - model_semaphore._value + (len(
113
+ model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
114
+
115
+ def get_status(self):
116
+ return {
117
+ "model_names": [self.model_name],
118
+ "speed": 1,
119
+ "queue_length": self.get_queue_length(),
120
+ }
121
+
122
+ @torch.inference_mode()
123
+ def generate_stream(self, params):
124
+ tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
125
+
126
+ prompt = params["prompt"]
127
+ ori_prompt = prompt
128
+ images = params.get("images", None)
129
+ num_image_tokens = 0
130
+ if images is not None and len(images) > 0 and self.is_multimodal:
131
+ if len(images) > 0:
132
+ if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
133
+ raise ValueError("Number of images does not match number of <image> tokens in prompt")
134
+
135
+ images = [load_image_from_base64(image) for image in images]
136
+ images = process_images(images, image_processor, model.config)
137
+
138
+ if type(images) is list:
139
+ images = [image.to(self.model.device, dtype=torch.float16) for image in images]
140
+ else:
141
+ images = images.to(self.model.device, dtype=torch.float16)
142
+
143
+ replace_token = DEFAULT_IMAGE_TOKEN
144
+ if getattr(self.model.config, 'mm_use_im_start_end', False):
145
+ replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
146
+ prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
147
+
148
+ num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
149
+ else:
150
+ images = None
151
+ image_args = {"images": images}
152
+ else:
153
+ images = None
154
+ image_args = {}
155
+
156
+ temperature = float(params.get("temperature", 1.0))
157
+ top_p = float(params.get("top_p", 1.0))
158
+ max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
159
+ max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
160
+ stop_str = params.get("stop", None)
161
+ do_sample = True if temperature > 0.001 else False
162
+
163
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
164
+ keywords = [stop_str]
165
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
166
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
167
+
168
+ max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
169
+
170
+ if max_new_tokens < 1:
171
+ yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
172
+ return
173
+
174
+ thread = Thread(target=model.generate, kwargs=dict(
175
+ inputs=input_ids,
176
+ do_sample=do_sample,
177
+ temperature=temperature,
178
+ top_p=top_p,
179
+ max_new_tokens=max_new_tokens,
180
+ streamer=streamer,
181
+ stopping_criteria=[stopping_criteria],
182
+ use_cache=True,
183
+ **image_args
184
+ ))
185
+ thread.start()
186
+
187
+ generated_text = ori_prompt
188
+ for new_text in streamer:
189
+ generated_text += new_text
190
+ if generated_text.endswith(stop_str):
191
+ generated_text = generated_text[:-len(stop_str)]
192
+ yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
193
+
194
+ def generate_stream_gate(self, params):
195
+ try:
196
+ for x in self.generate_stream(params):
197
+ yield x
198
+ except ValueError as e:
199
+ print("Caught ValueError:", e)
200
+ ret = {
201
+ "text": server_error_msg,
202
+ "error_code": 1,
203
+ }
204
+ yield json.dumps(ret).encode() + b"\0"
205
+ except torch.cuda.CudaError as e:
206
+ print("Caught torch.cuda.CudaError:", e)
207
+ ret = {
208
+ "text": server_error_msg,
209
+ "error_code": 1,
210
+ }
211
+ yield json.dumps(ret).encode() + b"\0"
212
+ except Exception as e:
213
+ print("Caught Unknown Error", e)
214
+ ret = {
215
+ "text": server_error_msg,
216
+ "error_code": 1,
217
+ }
218
+ yield json.dumps(ret).encode() + b"\0"
219
+
220
+
221
+ app = FastAPI()
222
+
223
+
224
+ def release_model_semaphore(fn=None):
225
+ model_semaphore.release()
226
+ if fn is not None:
227
+ fn()
228
+
229
+
230
+ @app.post("/worker_generate_stream")
231
+ async def generate_stream(request: Request):
232
+ global model_semaphore, global_counter
233
+ global_counter += 1
234
+ params = await request.json()
235
+
236
+ if model_semaphore is None:
237
+ model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
238
+ await model_semaphore.acquire()
239
+ worker.send_heart_beat()
240
+ generator = worker.generate_stream_gate(params)
241
+ background_tasks = BackgroundTasks()
242
+ background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
243
+ return StreamingResponse(generator, background=background_tasks)
244
+
245
+
246
+ @app.post("/worker_get_status")
247
+ async def get_status(request: Request):
248
+ return worker.get_status()
249
+
250
+
251
+ if __name__ == "__main__":
252
+ parser = argparse.ArgumentParser()
253
+ parser.add_argument("--host", type=str, default="localhost")
254
+ parser.add_argument("--port", type=int, default=21002)
255
+ parser.add_argument("--worker-address", type=str,
256
+ default="http://localhost:21002")
257
+ parser.add_argument("--controller-address", type=str,
258
+ default="http://localhost:21001")
259
+ parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
260
+ parser.add_argument("--model-base", type=str, default=None)
261
+ parser.add_argument("--model-name", type=str)
262
+ parser.add_argument("--device", type=str, default="cuda")
263
+ parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
264
+ parser.add_argument("--limit-model-concurrency", type=int, default=5)
265
+ parser.add_argument("--stream-interval", type=int, default=1)
266
+ parser.add_argument("--no-register", action="store_true")
267
+ parser.add_argument("--load-8bit", action="store_true")
268
+ parser.add_argument("--load-4bit", action="store_true")
269
+ args = parser.parse_args()
270
+ logger.info(f"args: {args}")
271
+
272
+ if args.multi_modal:
273
+ logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
274
+
275
+ worker = ModelWorker(args.controller_address,
276
+ args.worker_address,
277
+ worker_id,
278
+ args.no_register,
279
+ args.model_path,
280
+ args.model_base,
281
+ args.model_name,
282
+ args.load_8bit,
283
+ args.load_4bit,
284
+ args.device)
285
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
llava/serve/register_worker.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Manually register workers.
3
+
4
+ Usage:
5
+ python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6
+ """
7
+
8
+ import argparse
9
+
10
+ import requests
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--controller-address", type=str)
15
+ parser.add_argument("--worker-name", type=str)
16
+ parser.add_argument("--check-heart-beat", action="store_true")
17
+ args = parser.parse_args()
18
+
19
+ url = args.controller_address + "/register_worker"
20
+ data = {
21
+ "worker_name": args.worker_name,
22
+ "check_heart_beat": args.check_heart_beat,
23
+ "worker_status": None,
24
+ }
25
+ r = requests.post(url, json=data)
26
+ assert r.status_code == 200
llava/serve/test_message.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ import requests
5
+
6
+ from llava.conversation import default_conversation
7
+
8
+
9
+ def main():
10
+ if args.worker_address:
11
+ worker_addr = args.worker_address
12
+ else:
13
+ controller_addr = args.controller_address
14
+ ret = requests.post(controller_addr + "/refresh_all_workers")
15
+ ret = requests.post(controller_addr + "/list_models")
16
+ models = ret.json()["models"]
17
+ models.sort()
18
+ print(f"Models: {models}")
19
+
20
+ ret = requests.post(controller_addr + "/get_worker_address",
21
+ json={"model": args.model_name})
22
+ worker_addr = ret.json()["address"]
23
+ print(f"worker_addr: {worker_addr}")
24
+
25
+ if worker_addr == "":
26
+ return
27
+
28
+ conv = default_conversation.copy()
29
+ conv.append_message(conv.roles[0], args.message)
30
+ prompt = conv.get_prompt()
31
+
32
+ headers = {"User-Agent": "LLaVA Client"}
33
+ pload = {
34
+ "model": args.model_name,
35
+ "prompt": prompt,
36
+ "max_new_tokens": args.max_new_tokens,
37
+ "temperature": 0.7,
38
+ "stop": conv.sep,
39
+ }
40
+ response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41
+ json=pload, stream=True)
42
+
43
+ print(prompt.replace(conv.sep, "\n"), end="")
44
+ for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45
+ if chunk:
46
+ data = json.loads(chunk.decode("utf-8"))
47
+ output = data["text"].split(conv.sep)[-1]
48
+ print(output, end="\r")
49
+ print("")
50
+
51
+
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55
+ parser.add_argument("--worker-address", type=str)
56
+ parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57
+ parser.add_argument("--max-new-tokens", type=int, default=32)
58
+ parser.add_argument("--message", type=str, default=
59
+ "Tell me a story with more than 1000 words.")
60
+ args = parser.parse_args()
61
+
62
+ main()