Spaces:
Running
Running
Kang Suhyun
commited on
[#67] Check models at the start of the app (#68)
Browse files* [#67] Check models at the start of the app
This change adds a check for the models at the start of the app.
If the models are not available, the app will throw an error.
* Apply code review
* Update
* Refactor completion function and fix error handling
* Add verbose logging in litellm module
* Refactor model and response code
* Swap content and role order in completion messages
- app.py +3 -0
- model.py +77 -0
- response.py +18 -40
app.py
CHANGED
@@ -13,6 +13,8 @@ from credentials import set_credentials
|
|
13 |
from leaderboard import build_leaderboard
|
14 |
from leaderboard import db
|
15 |
from leaderboard import SUPPORTED_TRANSLATION_LANGUAGES
|
|
|
|
|
16 |
import response
|
17 |
from response import get_responses
|
18 |
|
@@ -189,6 +191,7 @@ with gr.Blocks(title="Arena", css=css) as app:
|
|
189 |
|
190 |
if __name__ == "__main__":
|
191 |
set_credentials(credentials.CREDENTIALS, credentials.CREDENTIALS_PATH)
|
|
|
192 |
|
193 |
# We need to enable queue to use generators.
|
194 |
app.queue()
|
|
|
13 |
from leaderboard import build_leaderboard
|
14 |
from leaderboard import db
|
15 |
from leaderboard import SUPPORTED_TRANSLATION_LANGUAGES
|
16 |
+
from model import check_models
|
17 |
+
from model import supported_models
|
18 |
import response
|
19 |
from response import get_responses
|
20 |
|
|
|
191 |
|
192 |
if __name__ == "__main__":
|
193 |
set_credentials(credentials.CREDENTIALS, credentials.CREDENTIALS_PATH)
|
194 |
+
check_models(supported_models)
|
195 |
|
196 |
# We need to enable queue to use generators.
|
197 |
app.queue()
|
model.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module contains functions to interact with the models.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
from typing import List
|
8 |
+
|
9 |
+
from google.cloud import secretmanager
|
10 |
+
from google.oauth2 import service_account
|
11 |
+
import litellm
|
12 |
+
|
13 |
+
from credentials import get_credentials_json
|
14 |
+
|
15 |
+
GOOGLE_CLOUD_PROJECT = os.environ.get("GOOGLE_CLOUD_PROJECT")
|
16 |
+
MODELS_SECRET = os.environ.get("MODELS_SECRET")
|
17 |
+
|
18 |
+
secretmanager_client = secretmanager.SecretManagerServiceClient(
|
19 |
+
credentials=service_account.Credentials.from_service_account_info(
|
20 |
+
get_credentials_json()))
|
21 |
+
models_secret = secretmanager_client.access_secret_version(
|
22 |
+
name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT,
|
23 |
+
MODELS_SECRET, "latest"))
|
24 |
+
decoded_secret = models_secret.payload.data.decode("UTF-8")
|
25 |
+
|
26 |
+
supported_models_json = json.loads(decoded_secret)
|
27 |
+
|
28 |
+
|
29 |
+
class Model:
|
30 |
+
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
name: str,
|
34 |
+
provider: str = None,
|
35 |
+
# The JSON keys are in camelCase. To unpack these keys into
|
36 |
+
# Model attributes, we need to use the same camelCase names.
|
37 |
+
apiKey: str = None, # pylint: disable=invalid-name
|
38 |
+
apiBase: str = None): # pylint: disable=invalid-name
|
39 |
+
self.name = name
|
40 |
+
self.provider = provider
|
41 |
+
self.api_key = apiKey
|
42 |
+
self.api_base = apiBase
|
43 |
+
|
44 |
+
|
45 |
+
supported_models: List[Model] = [
|
46 |
+
Model(name=model_name, **model_config)
|
47 |
+
for model_name, model_config in supported_models_json.items()
|
48 |
+
]
|
49 |
+
|
50 |
+
|
51 |
+
def completion(model: Model, messages: List, max_tokens: float = None) -> str:
|
52 |
+
response = litellm.completion(model=model.provider + "/" +
|
53 |
+
model.name if model.provider else model.name,
|
54 |
+
api_key=model.api_key,
|
55 |
+
api_base=model.api_base,
|
56 |
+
messages=messages,
|
57 |
+
max_tokens=max_tokens)
|
58 |
+
|
59 |
+
return response.choices[0].message.content
|
60 |
+
|
61 |
+
|
62 |
+
def check_models(models: List[Model]):
|
63 |
+
for model in models:
|
64 |
+
print(f"Checking model {model.name}...")
|
65 |
+
try:
|
66 |
+
completion(model=model,
|
67 |
+
messages=[{
|
68 |
+
"content": "Hello.",
|
69 |
+
"role": "user"
|
70 |
+
}],
|
71 |
+
max_tokens=5)
|
72 |
+
print(f"Model {model.name} is available.")
|
73 |
+
|
74 |
+
# This check is designed to verify the availability of the models
|
75 |
+
# without any issues. Therefore, we need to catch all exceptions.
|
76 |
+
except Exception as e: # pylint: disable=broad-except
|
77 |
+
raise RuntimeError(f"Model {model.name} is not available: {e}") from e
|
response.py
CHANGED
@@ -3,32 +3,17 @@ This module contains functions for generating responses using LLMs.
|
|
3 |
"""
|
4 |
|
5 |
import enum
|
6 |
-
import json
|
7 |
-
import os
|
8 |
from random import sample
|
|
|
9 |
from uuid import uuid4
|
10 |
|
11 |
from firebase_admin import firestore
|
12 |
-
from google.cloud import secretmanager
|
13 |
-
from google.oauth2 import service_account
|
14 |
import gradio as gr
|
15 |
-
from litellm import completion
|
16 |
|
17 |
-
from credentials import get_credentials_json
|
18 |
from leaderboard import db
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
secretmanager_client = secretmanager.SecretManagerServiceClient(
|
24 |
-
credentials=service_account.Credentials.from_service_account_info(
|
25 |
-
get_credentials_json()))
|
26 |
-
models_secret = secretmanager_client.access_secret_version(
|
27 |
-
name=secretmanager_client.secret_version_path(GOOGLE_CLOUD_PROJECT,
|
28 |
-
MODELS_SECRET, "latest"))
|
29 |
-
decoded_secret = models_secret.payload.data.decode("UTF-8")
|
30 |
-
|
31 |
-
supported_models = json.loads(decoded_secret)
|
32 |
|
33 |
|
34 |
def create_history(model_name: str, instruction: str, prompt: str,
|
@@ -69,42 +54,35 @@ def get_responses(user_prompt, category, source_lang, target_lang):
|
|
69 |
not target_lang):
|
70 |
raise gr.Error("Please select source and target languages.")
|
71 |
|
72 |
-
models = sample(list(supported_models), 2)
|
73 |
instruction = get_instruction(category, source_lang, target_lang)
|
74 |
|
75 |
responses = []
|
76 |
for model in models:
|
77 |
-
model_config = supported_models[model]
|
78 |
-
|
79 |
-
model_name = model_config[
|
80 |
-
"provider"] + "/" + model if "provider" in model_config else model
|
81 |
-
api_key = model_config.get("apiKey", None)
|
82 |
-
api_base = model_config.get("apiBase", None)
|
83 |
-
|
84 |
try:
|
85 |
# TODO(#1): Allow user to set configuration.
|
86 |
-
response = completion(model=
|
87 |
-
api_key=api_key,
|
88 |
-
api_base=api_base,
|
89 |
messages=[{
|
90 |
-
"
|
91 |
-
"
|
92 |
}, {
|
93 |
-
"
|
94 |
-
"
|
95 |
}])
|
96 |
-
|
97 |
-
|
98 |
-
responses.append(content)
|
99 |
|
100 |
# TODO(#1): Narrow down the exception type.
|
101 |
except Exception as e: # pylint: disable=broad-except
|
102 |
-
print(f"Error with model {model}: {e}")
|
103 |
raise gr.Error("Failed to get response. Please try again.")
|
104 |
|
|
|
|
|
105 |
# It simulates concurrent stream response generation.
|
106 |
max_response_length = max(len(response) for response in responses)
|
107 |
for i in range(max_response_length):
|
108 |
-
yield [response[:i + 1] for response in responses
|
|
|
109 |
|
110 |
-
yield responses +
|
|
|
3 |
"""
|
4 |
|
5 |
import enum
|
|
|
|
|
6 |
from random import sample
|
7 |
+
from typing import List
|
8 |
from uuid import uuid4
|
9 |
|
10 |
from firebase_admin import firestore
|
|
|
|
|
11 |
import gradio as gr
|
|
|
12 |
|
|
|
13 |
from leaderboard import db
|
14 |
+
from model import completion
|
15 |
+
from model import Model
|
16 |
+
from model import supported_models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
def create_history(model_name: str, instruction: str, prompt: str,
|
|
|
54 |
not target_lang):
|
55 |
raise gr.Error("Please select source and target languages.")
|
56 |
|
57 |
+
models: List[Model] = sample(list(supported_models), 2)
|
58 |
instruction = get_instruction(category, source_lang, target_lang)
|
59 |
|
60 |
responses = []
|
61 |
for model in models:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
try:
|
63 |
# TODO(#1): Allow user to set configuration.
|
64 |
+
response = completion(model=model,
|
|
|
|
|
65 |
messages=[{
|
66 |
+
"role": "system",
|
67 |
+
"content": instruction
|
68 |
}, {
|
69 |
+
"role": "user",
|
70 |
+
"content": user_prompt
|
71 |
}])
|
72 |
+
create_history(model.name, instruction, user_prompt, response)
|
73 |
+
responses.append(response)
|
|
|
74 |
|
75 |
# TODO(#1): Narrow down the exception type.
|
76 |
except Exception as e: # pylint: disable=broad-except
|
77 |
+
print(f"Error with model {model.name}: {e}")
|
78 |
raise gr.Error("Failed to get response. Please try again.")
|
79 |
|
80 |
+
model_names = [model.name for model in models]
|
81 |
+
|
82 |
# It simulates concurrent stream response generation.
|
83 |
max_response_length = max(len(response) for response in responses)
|
84 |
for i in range(max_response_length):
|
85 |
+
yield [response[:i + 1] for response in responses
|
86 |
+
] + model_names + [instruction]
|
87 |
|
88 |
+
yield responses + model_names + [instruction]
|