# app.py import os import re import uuid import gradio as gr import torch import torch.nn.functional as F from dotenv import load_dotenv from typing import List, Tuple, Dict, Any from transformers import AutoTokenizer, AutoModel from openai import OpenAI from langchain_community.document_loaders import UnstructuredFileLoader from langchain_chroma import Chroma from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.config import Settings import chromadb from utils import load_env_variables, parse_and_route, escape_special_characters from globalvars import API_BASE, intention_prompt, tasks, system_message, metadata_prompt, model_name import spaces from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain_community.document_compressors.jina_rerank import JinaRerank from langchain import hub from langchain.chains import create_retrieval_chain from langchain.chains.retrieval import create_stuff_documents_chain load_dotenv() # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:180' # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # os.environ['CUDA_CACHE_DISABLE'] = '1' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") hf_token, yi_token = load_env_variables() tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, trust_remote_code=True) model = None @spaces.GPU def load_model(): global model if model is None: model = AutoModel.from_pretrained(model_name, token=hf_token, trust_remote_code=True).to(device) return model # Load model jina_model = load_model() def clear_cuda_cache(): torch.cuda.empty_cache() client = OpenAI(api_key=yi_token, base_url=API_BASE) chroma_client = chromadb.Client(Settings()) chroma_collection = chroma_client.create_collection("all-my-documents") class JinaEmbeddingFunction(EmbeddingFunction): def __init__(self, model, tokenizer, intention_client): self.model = model self.tokenizer = tokenizer self.intention_client = intention_client def __call__(self, input: Documents) -> Tuple[List[List[float]], List[Dict[str, Any]]]: embeddings_with_metadata = [self.compute_embeddings(doc) for doc in input] embeddings = [item[0] for item in embeddings_with_metadata] metadata = [item[1] for item in embeddings_with_metadata] return embeddings, metadata @spaces.GPU def compute_embeddings(self, input_text: str): escaped_input_text = escape_special_characters(input_text) # Get the intention intention_completion = self.intention_client.chat.completions.create( model="yi-large", messages=[ {"role": "system", "content": escape_special_characters(intention_prompt)}, {"role": "user", "content": escaped_input_text} ] ) intention_output = intention_completion.choices[0].message.content parsed_task = parse_and_route(intention_output) selected_task = parsed_task if parsed_task in tasks else "DEFAULT" task = tasks[selected_task] # Get the metadata metadata_completion = self.intention_client.chat.completions.create( model="yi-large", messages=[ {"role": "system", "content": escape_special_characters(metadata_prompt)}, {"role": "user", "content": escaped_input_text} ] ) metadata_output = metadata_completion.choices[0].message.content metadata = self.extract_metadata(metadata_output) # Compute embeddings using Jina model encoded_input = self.tokenizer(escaped_input_text, padding=True, truncation=True, return_tensors="pt").to(device) with torch.no_grad(): model_output = self.model(**encoded_input, task=task) embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"]) embeddings = F.normalize(embeddings, p=2, dim=1) return embeddings.cpu().numpy().tolist()[0], metadata def extract_metadata(self, metadata_output: str) -> Dict[str, str]: pattern = re.compile(r'\"(\w+)\": \"([^\"]+)\"') matches = pattern.findall(metadata_output) metadata = {key: value for key, value in matches} return metadata @staticmethod def mean_pooling(model_output, attention_mask): token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) def load_documents(file_path: str, mode: str = "elements"): loader = UnstructuredFileLoader(file_path, mode=mode) docs = loader.load() return [doc.page_content for doc in docs] def initialize_chroma(collection_name: str, embedding_function: JinaEmbeddingFunction): db = Chroma(client=chroma_client, collection_name=collection_name, embedding_function=embedding_function) return db @spaces.GPU def add_documents_to_chroma(documents: list, embedding_function: JinaEmbeddingFunction): for doc in documents: embeddings, metadata = embedding_function.compute_embeddings(doc) chroma_collection.add( ids=[str(uuid.uuid1())], documents=[doc], embeddings=[embeddings], metadatas=[metadata] ) @spaces.GPU def rerank_documents(query: str, documents: List[str]) -> List[str]: compressor = JinaRerank() retriever = chroma_db.as_retriever(search_kwargs={"k": 20}) compression_retriever = ContextualCompressionRetriever( base_compressor=compressor, base_retriever=retriever ) compressed_docs = compression_retriever.get_relevant_documents(query) return [doc.page_content for doc in compressed_docs] def query_chroma(query_text: str, embedding_function: JinaEmbeddingFunction): query_embeddings, query_metadata = embedding_function.compute_embeddings(query_text) result_docs = chroma_collection.query( query_embeddings=[query_embeddings], n_results=3 ) return result_docs @spaces.GPU def answer_query(message: str, chat_history: List[Tuple[str, str]], system_message: str, max_new_tokens: int, temperature: float, top_p: float): # Query Chroma for relevant documents results = query_chroma(message, embedding_function) context = "\n\n".join([result['document'] for result in results['documents'][0]]) # Rerank the documents reranked_docs = rerank_documents(message, context.split("\n\n")) reranked_context = "\n\n".join(reranked_docs) # Prepare the prompt for YI model prompt = f"{system_message}\n\nContext: {reranked_context}\n\nHuman: {message}\n\nAssistant:" # Generate response using YI model response = client.chat.completions.create( model="yi-large", messages=[ {"role": "system", "content": system_message}, {"role": "user", "content": f"Context: {reranked_context}\n\nHuman: {message}"} ], max_tokens=max_new_tokens, temperature=temperature, top_p=top_p ) assistant_response = response.choices[0].message.content chat_history.append((message, assistant_response)) return "", chat_history # Initialize clients intention_client = OpenAI(api_key=yi_token, base_url=API_BASE) embedding_function = JinaEmbeddingFunction(jina_model, tokenizer, intention_client) chroma_db = initialize_chroma(collection_name="Jina-embeddings", embedding_function=embedding_function) @spaces.GPU def upload_documents(files): for file in files: loader = UnstructuredFileLoader(file.name) documents = loader.load() add_documents_to_chroma([doc.page_content for doc in documents], embedding_function) return "Documents uploaded and processed successfully!" @spaces.GPU def query_documents(query): results = query_chroma(query, embedding_function) reranked_docs = rerank_documents(query, [result for result in results['documents'][0]]) return "\n\n".join(reranked_docs) with gr.Blocks() as demo: with gr.Tab("Upload Documents"): document_upload = gr.File(file_count="multiple", file_types=["document"]) upload_button = gr.Button("Upload and Process") upload_button.click(upload_documents, inputs=document_upload, outputs=gr.Text()) with gr.Tab("Ask Questions"): with gr.Row(): chat_interface = gr.ChatInterface( answer_query, additional_inputs=[ gr.Textbox(value="You are a friendly Chatbot.", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), ], ) query_input = gr.Textbox(label="Query") query_button = gr.Button("Query") query_output = gr.Textbox() query_button.click(query_documents, inputs=query_input, outputs=query_output) if __name__ == "__main__": demo.launch()