Abid Ali Awan commited on
Commit
dbc2cc3
·
1 Parent(s): c84dd0f

improve the text trigger function

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
 
2
  import gradio as gr
 
3
  from langchain_core.output_parsers import StrOutputParser
 
4
  from langchain_core.runnables import RunnablePassthrough
5
  from langchain_groq import ChatGroq
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
- from langchain_chroma import Chroma
8
- from langchain_core.prompts import PromptTemplate
9
 
10
  # Load the API key from environment variables
11
  groq_api_key = os.getenv("Groq_API_Key")
@@ -14,8 +15,9 @@ groq_api_key = os.getenv("Groq_API_Key")
14
  llm = ChatGroq(model="llama-3.1-70b-versatile", api_key=groq_api_key)
15
 
16
  # Initialize the embedding model
17
- embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1",
18
- model_kwargs = {'device': 'cpu'})
 
19
 
20
  # Load the vector store from a local directory
21
  vectorstore = Chroma(
@@ -48,6 +50,7 @@ rag_chain = (
48
  | StrOutputParser()
49
  )
50
 
 
51
  # Define the function to stream the RAG memory
52
  def rag_memory_stream(text):
53
  partial_text = ""
@@ -56,6 +59,7 @@ def rag_memory_stream(text):
56
  # Yield the updated conversation history
57
  yield partial_text
58
 
 
59
  # Set up the Gradio interface
60
  title = "Real-time AI App with Groq API and LangChain"
61
  description = """
@@ -68,15 +72,18 @@ demo = gr.Interface(
68
  title=title,
69
  description=description,
70
  fn=rag_memory_stream,
71
- inputs="text",
72
- outputs="text",
 
 
 
 
73
  live=True,
74
  batch=True,
75
  max_batch_size=10000,
76
  concurrency_limit=12,
77
  allow_flagging="never",
78
  theme=gr.themes.Soft(),
79
- trigger_mode="always_last",
80
  )
81
 
82
  # Launch the Gradio interface
 
1
  import os
2
+
3
  import gradio as gr
4
+ from langchain_chroma import Chroma
5
  from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.prompts import PromptTemplate
7
  from langchain_core.runnables import RunnablePassthrough
8
  from langchain_groq import ChatGroq
9
  from langchain_huggingface import HuggingFaceEmbeddings
 
 
10
 
11
  # Load the API key from environment variables
12
  groq_api_key = os.getenv("Groq_API_Key")
 
15
  llm = ChatGroq(model="llama-3.1-70b-versatile", api_key=groq_api_key)
16
 
17
  # Initialize the embedding model
18
+ embed_model = HuggingFaceEmbeddings(
19
+ model_name="mixedbread-ai/mxbai-embed-large-v1", model_kwargs={"device": "cpu"}
20
+ )
21
 
22
  # Load the vector store from a local directory
23
  vectorstore = Chroma(
 
50
  | StrOutputParser()
51
  )
52
 
53
+
54
  # Define the function to stream the RAG memory
55
  def rag_memory_stream(text):
56
  partial_text = ""
 
59
  # Yield the updated conversation history
60
  yield partial_text
61
 
62
+
63
  # Set up the Gradio interface
64
  title = "Real-time AI App with Groq API and LangChain"
65
  description = """
 
72
  title=title,
73
  description=description,
74
  fn=rag_memory_stream,
75
+ inputs=gr.Textbox(
76
+ label="Enter your Star Wars question:",
77
+ trigger_mode="always_last",
78
+ default="Who is luke?",
79
+ ),
80
+ outputs=gr.Textbox(label="Awnser:", default="...", trigger_mode="auto"),
81
  live=True,
82
  batch=True,
83
  max_batch_size=10000,
84
  concurrency_limit=12,
85
  allow_flagging="never",
86
  theme=gr.themes.Soft(),
 
87
  )
88
 
89
  # Launch the Gradio interface