YiJina / yijinaembed.py
Tonic's picture
add jina embeddings and reranker
a6d437d unverified
raw
history blame
9.31 kB
# app.py
import os
import re
import uuid
import gradio as gr
import torch
import torch.nn.functional as F
from dotenv import load_dotenv
from typing import List, Tuple, Dict, Any
from transformers import AutoTokenizer, AutoModel
from openai import OpenAI
from langchain_community.document_loaders import UnstructuredFileLoader
from langchain_chroma import Chroma
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.config import Settings
import chromadb
from utils import load_env_variables, parse_and_route, escape_special_characters
from globalvars import API_BASE, intention_prompt, tasks, system_message, metadata_prompt, model_name
import spaces
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_community.document_compressors.jina_rerank import JinaRerank
from langchain import hub
from langchain.chains import create_retrieval_chain
from langchain.chains.retrieval import create_stuff_documents_chain
load_dotenv()
# os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:180'
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# os.environ['CUDA_CACHE_DISABLE'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hf_token, yi_token = load_env_variables()
tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, trust_remote_code=True)
model = None
@spaces.GPU
def load_model():
global model
if model is None:
model = AutoModel.from_pretrained(model_name, token=hf_token, trust_remote_code=True).to(device)
return model
# Load model
jina_model = load_model()
def clear_cuda_cache():
torch.cuda.empty_cache()
client = OpenAI(api_key=yi_token, base_url=API_BASE)
chroma_client = chromadb.Client(Settings())
chroma_collection = chroma_client.create_collection("all-my-documents")
class JinaEmbeddingFunction(EmbeddingFunction):
def __init__(self, model, tokenizer, intention_client):
self.model = model
self.tokenizer = tokenizer
self.intention_client = intention_client
def __call__(self, input: Documents) -> Tuple[List[List[float]], List[Dict[str, Any]]]:
embeddings_with_metadata = [self.compute_embeddings(doc) for doc in input]
embeddings = [item[0] for item in embeddings_with_metadata]
metadata = [item[1] for item in embeddings_with_metadata]
return embeddings, metadata
@spaces.GPU
def compute_embeddings(self, input_text: str):
escaped_input_text = escape_special_characters(input_text)
# Get the intention
intention_completion = self.intention_client.chat.completions.create(
model="yi-large",
messages=[
{"role": "system", "content": escape_special_characters(intention_prompt)},
{"role": "user", "content": escaped_input_text}
]
)
intention_output = intention_completion.choices[0].message.content
parsed_task = parse_and_route(intention_output)
selected_task = parsed_task if parsed_task in tasks else "DEFAULT"
task = tasks[selected_task]
# Get the metadata
metadata_completion = self.intention_client.chat.completions.create(
model="yi-large",
messages=[
{"role": "system", "content": escape_special_characters(metadata_prompt)},
{"role": "user", "content": escaped_input_text}
]
)
metadata_output = metadata_completion.choices[0].message.content
metadata = self.extract_metadata(metadata_output)
# Compute embeddings using Jina model
encoded_input = self.tokenizer(escaped_input_text, padding=True, truncation=True, return_tensors="pt").to(device)
with torch.no_grad():
model_output = self.model(**encoded_input, task=task)
embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"])
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings.cpu().numpy().tolist()[0], metadata
def extract_metadata(self, metadata_output: str) -> Dict[str, str]:
pattern = re.compile(r'\"(\w+)\": \"([^\"]+)\"')
matches = pattern.findall(metadata_output)
metadata = {key: value for key, value in matches}
return metadata
@staticmethod
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
def load_documents(file_path: str, mode: str = "elements"):
loader = UnstructuredFileLoader(file_path, mode=mode)
docs = loader.load()
return [doc.page_content for doc in docs]
def initialize_chroma(collection_name: str, embedding_function: JinaEmbeddingFunction):
db = Chroma(client=chroma_client, collection_name=collection_name, embedding_function=embedding_function)
return db
@spaces.GPU
def add_documents_to_chroma(documents: list, embedding_function: JinaEmbeddingFunction):
for doc in documents:
embeddings, metadata = embedding_function.compute_embeddings(doc)
chroma_collection.add(
ids=[str(uuid.uuid1())],
documents=[doc],
embeddings=[embeddings],
metadatas=[metadata]
)
@spaces.GPU
def rerank_documents(query: str, documents: List[str]) -> List[str]:
compressor = JinaRerank()
retriever = chroma_db.as_retriever(search_kwargs={"k": 20})
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
compressed_docs = compression_retriever.get_relevant_documents(query)
return [doc.page_content for doc in compressed_docs]
def query_chroma(query_text: str, embedding_function: JinaEmbeddingFunction):
query_embeddings, query_metadata = embedding_function.compute_embeddings(query_text)
result_docs = chroma_collection.query(
query_embeddings=[query_embeddings],
n_results=3
)
return result_docs
@spaces.GPU
def answer_query(message: str, chat_history: List[Tuple[str, str]], system_message: str, max_new_tokens: int, temperature: float, top_p: float):
# Query Chroma for relevant documents
results = query_chroma(message, embedding_function)
context = "\n\n".join([result['document'] for result in results['documents'][0]])
# Rerank the documents
reranked_docs = rerank_documents(message, context.split("\n\n"))
reranked_context = "\n\n".join(reranked_docs)
# Prepare the prompt for YI model
prompt = f"{system_message}\n\nContext: {reranked_context}\n\nHuman: {message}\n\nAssistant:"
# Generate response using YI model
response = client.chat.completions.create(
model="yi-large",
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": f"Context: {reranked_context}\n\nHuman: {message}"}
],
max_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p
)
assistant_response = response.choices[0].message.content
chat_history.append((message, assistant_response))
return "", chat_history
# Initialize clients
intention_client = OpenAI(api_key=yi_token, base_url=API_BASE)
embedding_function = JinaEmbeddingFunction(jina_model, tokenizer, intention_client)
chroma_db = initialize_chroma(collection_name="Jina-embeddings", embedding_function=embedding_function)
@spaces.GPU
def upload_documents(files):
for file in files:
loader = UnstructuredFileLoader(file.name)
documents = loader.load()
add_documents_to_chroma([doc.page_content for doc in documents], embedding_function)
return "Documents uploaded and processed successfully!"
@spaces.GPU
def query_documents(query):
results = query_chroma(query, embedding_function)
reranked_docs = rerank_documents(query, [result for result in results['documents'][0]])
return "\n\n".join(reranked_docs)
with gr.Blocks() as demo:
with gr.Tab("Upload Documents"):
document_upload = gr.File(file_count="multiple", file_types=["document"])
upload_button = gr.Button("Upload and Process")
upload_button.click(upload_documents, inputs=document_upload, outputs=gr.Text())
with gr.Tab("Ask Questions"):
with gr.Row():
chat_interface = gr.ChatInterface(
answer_query,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
query_input = gr.Textbox(label="Query")
query_button = gr.Button("Query")
query_output = gr.Textbox()
query_button.click(query_documents, inputs=query_input, outputs=query_output)
if __name__ == "__main__":
demo.launch()