Jae-Won Chung commited on
Commit
7ed0b8b
·
unverified ·
1 Parent(s): 10ee5bf

Make one model selectable by user (#23)

Browse files
app.py CHANGED
@@ -363,6 +363,12 @@ Every benchmark is limited in some sense -- Before you interpret the results, pl
363
  controller_addr = os.environ["COLOSSEUM_CONTROLLER_ADDR"]
364
  global_controller_client = ControllerClient(controller_addr=controller_addr, timeout=15)
365
 
 
 
 
 
 
 
366
  # Colosseum helper functions.
367
  def enable_interact():
368
  return [gr.update(interactive=True)] * 2
@@ -394,21 +400,23 @@ def consumed_more_energy_message(energy_a, energy_b):
394
 
395
  # Colosseum event handlers
396
  def add_prompt_disable_submit(prompt, history_a, history_b):
397
- """Add the user's prompt to the two model's history and disable the submit button."""
398
  client = global_controller_client.fork()
399
  return [
400
  gr.Textbox.update(value=" ", interactive=False),
401
  gr.Button.update(interactive=False),
 
402
  history_a + [[prompt, ""]],
403
  history_b + [[prompt, ""]],
404
  client,
405
  ]
406
 
407
- def generate_responses(client: ControllerClient, history_a, history_b):
408
  """Generate responses for the two models."""
 
409
  for resp_a, resp_b in itertools.zip_longest(
410
- client.prompt(prompt=history_a[-1][0], index=0),
411
- client.prompt(prompt=history_b[-1][0], index=1),
412
  ):
413
  if resp_a is not None:
414
  history_a[-1][1] += resp_a
@@ -475,12 +483,14 @@ def play_again():
475
  return [
476
  # Clear chatbot history
477
  None, None,
478
- # Turn on prompt textbox and submit button
479
  gr.Textbox.update(value="", interactive=True), gr.Button.update(interactive=True),
480
  # Mask model names
481
  gr.Markdown.update(value="", visible=False), gr.Markdown.update(value="", visible=False),
482
  # Hide energy vote buttons and message
483
  gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Markdown.update(visible=False),
 
 
484
  # Disable reset button
485
  gr.Button.update(interactive=False, visible=False),
486
  ]
@@ -506,6 +516,14 @@ with gr.Blocks(css=custom_css) as block:
506
  with gr.TabItem("Colosseum ⚔️️"):
507
  gr.Markdown(open("docs/colosseum_top.md").read())
508
 
 
 
 
 
 
 
 
 
509
  with gr.Group():
510
  with gr.Row():
511
  prompt_input = gr.Textbox(
@@ -561,12 +579,12 @@ with gr.Blocks(css=custom_css) as block:
561
 
562
 
563
  (prompt_input
564
- .submit(add_prompt_disable_submit, [prompt_input, *chatbots], [prompt_input, prompt_submit_btn, *chatbots, controller_client], queue=False)
565
- .then(generate_responses, [controller_client, *chatbots], [*chatbots], queue=True, show_progress="hidden")
566
  .then(enable_interact, None, resp_vote_btn_list, queue=False))
567
  (prompt_submit_btn
568
- .click(add_prompt_disable_submit, [prompt_input, *chatbots], [prompt_input, prompt_submit_btn, *chatbots, controller_client], queue=False)
569
- .then(generate_responses, [controller_client, *chatbots], [*chatbots], queue=True, show_progress="hidden")
570
  .then(enable_interact, None, resp_vote_btn_list, queue=False))
571
 
572
  left_resp_vote_btn.click(
@@ -599,7 +617,7 @@ with gr.Blocks(css=custom_css) as block:
599
  .click(
600
  play_again,
601
  None,
602
- [*chatbots, prompt_input, prompt_submit_btn, *masked_model_names, *energy_vote_btn_list, energy_comparison_message, play_again_btn],
603
  queue=False,
604
  )
605
  .then(None, _js=focus_prompt_input_js, queue=False))
 
363
  controller_addr = os.environ["COLOSSEUM_CONTROLLER_ADDR"]
364
  global_controller_client = ControllerClient(controller_addr=controller_addr, timeout=15)
365
 
366
+ # Load the list of models. To reload, the app should be restarted.
367
+ available_models = global_controller_client.get_available_models()
368
+ model_preference_dropdown_choices = [f"One is {model}" for model in available_models]
369
+ model_preference_dropdown_choices = ["Two random models"] + model_preference_dropdown_choices
370
+ user_pref_to_model_name = dict(zip(model_preference_dropdown_choices, ["Random"] + available_models))
371
+
372
  # Colosseum helper functions.
373
  def enable_interact():
374
  return [gr.update(interactive=True)] * 2
 
400
 
401
  # Colosseum event handlers
402
  def add_prompt_disable_submit(prompt, history_a, history_b):
403
+ """Add the user's prompt to the two model's history and disable further submission."""
404
  client = global_controller_client.fork()
405
  return [
406
  gr.Textbox.update(value=" ", interactive=False),
407
  gr.Button.update(interactive=False),
408
+ gr.Dropdown.update(interactive=False),
409
  history_a + [[prompt, ""]],
410
  history_b + [[prompt, ""]],
411
  client,
412
  ]
413
 
414
+ def generate_responses(client: ControllerClient, user_preference, history_a, history_b):
415
  """Generate responses for the two models."""
416
+ model_preference = user_pref_to_model_name[user_preference]
417
  for resp_a, resp_b in itertools.zip_longest(
418
+ client.prompt(prompt=history_a[-1][0], index=0, model_preference=model_preference),
419
+ client.prompt(prompt=history_b[-1][0], index=1, model_preference=model_preference),
420
  ):
421
  if resp_a is not None:
422
  history_a[-1][1] += resp_a
 
483
  return [
484
  # Clear chatbot history
485
  None, None,
486
+ # Enable prompt textbox and submit button
487
  gr.Textbox.update(value="", interactive=True), gr.Button.update(interactive=True),
488
  # Mask model names
489
  gr.Markdown.update(value="", visible=False), gr.Markdown.update(value="", visible=False),
490
  # Hide energy vote buttons and message
491
  gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Markdown.update(visible=False),
492
+ # Enable model preference dropdown
493
+ gr.Dropdown.update(interactive=True),
494
  # Disable reset button
495
  gr.Button.update(interactive=False, visible=False),
496
  ]
 
516
  with gr.TabItem("Colosseum ⚔️️"):
517
  gr.Markdown(open("docs/colosseum_top.md").read())
518
 
519
+ with gr.Row():
520
+ model_preference_dropdown = gr.Dropdown(
521
+ choices=model_preference_dropdown_choices,
522
+ value=model_preference_dropdown_choices[0],
523
+ label="Prefer a specific model?",
524
+ interactive=True,
525
+ )
526
+
527
  with gr.Group():
528
  with gr.Row():
529
  prompt_input = gr.Textbox(
 
579
 
580
 
581
  (prompt_input
582
+ .submit(add_prompt_disable_submit, [prompt_input, *chatbots], [prompt_input, prompt_submit_btn, model_preference_dropdown, *chatbots, controller_client], queue=False)
583
+ .then(generate_responses, [controller_client, model_preference_dropdown, *chatbots], [*chatbots], queue=True, show_progress="hidden")
584
  .then(enable_interact, None, resp_vote_btn_list, queue=False))
585
  (prompt_submit_btn
586
+ .click(add_prompt_disable_submit, [prompt_input, *chatbots], [prompt_input, prompt_submit_btn, model_preference_dropdown, *chatbots, controller_client], queue=False)
587
+ .then(generate_responses, [controller_client, model_preference_dropdown, *chatbots], [*chatbots], queue=True, show_progress="hidden")
588
  .then(enable_interact, None, resp_vote_btn_list, queue=False))
589
 
590
  left_resp_vote_btn.click(
 
617
  .click(
618
  play_again,
619
  None,
620
+ [*chatbots, prompt_input, prompt_submit_btn, *masked_model_names, *energy_vote_btn_list, energy_comparison_message, model_preference_dropdown, play_again_btn],
621
  queue=False,
622
  )
623
  .then(None, _js=focus_prompt_input_js, queue=False))
spitfight/colosseum/client.py CHANGED
@@ -9,9 +9,11 @@ import requests
9
  import gradio as gr
10
 
11
  from spitfight.colosseum.common import (
 
12
  COLOSSEUM_PROMPT_ROUTE,
13
  COLOSSEUM_RESP_VOTE_ROUTE,
14
  COLOSSEUM_ENERGY_VOTE_ROUTE,
 
15
  PromptRequest,
16
  ResponseVoteRequest,
17
  ResponseVoteResponse,
@@ -37,9 +39,33 @@ class ControllerClient:
37
  request_id=uuid4(),
38
  )
39
 
40
- def prompt(self, prompt: str, index: Literal[0, 1]) -> Generator[str, None, None]:
41
- """Generate the response of the `index`th model with the prompt."""
42
- prompt_request = PromptRequest(request_id=self.request_id, prompt=prompt, model_index=index)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  with _catch_requests_exceptions():
44
  resp = requests.post(
45
  f"http://{self.controller_addr}{COLOSSEUM_PROMPT_ROUTE}",
 
9
  import gradio as gr
10
 
11
  from spitfight.colosseum.common import (
12
+ COLOSSEUM_MODELS_ROUTE,
13
  COLOSSEUM_PROMPT_ROUTE,
14
  COLOSSEUM_RESP_VOTE_ROUTE,
15
  COLOSSEUM_ENERGY_VOTE_ROUTE,
16
+ ModelsResponse,
17
  PromptRequest,
18
  ResponseVoteRequest,
19
  ResponseVoteResponse,
 
39
  request_id=uuid4(),
40
  )
41
 
42
+ def get_available_models(self) -> list[str]:
43
+ """Retrieve the list of available models."""
44
+ with _catch_requests_exceptions():
45
+ resp = requests.get(
46
+ f"http://{self.controller_addr}{COLOSSEUM_MODELS_ROUTE}",
47
+ timeout=self.timeout,
48
+ )
49
+ _check_response(resp)
50
+ return ModelsResponse(**resp.json()).available_models
51
+
52
+ def prompt(
53
+ self,
54
+ prompt: str,
55
+ index: Literal[0, 1],
56
+ model_preference: str,
57
+ ) -> Generator[str, None, None]:
58
+ """Generate the response of the `index`th model with the prompt.
59
+
60
+ `user_pref` is the user's preference for the model to use. It can be
61
+ `"random"` or one of the models in the list returned by `get_available_models`.
62
+ """
63
+ prompt_request = PromptRequest(
64
+ request_id=self.request_id,
65
+ prompt=prompt,
66
+ model_index=index,
67
+ model_preference=model_preference,
68
+ )
69
  with _catch_requests_exceptions():
70
  resp = requests.post(
71
  f"http://{self.controller_addr}{COLOSSEUM_PROMPT_ROUTE}",
spitfight/colosseum/common.py CHANGED
@@ -4,16 +4,22 @@ from typing import Literal
4
 
5
  from pydantic import BaseModel
6
 
 
7
  COLOSSEUM_PROMPT_ROUTE = "/prompt"
8
  COLOSSEUM_RESP_VOTE_ROUTE = "/response_vote"
9
  COLOSSEUM_ENERGY_VOTE_ROUTE = "/energy_vote"
10
  COLOSSEUM_HEALTH_ROUTE = "/health"
11
 
12
 
 
 
 
 
13
  class PromptRequest(BaseModel):
14
  request_id: str
15
  prompt: str
16
  model_index: Literal[0, 1]
 
17
 
18
 
19
  class ResponseVoteRequest(BaseModel):
 
4
 
5
  from pydantic import BaseModel
6
 
7
+ COLOSSEUM_MODELS_ROUTE = "/models"
8
  COLOSSEUM_PROMPT_ROUTE = "/prompt"
9
  COLOSSEUM_RESP_VOTE_ROUTE = "/response_vote"
10
  COLOSSEUM_ENERGY_VOTE_ROUTE = "/energy_vote"
11
  COLOSSEUM_HEALTH_ROUTE = "/health"
12
 
13
 
14
+ class ModelsResponse(BaseModel):
15
+ available_models: list[str]
16
+
17
+
18
  class PromptRequest(BaseModel):
19
  request_id: str
20
  prompt: str
21
  model_index: Literal[0, 1]
22
+ model_preference: str
23
 
24
 
25
  class ResponseVoteRequest(BaseModel):
spitfight/colosseum/controller/controller.py CHANGED
@@ -12,7 +12,7 @@ from pydantic import BaseModel, Field
12
  from spitfight.log import get_logger
13
  from spitfight.utils import BoundedExpiringDict, TokenGenerationBuffer, create_task
14
  from spitfight.colosseum.controller.worker import WorkerService
15
- from spitfight.prompt import get_system_prompt, apply_model_characteristics
16
 
17
  if TYPE_CHECKING:
18
  from spitfight.colosseum.controller.router import ControllerConfig
@@ -46,6 +46,7 @@ class RequestState(BaseModel):
46
  request_id: str
47
  model_names: list[str]
48
  raw_prompt: str
 
49
  responses: list[str] = ["UNSET", "UNSET"]
50
  model_prompts: list[str] = ["UNSET", "UNSET"]
51
  energy_consumptions: list[float] = [0.0, 0.0]
@@ -140,6 +141,14 @@ class Controller:
140
  prev_num_req_states - len(self.request_states),
141
  )
142
 
 
 
 
 
 
 
 
 
143
  def response_vote(self, request_id: str, victory_index: Literal[0, 1]) -> RequestState | None:
144
  """Record the user's response vote and return the new state."""
145
  if (state := self.request_states.get(request_id)) is not None:
@@ -165,16 +174,18 @@ class Controller:
165
  request_id: str,
166
  prompt: str,
167
  model_index: Literal[0, 1],
 
168
  ) -> AsyncGenerator[bytes, None]:
169
  # This method is called twice for the same request, once for each model.
170
  # If it's the first time this method is called, assign models to the request.
171
  if request_id not in self.request_states:
172
- workers = self.worker_service.choose_two()
173
  model_names = [worker.model_name for worker in workers]
174
  self.request_states[request_id] = RequestState(
175
  request_id=request_id,
176
  raw_prompt=prompt,
177
  model_names=model_names,
 
178
  )
179
  request_state = self.request_states[request_id]
180
  model_name = request_state.model_names[model_index]
 
12
  from spitfight.log import get_logger
13
  from spitfight.utils import BoundedExpiringDict, TokenGenerationBuffer, create_task
14
  from spitfight.colosseum.controller.worker import WorkerService
15
+ from spitfight.prompt import apply_model_characteristics
16
 
17
  if TYPE_CHECKING:
18
  from spitfight.colosseum.controller.router import ControllerConfig
 
46
  request_id: str
47
  model_names: list[str]
48
  raw_prompt: str
49
+ model_preference: str
50
  responses: list[str] = ["UNSET", "UNSET"]
51
  model_prompts: list[str] = ["UNSET", "UNSET"]
52
  energy_consumptions: list[float] = [0.0, 0.0]
 
141
  prev_num_req_states - len(self.request_states),
142
  )
143
 
144
+ def get_available_models(self) -> list[str]:
145
+ """Return the names of available models."""
146
+ return [
147
+ worker.model_name
148
+ for worker in self.worker_service.workers
149
+ if worker.status == "up"
150
+ ]
151
+
152
  def response_vote(self, request_id: str, victory_index: Literal[0, 1]) -> RequestState | None:
153
  """Record the user's response vote and return the new state."""
154
  if (state := self.request_states.get(request_id)) is not None:
 
174
  request_id: str,
175
  prompt: str,
176
  model_index: Literal[0, 1],
177
+ model_preference: str,
178
  ) -> AsyncGenerator[bytes, None]:
179
  # This method is called twice for the same request, once for each model.
180
  # If it's the first time this method is called, assign models to the request.
181
  if request_id not in self.request_states:
182
+ workers = self.worker_service.choose_based_on_preference(model_preference)
183
  model_names = [worker.model_name for worker in workers]
184
  self.request_states[request_id] = RequestState(
185
  request_id=request_id,
186
  raw_prompt=prompt,
187
  model_names=model_names,
188
+ model_preference=model_preference,
189
  )
190
  request_state = self.request_states[request_id]
191
  model_name = request_state.model_names[model_index]
spitfight/colosseum/controller/router.py CHANGED
@@ -10,10 +10,12 @@ from text_generation.errors import OverloadedError, UnknownError, ValidationErro
10
 
11
  from spitfight.log import get_logger, init_queued_root_logger, shutdown_queued_root_loggers
12
  from spitfight.colosseum.common import (
 
13
  COLOSSEUM_PROMPT_ROUTE,
14
  COLOSSEUM_RESP_VOTE_ROUTE,
15
  COLOSSEUM_ENERGY_VOTE_ROUTE,
16
  COLOSSEUM_HEALTH_ROUTE,
 
17
  PromptRequest,
18
  ResponseVoteRequest,
19
  ResponseVoteResponse,
@@ -67,12 +69,21 @@ async def shutdown_event():
67
  get_global_controller().shutdown()
68
  shutdown_queued_root_loggers()
69
 
 
 
 
 
70
  @app.post(COLOSSEUM_PROMPT_ROUTE)
71
  async def prompt(
72
  request: PromptRequest,
73
  controller: Controller = Depends(get_global_controller),
74
  ):
75
- generator = controller.prompt(request.request_id, request.prompt, request.model_index)
 
 
 
 
 
76
 
77
  # First try to get the first token in order to catch TGI errors.
78
  try:
 
10
 
11
  from spitfight.log import get_logger, init_queued_root_logger, shutdown_queued_root_loggers
12
  from spitfight.colosseum.common import (
13
+ COLOSSEUM_MODELS_ROUTE,
14
  COLOSSEUM_PROMPT_ROUTE,
15
  COLOSSEUM_RESP_VOTE_ROUTE,
16
  COLOSSEUM_ENERGY_VOTE_ROUTE,
17
  COLOSSEUM_HEALTH_ROUTE,
18
+ ModelsResponse,
19
  PromptRequest,
20
  ResponseVoteRequest,
21
  ResponseVoteResponse,
 
69
  get_global_controller().shutdown()
70
  shutdown_queued_root_loggers()
71
 
72
+ @app.get(COLOSSEUM_MODELS_ROUTE, response_model=ModelsResponse)
73
+ async def models(controller: Controller = Depends(get_global_controller)):
74
+ return ModelsResponse(available_models=controller.get_available_models())
75
+
76
  @app.post(COLOSSEUM_PROMPT_ROUTE)
77
  async def prompt(
78
  request: PromptRequest,
79
  controller: Controller = Depends(get_global_controller),
80
  ):
81
+ generator = controller.prompt(
82
+ request.request_id,
83
+ request.prompt,
84
+ request.model_index,
85
+ request.model_preference,
86
+ )
87
 
88
  # First try to get the first token in order to catch TGI errors.
89
  try:
spitfight/colosseum/controller/worker.py CHANGED
@@ -19,7 +19,7 @@ class Worker(BaseModel):
19
  hostname: str
20
  # For TGI, this would always be 80.
21
  port: int
22
- # User-friendly model name, e.g. "metaai/llama2-13b-chat".
23
  model_name: str
24
  # Hugging Face model ID, e.g. "metaai/Llama-2-13b-chat-hf".
25
  model_id: str
@@ -146,6 +146,21 @@ class WorkerService:
146
  worker_a, worker_b = random.sample(live_workers, 2)
147
  return worker_a, worker_b
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  async def check_workers(self) -> None:
150
  """Check the status of all workers."""
151
  await asyncio.gather(*[worker.check_status() for worker in self.workers])
 
19
  hostname: str
20
  # For TGI, this would always be 80.
21
  port: int
22
+ # User-friendly model name, e.g. "Llama2-7B".
23
  model_name: str
24
  # Hugging Face model ID, e.g. "metaai/Llama-2-13b-chat-hf".
25
  model_id: str
 
146
  worker_a, worker_b = random.sample(live_workers, 2)
147
  return worker_a, worker_b
148
 
149
+ def choose_based_on_preference(self, preference: str) -> tuple[Worker, Worker]:
150
+ """Choose two different workers based on user preference.
151
+
152
+ Specifically, if `preference` is `"Random"`, this is equivalent to
153
+ choosing two models at random. Otherwise, if `preference` is a model
154
+ name, this is equivalent to choosing that model and another model at
155
+ random. In that case, the order of the two models is also randomized.
156
+ """
157
+ if preference == "Random":
158
+ return self.choose_two()
159
+ else:
160
+ worker_a = self.get_worker(preference)
161
+ worker_b = random.choice([worker for worker in self.workers if worker != worker_a])
162
+ return tuple(random.sample([worker_a, worker_b], 2))
163
+
164
  async def check_workers(self) -> None:
165
  """Check the status of all workers."""
166
  await asyncio.gather(*[worker.check_status() for worker in self.workers])