Kang Suhyun commited on
Commit
2a0aa5a
·
unverified ·
1 Parent(s): fa7ac61

[#67] Check models at the start of the app (#68)

Browse files

* [#67] Check models at the start of the app

This change adds a check for the models at the start of the app.
If the models are not available, the app will throw an error.

* Apply code review

* Update

* Refactor completion function and fix error handling

* Add verbose logging in litellm module

* Refactor model and response code

* Swap content and role order in completion messages

Files changed (3) hide show
  1. app.py +3 -0
  2. model.py +77 -0
  3. response.py +18 -40
app.py CHANGED
@@ -13,6 +13,8 @@ from credentials import set_credentials
13
  from leaderboard import build_leaderboard
14
  from leaderboard import db
15
  from leaderboard import SUPPORTED_TRANSLATION_LANGUAGES
 
 
16
  import response
17
  from response import get_responses
18
 
@@ -189,6 +191,7 @@ with gr.Blocks(title="Arena", css=css) as app:
189
 
190
  if __name__ == "__main__":
191
  set_credentials(credentials.CREDENTIALS, credentials.CREDENTIALS_PATH)
 
192
 
193
  # We need to enable queue to use generators.
194
  app.queue()
 
13
  from leaderboard import build_leaderboard
14
  from leaderboard import db
15
  from leaderboard import SUPPORTED_TRANSLATION_LANGUAGES
16
+ from model import check_models
17
+ from model import supported_models
18
  import response
19
  from response import get_responses
20
 
 
191
 
192
  if __name__ == "__main__":
193
  set_credentials(credentials.CREDENTIALS, credentials.CREDENTIALS_PATH)
194
+ check_models(supported_models)
195
 
196
  # We need to enable queue to use generators.
197
  app.queue()
model.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains functions to interact with the models.
3
+ """
4
+
5
+ import json
6
+ import os
7
+ from typing import List
8
+
9
+ from google.cloud import secretmanager
10
+ from google.oauth2 import service_account
11
+ import litellm
12
+
13
+ from credentials import get_credentials_json
14
+
15
+ GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT")
16
+ MODELS_SECRET = os.environ.get("MODELS_SECRET")
17
+
18
+ secretmanager_client = secretmanager.SecretManagerServiceClient(
19
+ credentials=service_account.Credentials.from_service_account_info(
20
+ get_credentials_json()))
21
+ models_secret = secretmanager_client.access_secret_version(
22
+ name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT,
23
+ MODELS_SECRET, "latest"))
24
+ decoded_secret = models_secret.payload.data.decode("UTF-8")
25
+
26
+ supported_models_json = json.loads(decoded_secret)
27
+
28
+
29
+ class Model:
30
+
31
+ def __init__(
32
+ self,
33
+ name: str,
34
+ provider: str = None,
35
+ # The JSON keys are in camelCase. To unpack these keys into
36
+ # Model attributes, we need to use the same camelCase names.
37
+ apiKey: str = None, # pylint: disable=invalid-name
38
+ apiBase: str = None): # pylint: disable=invalid-name
39
+ self.name = name
40
+ self.provider = provider
41
+ self.api_key = apiKey
42
+ self.api_base = apiBase
43
+
44
+
45
+ supported_models: List[Model] = [
46
+ Model(name=model_name, **model_config)
47
+ for model_name, model_config in supported_models_json.items()
48
+ ]
49
+
50
+
51
+ def completion(model: Model, messages: List, max_tokens: float = None) -> str:
52
+ response = litellm.completion(model=model.provider + "/" +
53
+ model.name if model.provider else model.name,
54
+ api_key=model.api_key,
55
+ api_base=model.api_base,
56
+ messages=messages,
57
+ max_tokens=max_tokens)
58
+
59
+ return response.choices[0].message.content
60
+
61
+
62
+ def check_models(models: List[Model]):
63
+ for model in models:
64
+ print(f"Checking model {model.name}...")
65
+ try:
66
+ completion(model=model,
67
+ messages=[{
68
+ "content": "Hello.",
69
+ "role": "user"
70
+ }],
71
+ max_tokens=5)
72
+ print(f"Model {model.name} is available.")
73
+
74
+ # This check is designed to verify the availability of the models
75
+ # without any issues. Therefore, we need to catch all exceptions.
76
+ except Exception as e: # pylint: disable=broad-except
77
+ raise RuntimeError(f"Model {model.name} is not available: {e}") from e
response.py CHANGED
@@ -3,32 +3,17 @@ This module contains functions for generating responses using LLMs.
3
  """
4
 
5
  import enum
6
- import json
7
- import os
8
  from random import sample
 
9
  from uuid import uuid4
10
 
11
  from firebase_admin import firestore
12
- from google.cloud import secretmanager
13
- from google.oauth2 import service_account
14
  import gradio as gr
15
- from litellm import completion
16
 
17
- from credentials import get_credentials_json
18
  from leaderboard import db
19
-
20
- GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT")
21
- MODELS_SECRET = os.environ.get("MODELS_SECRET")
22
-
23
- secretmanager_client = secretmanager.SecretManagerServiceClient(
24
- credentials=service_account.Credentials.from_service_account_info(
25
- get_credentials_json()))
26
- models_secret = secretmanager_client.access_secret_version(
27
- name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT,
28
- MODELS_SECRET, "latest"))
29
- decoded_secret = models_secret.payload.data.decode("UTF-8")
30
-
31
- supported_models = json.loads(decoded_secret)
32
 
33
 
34
  def create_history(model_name: str, instruction: str, prompt: str,
@@ -69,42 +54,35 @@ def get_responses(user_prompt, category, source_lang, target_lang):
69
  not target_lang):
70
  raise gr.Error("Please select source and target languages.")
71
 
72
- models = sample(list(supported_models), 2)
73
  instruction = get_instruction(category, source_lang, target_lang)
74
 
75
  responses = []
76
  for model in models:
77
- model_config = supported_models[model]
78
-
79
- model_name = model_config[
80
- "provider"] + "/" + model if "provider" in model_config else model
81
- api_key = model_config.get("apiKey", None)
82
- api_base = model_config.get("apiBase", None)
83
-
84
  try:
85
  # TODO(#1): Allow user to set configuration.
86
- response = completion(model=model_name,
87
- api_key=api_key,
88
- api_base=api_base,
89
  messages=[{
90
- "content": instruction,
91
- "role": "system"
92
  }, {
93
- "content": user_prompt,
94
- "role": "user"
95
  }])
96
- content = response.choices[0].message.content
97
- create_history(model, instruction, user_prompt, content)
98
- responses.append(content)
99
 
100
  # TODO(#1): Narrow down the exception type.
101
  except Exception as e: # pylint: disable=broad-except
102
- print(f"Error with model {model}: {e}")
103
  raise gr.Error("Failed to get response. Please try again.")
104
 
 
 
105
  # It simulates concurrent stream response generation.
106
  max_response_length = max(len(response) for response in responses)
107
  for i in range(max_response_length):
108
- yield [response[:i + 1] for response in responses] + models + [instruction]
 
109
 
110
- yield responses + models + [instruction]
 
3
  """
4
 
5
  import enum
 
 
6
  from random import sample
7
+ from typing import List
8
  from uuid import uuid4
9
 
10
  from firebase_admin import firestore
 
 
11
  import gradio as gr
 
12
 
 
13
  from leaderboard import db
14
+ from model import completion
15
+ from model import Model
16
+ from model import supported_models
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  def create_history(model_name: str, instruction: str, prompt: str,
 
54
  not target_lang):
55
  raise gr.Error("Please select source and target languages.")
56
 
57
+ models: List[Model] = sample(list(supported_models), 2)
58
  instruction = get_instruction(category, source_lang, target_lang)
59
 
60
  responses = []
61
  for model in models:
 
 
 
 
 
 
 
62
  try:
63
  # TODO(#1): Allow user to set configuration.
64
+ response = completion(model=model,
 
 
65
  messages=[{
66
+ "role": "system",
67
+ "content": instruction
68
  }, {
69
+ "role": "user",
70
+ "content": user_prompt
71
  }])
72
+ create_history(model.name, instruction, user_prompt, response)
73
+ responses.append(response)
 
74
 
75
  # TODO(#1): Narrow down the exception type.
76
  except Exception as e: # pylint: disable=broad-except
77
+ print(f"Error with model {model.name}: {e}")
78
  raise gr.Error("Failed to get response. Please try again.")
79
 
80
+ model_names = [model.name for model in models]
81
+
82
  # It simulates concurrent stream response generation.
83
  max_response_length = max(len(response) for response in responses)
84
  for i in range(max_response_length):
85
+ yield [response[:i + 1] for response in responses
86
+ ] + model_names + [instruction]
87
 
88
+ yield responses + model_names + [instruction]