wwwillchen commited on
Commit
89542ff
·
1 Parent(s): bb4138d

Completed - part 4

Browse files
Files changed (3) hide show
  1. claude.py +24 -0
  2. gemini.py +44 -0
  3. main.py +56 -3
claude.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import anthropic
2
+ from typing import Iterable
3
+
4
+ from data_model import ChatMessage, State
5
+ import mesop as me
6
+
7
+ def call_claude_sonnet(input: str, history: list[ChatMessage]) -> Iterable[str]:
8
+ state = me.state(State)
9
+ client = anthropic.Anthropic(api_key=state.claude_api_key)
10
+ messages = [
11
+ {
12
+ "role": "assistant" if message.role == "model" else message.role,
13
+ "content": message.content,
14
+ }
15
+ for message in history
16
+ ] + [{"role": "user", "content": input}]
17
+
18
+ with client.messages.stream(
19
+ max_tokens=1024,
20
+ messages=messages,
21
+ model="claude-3-sonnet-20240229",
22
+ ) as stream:
23
+ for text in stream.text_stream:
24
+ yield text
gemini.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import google.generativeai as genai
2
+ from typing import Iterable
3
+
4
+ from data_model import ChatMessage, State
5
+ import mesop as me
6
+
7
+ generation_config = {
8
+ "temperature": 1,
9
+ "top_p": 0.95,
10
+ "top_k": 64,
11
+ "max_output_tokens": 8192,
12
+ }
13
+
14
+ def configure_gemini():
15
+ state = me.state(State)
16
+ genai.configure(api_key=state.gemini_api_key)
17
+
18
+ def send_prompt_pro(prompt: str, history: list[ChatMessage]) -> Iterable[str]:
19
+ configure_gemini()
20
+ model = genai.GenerativeModel(
21
+ model_name="gemini-1.5-pro-latest",
22
+ generation_config=generation_config,
23
+ )
24
+ chat_session = model.start_chat(
25
+ history=[
26
+ {"role": message.role, "parts": [message.content]} for message in history
27
+ ]
28
+ )
29
+ for chunk in chat_session.send_message(prompt, stream=True):
30
+ yield chunk.text
31
+
32
+ def send_prompt_flash(prompt: str, history: list[ChatMessage]) -> Iterable[str]:
33
+ configure_gemini()
34
+ model = genai.GenerativeModel(
35
+ model_name="gemini-1.5-flash-latest",
36
+ generation_config=generation_config,
37
+ )
38
+ chat_session = model.start_chat(
39
+ history=[
40
+ {"role": message.role, "parts": [message.content]} for message in history
41
+ ]
42
+ )
43
+ for chunk in chat_session.send_message(prompt, stream=True):
44
+ yield chunk.text
main.py CHANGED
@@ -1,6 +1,8 @@
1
  import mesop as me
2
- from data_model import State, Models, ModelDialogState
3
  from dialog import dialog, dialog_actions
 
 
4
 
5
  def change_model_option(e: me.CheckboxChangeEvent):
6
  s = me.state(ModelDialogState)
@@ -94,6 +96,31 @@ def page():
94
  style=me.Style(font_size=20, margin=me.Margin(bottom=24)),
95
  )
96
  chat_input()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def header():
99
  with me.box(
@@ -167,6 +194,32 @@ def on_blur(e: me.InputBlurEvent):
167
 
168
  def send_prompt(e: me.ClickEvent):
169
  state = me.state(State)
170
- print(f"Sending prompt: {state.input}")
171
- print(f"Selected models: {state.models}")
 
 
172
  state.input = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import mesop as me
2
+ from data_model import State, Models, ModelDialogState, Conversation, ChatMessage
3
  from dialog import dialog, dialog_actions
4
+ import claude
5
+ import gemini
6
 
7
  def change_model_option(e: me.CheckboxChangeEvent):
8
  s = me.state(ModelDialogState)
 
96
  style=me.Style(font_size=20, margin=me.Margin(bottom=24)),
97
  )
98
  chat_input()
99
+ display_conversations()
100
+
101
+ def display_conversations():
102
+ state = me.state(State)
103
+ for conversation in state.conversations:
104
+ with me.box(style=me.Style(margin=me.Margin(bottom=24))):
105
+ me.text(f"Model: {conversation.model}", style=me.Style(font_weight=500))
106
+ for message in conversation.messages:
107
+ display_message(message)
108
+
109
+ def display_message(message: ChatMessage):
110
+ style = me.Style(
111
+ padding=me.Padding.all(12),
112
+ border_radius=8,
113
+ margin=me.Margin(bottom=8),
114
+ )
115
+ if message.role == "user":
116
+ style.background = "#e7f2ff"
117
+ else:
118
+ style.background = "#ffffff"
119
+
120
+ with me.box(style=style):
121
+ me.markdown(message.content)
122
+ if message.in_progress:
123
+ me.progress_spinner()
124
 
125
  def header():
126
  with me.box(
 
194
 
195
  def send_prompt(e: me.ClickEvent):
196
  state = me.state(State)
197
+ if not state.conversations:
198
+ for model in state.models:
199
+ state.conversations.append(Conversation(model=model, messages=[]))
200
+ input = state.input
201
  state.input = ""
202
+
203
+ for conversation in state.conversations:
204
+ model = conversation.model
205
+ messages = conversation.messages
206
+ history = messages[:]
207
+ messages.append(ChatMessage(role="user", content=input))
208
+ messages.append(ChatMessage(role="model", in_progress=True))
209
+ yield
210
+
211
+ if model == Models.GEMINI_1_5_FLASH.value:
212
+ llm_response = gemini.send_prompt_flash(input, history)
213
+ elif model == Models.GEMINI_1_5_PRO.value:
214
+ llm_response = gemini.send_prompt_pro(input, history)
215
+ elif model == Models.CLAUDE_3_5_SONNET.value:
216
+ llm_response = claude.call_claude_sonnet(input, history)
217
+ else:
218
+ raise Exception("Unhandled model", model)
219
+
220
+ for chunk in llm_response:
221
+ messages[-1].content += chunk
222
+ yield
223
+ messages[-1].in_progress = False
224
+ yield
225
+