|
import streamlit as st |
|
import ollama |
|
import os |
|
import logging |
|
|
|
from langchain_ollama import ChatOllama |
|
from langchain_community.llms import Ollama |
|
|
|
from langchain_community.document_loaders import PyPDFLoader |
|
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
|
import faiss |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.docstore.in_memory import InMemoryDocstore |
|
|
|
from langchain import hub |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
|
from typing import List, Tuple, Dict, Any, Optional |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_docs(docs): |
|
return "\n\n".join([doc.page_content for doc in docs]) |
|
|
|
@st.cache_resource(show_spinner=True) |
|
def extract_model_names( |
|
models_info: Dict[str, List[Dict[str, Any]]], |
|
) -> Tuple[str, ...]: |
|
""" |
|
Extract model names from the provided models information. |
|
|
|
Args: |
|
models_info (Dict[str, List[Dict[str, Any]]]): Dictionary containing information about available models. |
|
|
|
Returns: |
|
Tuple[str, ...]: A tuple of model names. |
|
""" |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(levelname)s - %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
logger.info("Extracting model names from models_info") |
|
model_names = tuple(model["name"] for model in models_info["models"]) |
|
logger.info(f"Extracted model names: {model_names}") |
|
return model_names |
|
|
|
|
|
def generate_response(rag_chain, input_text): |
|
|
|
response = rag_chain.invoke(input_text) |
|
|
|
return response |
|
|
|
def get_pdf(uploaded_file): |
|
temp_file = "./temp.pdf" |
|
if uploaded_file : |
|
|
|
|
|
if os.path.exists(temp_file): |
|
os.remove(temp_file) |
|
with open(temp_file, "wb") as file: |
|
file.write(uploaded_file.getvalue()) |
|
file_name = uploaded_file.name |
|
loader = PyPDFLoader(temp_file) |
|
docs = loader.load() |
|
return docs |
|
|
|
def inference(chain, input_query): |
|
"""Invoke the processing chain with the input query.""" |
|
result = chain.invoke(input_query) |
|
return result |
|
|
|
|
|
def main() -> None: |
|
|
|
st.title("🧠 This is a RAG Chatbot with Ollama and Langchain !!!") |
|
|
|
st.write("The LLM model Llama-3.2 is used") |
|
st.write("You can upload a PDF to chat with !!!") |
|
|
|
with st.sidebar: |
|
st.title("PDF FILE UPLOAD:") |
|
docs = st.file_uploader("Upload your PDF File and Click on the Submit & Process Button", accept_multiple_files=False, key="pdf_uploader") |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
|
|
|
raw_text = get_pdf(docs) |
|
|
|
chunks = text_splitter.split_documents(raw_text) |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
single_vector = embeddings.embed_query("this is some text data") |
|
|
|
index = faiss.IndexFlatL2(len(single_vector)) |
|
|
|
vector_store = FAISS( |
|
embedding_function=embeddings, |
|
index=index, |
|
docstore=InMemoryDocstore(), |
|
index_to_docstore_id={} |
|
) |
|
|
|
ids = vector_store.add_documents(documents=chunks) |
|
|
|
|
|
|
|
retriever = vector_store.as_retriever(search_type="mmr", search_kwargs = {'k': 3, |
|
'fetch_k': 100, |
|
'lambda_mult': 1}) |
|
|
|
|
|
prompt = """ |
|
You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. |
|
If you don't know the answer, just say that you don't know. |
|
Answer in bullet points. Make sure your answer is relevant to the question and it is answered from the context only. |
|
Question: {question} |
|
Context: {context} |
|
Answer: |
|
""" |
|
|
|
prompt = ChatPromptTemplate.from_template(prompt) |
|
|
|
|
|
model = ChatOllama(model="unsloth/Llama-3.2-3B") |
|
|
|
test_llm = model |
|
test_llm.invoke("Why is the sky blue?") |
|
|
|
rag_chain = ( |
|
{"context": retriever|format_docs, "question": RunnablePassthrough()} |
|
| prompt |
|
| model |
|
| StrOutputParser() |
|
) |
|
|
|
|
|
with st.form("llm-form"): |
|
text = st.text_area("Enter your question or statement:") |
|
submit = st.form_submit_button("Submit") |
|
|
|
if "chat_history" not in st.session_state: |
|
st.session_state['chat_history'] = [] |
|
|
|
if submit and text: |
|
with st.spinner("Generating response..."): |
|
|
|
response = generate_response(rag_chain, text) |
|
|
|
st.session_state['chat_history'].append({"user": text, "ollama": response}) |
|
st.write(response) |
|
|
|
st.write("## Chat History") |
|
for chat in reversed(st.session_state['chat_history']): |
|
st.write(f"**🧑 User**: {chat['user']}") |
|
st.write(f"**🧠 Assistant**: {chat['ollama']}") |
|
st.write("---") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|