import logging import os import tempfile from langchain.chains import ConversationalRetrievalChain, ConversationChain from langchain_openai import AzureOpenAIEmbeddings, AzureChatOpenAI from langchain.retrievers import ContextualCompressionRetriever from langchain.retrievers.document_compressors import EmbeddingsFilter from langchain.schema import BaseRetriever, Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import DocArrayInMemorySearch from langchain.agents import initialize_agent, AgentType from langchain_community.agent_toolkits.load_tools import load_tools from utils import MEMORY, load_document import streamlit as st logging.basicConfig(encoding="utf-8", level=logging.INFO) LOGGER = logging.getLogger() def config_retriever(docs: list[Document], use_compression=False, chunk_size=1500): text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap = 200) splits = text_splitter.split_documents(docs) embeddings = AzureOpenAIEmbeddings( api_key=st.secrets['key'], azure_deployment=st.secrets['embedding_name'], openai_api_version=st.secrets['embedding_version'], azure_endpoint=st.secrets['endpoint'], ) vectorDB = DocArrayInMemorySearch.from_documents(splits, embeddings) retriever = vectorDB.as_retriever( search_type='mmr', search_kwargs={ "k": 5, "fetch_k": 7, "include_metadata": True } ) if not use_compression: return retriever else: embeddings_filter = EmbeddingsFilter( embeddings=embeddings, similarity_threshold=0.2 ) return ContextualCompressionRetriever( base_compressor=embeddings_filter, base_retriever=retriever ) def config_baseretrieval_chain(retriever: BaseRetriever, temperature=0.1): LLM = AzureChatOpenAI( api_key=st.secrets['key'], openai_api_version=st.secrets['chat_version'], azure_deployment=st.secrets['chat_name'], azure_endpoint=st.secrets['endpoint'], temperature=temperature, ) MEMORY.output_key = 'answer' params = dict( llm=LLM, retriever=retriever, memory=MEMORY, verbose=True ) return ConversationalRetrievalChain.from_llm(**params) def ddg_search_agent(temperature=0.1): LLM = AzureChatOpenAI( api_key=st.secrets['key'], openai_api_version=st.secrets['chat_version'], azure_deployment=st.secrets['chat_name'], azure_endpoint=st.secrets['endpoint'], temperature=temperature, ) tools = load_tools( tool_names=['ddg-search'], llm=LLM, model="gpt-4o-mini" ) return initialize_agent( tools=tools, llm=LLM, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, handle_parsing_errors=True ) def config_retrieval_chain( upload_files, use_compression=False, use_chunksize=1500, use_temperature=0.1, use_zeroshoot=False ): docs = [] temp_dir = tempfile.TemporaryDirectory() for file in upload_files: temp_filepath = os.path.join(temp_dir.name, file.name) with open(temp_filepath, "wb") as f: f.write(file.getvalue()) docs.extend(load_document(temp_filepath)) retriever = config_retriever(docs=docs, use_compression=use_compression, chunk_size=use_chunksize) chain = config_baseretrieval_chain(retriever=retriever, temperature=use_temperature) if use_zeroshoot: return ddg_search_agent(temperature=use_temperature) else: return chain def config_noretrieval_chain(use_temperature=0.1,use_zeroshoot=False): LLM = AzureChatOpenAI( api_key=st.secrets['key'], openai_api_version=st.secrets['chat_version'], azure_deployment=st.secrets['chat_name'], azure_endpoint=st.secrets['endpoint'], temperature=use_temperature, ) if use_zeroshoot: return ddg_search_agent(temperature=use_temperature) else: return ConversationChain(llm=LLM)