Real-Time-RAG / app.py
Abid Ali Awan
fixing threading
ed71130
raw
history blame
2.96 kB
import os
import gradio as gr
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_core.prompts import PromptTemplate
import threading
# Load the API key from environment variables
groq_api_key = os.getenv("Groq_API_Key")
# Initialize the language model with the specified model and API key
llm = ChatGroq(model="llama-3.1-70b-versatile", api_key=groq_api_key)
# Initialize the embedding model
embed_model = HuggingFaceEmbeddings(
model_name="mixedbread-ai/mxbai-embed-large-v1", model_kwargs={"device": "cpu"}
)
# Load the vector store from a local directory
vectorstore = Chroma(
"Starwars_Vectordb",
embedding_function=embed_model,
)
# Convert the vector store to a retriever
retriever = vectorstore.as_retriever()
# Define the prompt template for the language model
template = """You are a Star Wars assistant for answering questions.
Use the provided context to answer the question.
If you don't know the answer, say so. Explain your answer in detail.
Do not discuss the context in your response; just provide the answer directly.
Context: {context}
Question: {question}
Answer:"""
rag_prompt = PromptTemplate.from_template(template)
# Create the RAG (Retrieval-Augmented Generation) chain
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| rag_prompt
| llm
| StrOutputParser()
)
# Global variable to store the current input text
current_text = ""
# Lock to synchronize access to current_text
text_lock = threading.Lock()
# Define the function to stream the RAG memory
def rag_memory_stream(text):
global current_text
with text_lock:
current_text = text # Update the current text input
partial_text = ""
for new_text in rag_chain.stream(text):
with text_lock:
# If the input text has changed, reset the generation
if text != current_text:
return # Exit the generator if new input is detected
partial_text += new_text
# Yield the updated conversation history
yield partial_text
# Set up the Gradio interface
title = "Real-time AI App with Groq API and LangChain"
description = """
<center>
<img src="https://huggingface.co/spaces/kingabzpro/Real-Time-RAG/resolve/main/Images/cover.png" alt="logo" width="550"/>
</center>
"""
demo = gr.Interface(
title=title,
description=description,
fn=rag_memory_stream,
inputs="text",
outputs="text",
live=True,
batch=False, # Disable batching to handle each input separately
max_batch_size=1, # Set batch size to 1 to process inputs one by one
concurrency_limit=12,
allow_flagging="never",
theme=gr.themes.Soft(),
)
# Launch the Gradio interface
demo.queue()
demo.launch()