Abid Ali Awan commited on
Commit
ed71130
·
1 Parent(s): dbc2cc3

fixing threading

Browse files
Files changed (1) hide show
  1. app.py +20 -11
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import os
2
-
3
  import gradio as gr
4
- from langchain_chroma import Chroma
5
  from langchain_core.output_parsers import StrOutputParser
6
- from langchain_core.prompts import PromptTemplate
7
  from langchain_core.runnables import RunnablePassthrough
8
  from langchain_groq import ChatGroq
9
  from langchain_huggingface import HuggingFaceEmbeddings
 
 
 
10
 
11
  # Load the API key from environment variables
12
  groq_api_key = os.getenv("Groq_API_Key")
@@ -50,11 +50,24 @@ rag_chain = (
50
  | StrOutputParser()
51
  )
52
 
 
 
 
 
 
 
53
 
54
  # Define the function to stream the RAG memory
55
  def rag_memory_stream(text):
 
 
 
56
  partial_text = ""
57
  for new_text in rag_chain.stream(text):
 
 
 
 
58
  partial_text += new_text
59
  # Yield the updated conversation history
60
  yield partial_text
@@ -72,15 +85,11 @@ demo = gr.Interface(
72
  title=title,
73
  description=description,
74
  fn=rag_memory_stream,
75
- inputs=gr.Textbox(
76
- label="Enter your Star Wars question:",
77
- trigger_mode="always_last",
78
- default="Who is luke?",
79
- ),
80
- outputs=gr.Textbox(label="Awnser:", default="...", trigger_mode="auto"),
81
  live=True,
82
- batch=True,
83
- max_batch_size=10000,
84
  concurrency_limit=12,
85
  allow_flagging="never",
86
  theme=gr.themes.Soft(),
 
1
  import os
 
2
  import gradio as gr
 
3
  from langchain_core.output_parsers import StrOutputParser
 
4
  from langchain_core.runnables import RunnablePassthrough
5
  from langchain_groq import ChatGroq
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
+ from langchain_chroma import Chroma
8
+ from langchain_core.prompts import PromptTemplate
9
+ import threading
10
 
11
  # Load the API key from environment variables
12
  groq_api_key = os.getenv("Groq_API_Key")
 
50
  | StrOutputParser()
51
  )
52
 
53
+ # Global variable to store the current input text
54
+ current_text = ""
55
+
56
+ # Lock to synchronize access to current_text
57
+ text_lock = threading.Lock()
58
+
59
 
60
  # Define the function to stream the RAG memory
61
  def rag_memory_stream(text):
62
+ global current_text
63
+ with text_lock:
64
+ current_text = text # Update the current text input
65
  partial_text = ""
66
  for new_text in rag_chain.stream(text):
67
+ with text_lock:
68
+ # If the input text has changed, reset the generation
69
+ if text != current_text:
70
+ return # Exit the generator if new input is detected
71
  partial_text += new_text
72
  # Yield the updated conversation history
73
  yield partial_text
 
85
  title=title,
86
  description=description,
87
  fn=rag_memory_stream,
88
+ inputs="text",
89
+ outputs="text",
 
 
 
 
90
  live=True,
91
+ batch=False, # Disable batching to handle each input separately
92
+ max_batch_size=1, # Set batch size to 1 to process inputs one by one
93
  concurrency_limit=12,
94
  allow_flagging="never",
95
  theme=gr.themes.Soft(),