Spaces:
Running
Running
wwwillchen
commited on
Commit
·
89542ff
1
Parent(s):
bb4138d
Completed - part 4
Browse files
claude.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import anthropic
|
2 |
+
from typing import Iterable
|
3 |
+
|
4 |
+
from data_model import ChatMessage, State
|
5 |
+
import mesop as me
|
6 |
+
|
7 |
+
def call_claude_sonnet(input: str, history: list[ChatMessage]) -> Iterable[str]:
|
8 |
+
state = me.state(State)
|
9 |
+
client = anthropic.Anthropic(api_key=state.claude_api_key)
|
10 |
+
messages = [
|
11 |
+
{
|
12 |
+
"role": "assistant" if message.role == "model" else message.role,
|
13 |
+
"content": message.content,
|
14 |
+
}
|
15 |
+
for message in history
|
16 |
+
] + [{"role": "user", "content": input}]
|
17 |
+
|
18 |
+
with client.messages.stream(
|
19 |
+
max_tokens=1024,
|
20 |
+
messages=messages,
|
21 |
+
model="claude-3-sonnet-20240229",
|
22 |
+
) as stream:
|
23 |
+
for text in stream.text_stream:
|
24 |
+
yield text
|
gemini.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import google.generativeai as genai
|
2 |
+
from typing import Iterable
|
3 |
+
|
4 |
+
from data_model import ChatMessage, State
|
5 |
+
import mesop as me
|
6 |
+
|
7 |
+
generation_config = {
|
8 |
+
"temperature": 1,
|
9 |
+
"top_p": 0.95,
|
10 |
+
"top_k": 64,
|
11 |
+
"max_output_tokens": 8192,
|
12 |
+
}
|
13 |
+
|
14 |
+
def configure_gemini():
|
15 |
+
state = me.state(State)
|
16 |
+
genai.configure(api_key=state.gemini_api_key)
|
17 |
+
|
18 |
+
def send_prompt_pro(prompt: str, history: list[ChatMessage]) -> Iterable[str]:
|
19 |
+
configure_gemini()
|
20 |
+
model = genai.GenerativeModel(
|
21 |
+
model_name="gemini-1.5-pro-latest",
|
22 |
+
generation_config=generation_config,
|
23 |
+
)
|
24 |
+
chat_session = model.start_chat(
|
25 |
+
history=[
|
26 |
+
{"role": message.role, "parts": [message.content]} for message in history
|
27 |
+
]
|
28 |
+
)
|
29 |
+
for chunk in chat_session.send_message(prompt, stream=True):
|
30 |
+
yield chunk.text
|
31 |
+
|
32 |
+
def send_prompt_flash(prompt: str, history: list[ChatMessage]) -> Iterable[str]:
|
33 |
+
configure_gemini()
|
34 |
+
model = genai.GenerativeModel(
|
35 |
+
model_name="gemini-1.5-flash-latest",
|
36 |
+
generation_config=generation_config,
|
37 |
+
)
|
38 |
+
chat_session = model.start_chat(
|
39 |
+
history=[
|
40 |
+
{"role": message.role, "parts": [message.content]} for message in history
|
41 |
+
]
|
42 |
+
)
|
43 |
+
for chunk in chat_session.send_message(prompt, stream=True):
|
44 |
+
yield chunk.text
|
main.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import mesop as me
|
2 |
-
from data_model import State, Models, ModelDialogState
|
3 |
from dialog import dialog, dialog_actions
|
|
|
|
|
4 |
|
5 |
def change_model_option(e: me.CheckboxChangeEvent):
|
6 |
s = me.state(ModelDialogState)
|
@@ -94,6 +96,31 @@ def page():
|
|
94 |
style=me.Style(font_size=20, margin=me.Margin(bottom=24)),
|
95 |
)
|
96 |
chat_input()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
def header():
|
99 |
with me.box(
|
@@ -167,6 +194,32 @@ def on_blur(e: me.InputBlurEvent):
|
|
167 |
|
168 |
def send_prompt(e: me.ClickEvent):
|
169 |
state = me.state(State)
|
170 |
-
|
171 |
-
|
|
|
|
|
172 |
state.input = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import mesop as me
|
2 |
+
from data_model import State, Models, ModelDialogState, Conversation, ChatMessage
|
3 |
from dialog import dialog, dialog_actions
|
4 |
+
import claude
|
5 |
+
import gemini
|
6 |
|
7 |
def change_model_option(e: me.CheckboxChangeEvent):
|
8 |
s = me.state(ModelDialogState)
|
|
|
96 |
style=me.Style(font_size=20, margin=me.Margin(bottom=24)),
|
97 |
)
|
98 |
chat_input()
|
99 |
+
display_conversations()
|
100 |
+
|
101 |
+
def display_conversations():
|
102 |
+
state = me.state(State)
|
103 |
+
for conversation in state.conversations:
|
104 |
+
with me.box(style=me.Style(margin=me.Margin(bottom=24))):
|
105 |
+
me.text(f"Model: {conversation.model}", style=me.Style(font_weight=500))
|
106 |
+
for message in conversation.messages:
|
107 |
+
display_message(message)
|
108 |
+
|
109 |
+
def display_message(message: ChatMessage):
|
110 |
+
style = me.Style(
|
111 |
+
padding=me.Padding.all(12),
|
112 |
+
border_radius=8,
|
113 |
+
margin=me.Margin(bottom=8),
|
114 |
+
)
|
115 |
+
if message.role == "user":
|
116 |
+
style.background = "#e7f2ff"
|
117 |
+
else:
|
118 |
+
style.background = "#ffffff"
|
119 |
+
|
120 |
+
with me.box(style=style):
|
121 |
+
me.markdown(message.content)
|
122 |
+
if message.in_progress:
|
123 |
+
me.progress_spinner()
|
124 |
|
125 |
def header():
|
126 |
with me.box(
|
|
|
194 |
|
195 |
def send_prompt(e: me.ClickEvent):
|
196 |
state = me.state(State)
|
197 |
+
if not state.conversations:
|
198 |
+
for model in state.models:
|
199 |
+
state.conversations.append(Conversation(model=model, messages=[]))
|
200 |
+
input = state.input
|
201 |
state.input = ""
|
202 |
+
|
203 |
+
for conversation in state.conversations:
|
204 |
+
model = conversation.model
|
205 |
+
messages = conversation.messages
|
206 |
+
history = messages[:]
|
207 |
+
messages.append(ChatMessage(role="user", content=input))
|
208 |
+
messages.append(ChatMessage(role="model", in_progress=True))
|
209 |
+
yield
|
210 |
+
|
211 |
+
if model == Models.GEMINI_1_5_FLASH.value:
|
212 |
+
llm_response = gemini.send_prompt_flash(input, history)
|
213 |
+
elif model == Models.GEMINI_1_5_PRO.value:
|
214 |
+
llm_response = gemini.send_prompt_pro(input, history)
|
215 |
+
elif model == Models.CLAUDE_3_5_SONNET.value:
|
216 |
+
llm_response = claude.call_claude_sonnet(input, history)
|
217 |
+
else:
|
218 |
+
raise Exception("Unhandled model", model)
|
219 |
+
|
220 |
+
for chunk in llm_response:
|
221 |
+
messages[-1].content += chunk
|
222 |
+
yield
|
223 |
+
messages[-1].in_progress = False
|
224 |
+
yield
|
225 |
+
|