vpcom commited on
Commit
3842ab5
·
1 Parent(s): 9ffbb95

feat: overwrite the chatbot class

Browse files
Files changed (1) hide show
  1. app.py +95 -2
app.py CHANGED
@@ -230,7 +230,101 @@ CSS = """
230
  p {direction: rtl; white-space: pre-line;}
231
  """
232
 
233
- chatbot = gr.Chatbot(label="PersianGPT",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  rtl=True,
235
  show_share_button=True,
236
  show_copy_button=True,
@@ -337,7 +431,6 @@ def vote(data: gr.LikeData):
337
  ]
338
  jsonlfile.write("\n".join(json_data) + "\n")
339
 
340
-
341
  def save_whotheyare(x):
342
  global USERNAME
343
  global NAME
 
230
  p {direction: rtl; white-space: pre-line;}
231
  """
232
 
233
+ class Chatbot(gr.Chatbot):
234
+ def _preprocess_chat_messages(
235
+ self, chat_message: str | dict | None
236
+ ) -> str | tuple[str] | tuple[str, str] | None:
237
+ if chat_message is None:
238
+ return None
239
+ elif isinstance(chat_message, dict):
240
+ if chat_message["alt_text"] is not None:
241
+ return (chat_message["name"], chat_message["alt_text"])
242
+ else:
243
+ return (chat_message["name"],)
244
+ else: # string
245
+ return chat_message
246
+
247
+ def preprocess(
248
+ self,
249
+ y: list[list[str | dict | None] | tuple[str | dict | None, str | dict | None]],
250
+ ) -> list[list[str | tuple[str] | tuple[str, str] | None]]:
251
+ if y is None:
252
+ return y
253
+ processed_messages = []
254
+ for message_pair in y:
255
+ if not isinstance(message_pair, (tuple, list)):
256
+ raise TypeError(
257
+ f"Expected a list of lists or list of tuples. Received: {message_pair}"
258
+ )
259
+ if len(message_pair) != 2:
260
+ raise TypeError(
261
+ f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
262
+ )
263
+ processed_messages.append(
264
+ [
265
+ self._preprocess_chat_messages(message_pair[0]),
266
+ self._preprocess_chat_messages(message_pair[1]),
267
+ ]
268
+ )
269
+ return processed_messages
270
+
271
+ def _postprocess_chat_messages(
272
+ self, chat_message: str | tuple | list | None
273
+ ) -> str | dict | None:
274
+ if chat_message is None:
275
+ return None
276
+ elif isinstance(chat_message, (tuple, list)):
277
+ file_uri = str(chat_message[0])
278
+ if utils.validate_url(file_uri):
279
+ filepath = file_uri
280
+ else:
281
+ filepath = self.make_temp_copy_if_needed(file_uri)
282
+
283
+ mime_type = client_utils.get_mimetype(filepath)
284
+ return {
285
+ "name": filepath,
286
+ "mime_type": mime_type,
287
+ "alt_text": chat_message[1] if len(chat_message) > 1 else None,
288
+ "data": None, # These last two fields are filled in by the frontend
289
+ "is_file": True,
290
+ }
291
+ elif isinstance(chat_message, str):
292
+ chat_message = inspect.cleandoc(chat_message)
293
+ return chat_message
294
+ else:
295
+ raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
296
+
297
+ def postprocess(
298
+ self,
299
+ y: list[list[str | tuple[str] | tuple[str, str] | None] | tuple],
300
+ ) -> list[list[str | dict | None]]:
301
+ """
302
+ Parameters:
303
+ y: List of lists representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string or pathlib.Path filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
304
+ Returns:
305
+ List of lists representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information. Or None if the message is not to be displayed.
306
+ """
307
+ if y is None:
308
+ return []
309
+ processed_messages = []
310
+ for message_pair in y:
311
+ if not isinstance(message_pair, (tuple, list)):
312
+ raise TypeError(
313
+ f"Expected a list of lists or list of tuples. Received: {message_pair}"
314
+ )
315
+ if len(message_pair) != 2:
316
+ raise TypeError(
317
+ f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
318
+ )
319
+ processed_messages.append(
320
+ [
321
+ self._postprocess_chat_messages(message_pair[0]),
322
+ self._postprocess_chat_messages(message_pair[1]),
323
+ ]
324
+ )
325
+ return processed_messages
326
+
327
+ chatbot = Chatbot(label="PersianGPT",
328
  rtl=True,
329
  show_share_button=True,
330
  show_copy_button=True,
 
431
  ]
432
  jsonlfile.write("\n".join(json_data) + "\n")
433
 
 
434
  def save_whotheyare(x):
435
  global USERNAME
436
  global NAME