martin-gorner commited on
Commit
4ef1969
1 Parent(s): ba64cc5

added ability to chat with only the left or right bot

Browse files
Files changed (4) hide show
  1. app.py +154 -93
  2. img/arrowL.png +0 -0
  3. img/arrowR.png +0 -0
  4. img/arrowRL.png +0 -0
app.py CHANGED
@@ -18,6 +18,7 @@ from gradio import ChatMessage
18
  import keras_hub
19
 
20
  from chatstate import ChatState
 
21
  from models import (
22
  model_presets,
23
  load_model,
@@ -26,6 +27,13 @@ from models import (
26
  get_appropriate_chat_template,
27
  )
28
 
 
 
 
 
 
 
 
29
  model_labels_list = list(model_labels)
30
 
31
  # load and warm up (compile) all the models
@@ -41,21 +49,23 @@ for preset in model_presets:
41
 
42
  # For local debugging
43
  # model = keras_hub.models.Llama3CausalLM.from_preset(
44
- # "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
 
45
  # )
46
  # models = [model, model, model, model, model]
47
 
48
 
49
- def chat_turn_assistant_1(
50
- model,
51
  message,
 
52
  history,
53
  system_message,
54
- preset,
55
  # max_tokens,
56
  # temperature,
57
  # top_p,
58
  ):
 
 
59
  chat_template = get_appropriate_chat_template(preset)
60
  chat_state = ChatState(model, system_message, chat_template)
61
 
@@ -71,35 +81,25 @@ def chat_turn_assistant_1(
71
  return history
72
 
73
 
74
- def chat_turn_assistant(
75
- message,
76
- sel1,
77
- history1,
78
- sel2,
79
- history2,
80
- system_message,
81
- # max_tokens,
82
- # temperature,
83
- # top_p,
84
  ):
85
- history1 = chat_turn_assistant_1(
86
- models[sel1], message, history1, system_message, model_presets[sel1]
87
- )
88
- history2 = chat_turn_assistant_1(
89
- models[sel2], message, history2, system_message, model_presets[sel2]
90
  )
91
- return "", history1, history2
92
 
93
 
94
- def chat_turn_user_1(message, history):
95
  history.append(ChatMessage(role="user", content=message))
96
  return history
97
 
98
 
99
- def chat_turn_user(message, history1, history2):
100
- history1 = chat_turn_user_1(message, history1)
101
- history2 = chat_turn_user_1(message, history2)
102
- return "", history1, history2
 
103
 
104
 
105
  def bot_icon_select(model_name):
@@ -115,57 +115,72 @@ def bot_icon_select(model_name):
115
  return "img/bot.png"
116
 
117
 
118
- def instantiate_chatbots(sel1, sel2):
119
- model_name1 = model_presets[sel1]
120
- chatbot1 = gr.Chatbot(
121
- type="messages",
122
  show_label=False,
123
- show_share_button=False,
124
- avatar_images=("img/usr.png", bot_icon_select(model_name1)),
 
 
 
 
125
  )
126
- model_name2 = model_presets[sel2]
127
- chatbot2 = gr.Chatbot(
 
 
 
128
  type="messages",
 
129
  show_label=False,
130
  show_share_button=False,
131
- avatar_images=("img/usr.png", bot_icon_select(model_name2)),
132
  )
133
- return chatbot1, chatbot2
134
 
135
 
136
- def instantiate_select_boxes(sel1, sel2, model_labels):
137
- sel1 = gr.Dropdown(
138
- choices=[(name, i) for i, name in enumerate(model_labels)],
139
- show_label=False,
140
- info="<span style='color:black'>Selected model 1:</span> "
141
- + "<a href='"
142
- + preset_to_website_url(model_presets[sel1])
143
- + "'>"
144
- + preset_to_website_url(model_presets[sel1])
145
- + "</a>",
146
- value=sel1,
 
147
  )
148
- sel2 = gr.Dropdown(
149
- choices=[(name, i) for i, name in enumerate(model_labels)],
150
- show_label=False,
151
- info="<span style='color:black'>Selected model 2:</span> "
152
- + "<a href='"
153
- + preset_to_website_url(model_presets[sel2])
154
- + "'>"
155
- + preset_to_website_url(model_presets[sel2])
156
- + "</a>",
157
- value=sel2,
158
- )
159
- return sel1, sel2
160
 
161
 
162
- def instantiate_chatbots_and_select_boxes(sel1, sel2, model_labels):
163
- chatbot1, chatbot2 = instantiate_chatbots(sel1, sel2)
164
- sel1, sel2 = instantiate_select_boxes(sel1, sel2, model_labels)
165
- return sel1, chatbot1, sel2, chatbot2
 
 
 
 
166
 
 
 
 
 
167
 
168
- with gr.Blocks(fill_width=True, title="Keras demo") as demo:
 
 
 
 
 
 
169
 
170
  with gr.Row():
171
  gr.Image(
@@ -189,45 +204,91 @@ with gr.Blocks(fill_width=True, title="Keras demo") as demo:
189
  + "This demo is runnig on a Google TPU v5e 2x4 (8 cores) in bfloat16 precision."
190
  )
191
  with gr.Row():
192
- sel1, sel2 = instantiate_select_boxes(0, 1, model_labels_list)
193
-
194
- with gr.Row():
195
- chatbot1, chatbot2 = instantiate_chatbots(sel1.value, sel2.value)
196
-
197
- msg = gr.Textbox(label="Your message:", submit_btn=True)
198
 
199
  with gr.Row():
200
- gr.ClearButton([msg, chatbot1, chatbot2])
201
- with gr.Accordion("Additional settings", open=False):
202
- system_message = gr.Textbox(
203
- label="Sytem prompt",
204
- value="You are a helpful assistant and your name is Eliza.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  sel1.select(
208
- lambda sel1, sel2: instantiate_chatbots_and_select_boxes(
209
- sel1, sel2, model_labels_list
210
- ),
211
- inputs=[sel1, sel2],
212
- outputs=[sel1, chatbot1, sel2, chatbot2],
 
 
213
  )
214
 
215
  sel2.select(
216
- lambda sel1, sel2: instantiate_chatbots_and_select_boxes(
217
- sel1, sel2, model_labels_list
218
- ),
219
- inputs=[sel1, sel2],
220
- outputs=[sel1, chatbot1, sel2, chatbot2],
221
- )
222
-
223
- msg.submit(
224
- chat_turn_user,
225
- inputs=[msg, chatbot1, chatbot2],
226
- outputs=[msg, chatbot1, chatbot2],
227
  ).then(
228
- chat_turn_assistant,
229
- [msg, sel1, chatbot1, sel2, chatbot2, system_message],
230
- outputs=[msg, chatbot1, chatbot2],
231
  )
232
 
233
 
 
18
  import keras_hub
19
 
20
  from chatstate import ChatState
21
+ from enum import Enum
22
  from models import (
23
  model_presets,
24
  load_model,
 
27
  get_appropriate_chat_template,
28
  )
29
 
30
+
31
+ class TextRoute(Enum):
32
+ LEFT = 0
33
+ RIGHT = 1
34
+ BOTH = 2
35
+
36
+
37
  model_labels_list = list(model_labels)
38
 
39
  # load and warm up (compile) all the models
 
49
 
50
  # For local debugging
51
  # model = keras_hub.models.Llama3CausalLM.from_preset(
52
+ # # "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
53
+ # "../misc-code/ari_tiny_llama3"
54
  # )
55
  # models = [model, model, model, model, model]
56
 
57
 
58
+ def chat_turn_assistant(
 
59
  message,
60
+ sel,
61
  history,
62
  system_message,
 
63
  # max_tokens,
64
  # temperature,
65
  # top_p,
66
  ):
67
+ model = models[sel]
68
+ preset = model_presets[sel]
69
  chat_template = get_appropriate_chat_template(preset)
70
  chat_state = ChatState(model, system_message, chat_template)
71
 
 
81
  return history
82
 
83
 
84
+ def chat_turn_both_assistant(
85
+ message, sel1, sel2, history1, history2, system_message
 
 
 
 
 
 
 
 
86
  ):
87
+ return (
88
+ chat_turn_assistant(message, sel1, history1, system_message),
89
+ chat_turn_assistant(message, sel2, history2, system_message),
 
 
90
  )
 
91
 
92
 
93
+ def chat_turn_user(message, history):
94
  history.append(ChatMessage(role="user", content=message))
95
  return history
96
 
97
 
98
+ def chat_turn_both_user(message, history1, history2):
99
+ return (
100
+ chat_turn_user(message, history1),
101
+ chat_turn_user(message, history2),
102
+ )
103
 
104
 
105
  def bot_icon_select(model_name):
 
115
  return "img/bot.png"
116
 
117
 
118
+ def instantiate_select_box(sel, model_labels):
119
+ return gr.Dropdown(
120
+ choices=[(name, i) for i, name in enumerate(model_labels)],
 
121
  show_label=False,
122
+ value=sel,
123
+ info="<span style='color:black'>Selected model:</span> <a href='"
124
+ + preset_to_website_url(model_presets[sel])
125
+ + "'>"
126
+ + preset_to_website_url(model_presets[sel])
127
+ + "</a>",
128
  )
129
+
130
+
131
+ def instantiate_chatbot(sel, key):
132
+ model_name = model_presets[sel]
133
+ return gr.Chatbot(
134
  type="messages",
135
+ key=key,
136
  show_label=False,
137
  show_share_button=False,
138
+ avatar_images=("img/usr.png", bot_icon_select(model_name)),
139
  )
 
140
 
141
 
142
+ def instantiate_arrow_button(route, text_route):
143
+ icons = {
144
+ TextRoute.LEFT: "img/arrowL.png",
145
+ TextRoute.RIGHT: "img/arrowR.png",
146
+ TextRoute.BOTH: "img/arrowRL.png",
147
+ }
148
+ button = gr.Button(
149
+ "",
150
+ size="sm",
151
+ scale=0,
152
+ min_width=40,
153
+ icon=icons[route],
154
  )
155
+ button.click(lambda: route, outputs=[text_route])
156
+ return button
157
+
158
+
159
+ def instantiate_text_box():
160
+ return gr.Textbox(label="Your message:", submit_btn=True, key="msg")
 
 
 
 
 
 
161
 
162
 
163
+ def instantiate_additional_settings():
164
+ with gr.Accordion("Additional settings", open=False):
165
+ system_message = gr.Textbox(
166
+ label="Sytem prompt",
167
+ value="You are a helpful assistant and your name is Eliza.",
168
+ )
169
+ return system_message
170
+
171
 
172
+ sel1 = instantiate_select_box(0, model_labels_list)
173
+ sel2 = instantiate_select_box(1, model_labels_list)
174
+ chatbot1 = instantiate_chatbot(sel1.value, "chat1")
175
+ chatbot2 = instantiate_chatbot(sel2.value, "chat2")
176
 
177
+ # to correctly align the left/right arrows
178
+ CSS = ".elems_justif_right {align-items: end;}"
179
+
180
+ with gr.Blocks(fill_width=True, title="Keras demo", css=CSS) as demo:
181
+
182
+ # Where do messages go
183
+ text_route = gr.State(TextRoute.BOTH)
184
 
185
  with gr.Row():
186
  gr.Image(
 
204
  + "This demo is runnig on a Google TPU v5e 2x4 (8 cores) in bfloat16 precision."
205
  )
206
  with gr.Row():
207
+ sel1.render(),
208
+ sel2.render(),
 
 
 
 
209
 
210
  with gr.Row():
211
+ chatbot1.render()
212
+ chatbot2.render()
213
+
214
+ @gr.render(inputs=text_route)
215
+ def render_text_area(route):
216
+
217
+ if route == TextRoute.BOTH:
218
+ with gr.Row():
219
+ msg = instantiate_text_box()
220
+ with gr.Column(scale=0, min_width=40):
221
+ instantiate_arrow_button(TextRoute.RIGHT, text_route)
222
+ instantiate_arrow_button(TextRoute.LEFT, text_route)
223
+
224
+ elif route == TextRoute.LEFT:
225
+ with gr.Row():
226
+ with gr.Column(scale=1):
227
+ msg = instantiate_text_box()
228
+ with gr.Column(scale=1):
229
+ instantiate_arrow_button(TextRoute.RIGHT, text_route)
230
+ instantiate_arrow_button(TextRoute.BOTH, text_route)
231
+
232
+ elif route == TextRoute.RIGHT:
233
+ with gr.Row():
234
+ with gr.Column(scale=1, elem_classes="elems_justif_right"):
235
+ instantiate_arrow_button(TextRoute.LEFT, text_route)
236
+ instantiate_arrow_button(TextRoute.BOTH, text_route)
237
+ with gr.Column(scale=1):
238
+ msg = instantiate_text_box()
239
+
240
+ with gr.Row():
241
+ clear = gr.ClearButton([msg, chatbot1, chatbot2])
242
+ system_message = instantiate_additional_settings()
243
+
244
+ # Route the submitted message to the left, right or both chatbots
245
+ if route == TextRoute.LEFT:
246
+ submission = msg.submit(
247
+ chat_turn_user, inputs=[msg, chatbot1], outputs=[chatbot1]
248
+ ).then(
249
+ chat_turn_assistant,
250
+ [msg, sel1, chatbot1, system_message],
251
+ outputs=[chatbot1],
252
  )
253
+ elif route == TextRoute.RIGHT:
254
+ submission = msg.submit(
255
+ chat_turn_user, inputs=[msg, chatbot2], outputs=[chatbot2]
256
+ ).then(
257
+ chat_turn_assistant,
258
+ [msg, sel2, chatbot2, system_message],
259
+ outputs=[chatbot2],
260
+ )
261
+ elif route == TextRoute.BOTH:
262
+ submission = msg.submit(
263
+ chat_turn_both_user,
264
+ inputs=[msg, chatbot1, chatbot2],
265
+ outputs=[chatbot1, chatbot2],
266
+ ).then(
267
+ chat_turn_both_assistant,
268
+ [msg, sel1, sel2, chatbot1, chatbot2, system_message],
269
+ outputs=[chatbot1, chatbot2],
270
+ )
271
+ # In all cases reset text box after submission
272
+ submission.then(lambda: "", outputs=msg)
273
 
274
  sel1.select(
275
+ lambda sel: instantiate_chatbot(sel, "chat1"),
276
+ inputs=[sel1],
277
+ outputs=[chatbot1],
278
+ ).then(
279
+ lambda sel: instantiate_select_box(sel, model_labels_list),
280
+ inputs=[sel1],
281
+ outputs=[sel1],
282
  )
283
 
284
  sel2.select(
285
+ lambda sel: instantiate_chatbot(sel, "chat2"),
286
+ inputs=[sel2],
287
+ outputs=[chatbot2],
 
 
 
 
 
 
 
 
288
  ).then(
289
+ lambda sel: instantiate_select_box(sel, model_labels_list),
290
+ inputs=[sel2],
291
+ outputs=[sel2],
292
  )
293
 
294
 
img/arrowL.png ADDED
img/arrowR.png ADDED
img/arrowRL.png ADDED