Spaces:
Build error
Build error
# app.py | |
import spaces | |
from torch.nn import DataParallel | |
from torch import Tensor | |
from transformers import AutoTokenizer, AutoModel | |
from huggingface_hub import InferenceClient | |
from openai import OpenAI | |
from langchain_community.document_loaders import UnstructuredFileLoader | |
from langchain_chroma import Chroma | |
from chromadb import Documents, EmbeddingFunction, Embeddings | |
from chromadb.config import Settings | |
import chromadb #import HttpClient | |
import os | |
import tempfile | |
import re | |
import uuid | |
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from dotenv import load_dotenv | |
from utils import load_env_variables, parse_and_route, escape_special_characters | |
from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name, metadata_prompt | |
from sentence_transformers import SentenceTransformer | |
load_dotenv() | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30' | |
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | |
os.environ['CUDA_CACHE_DISABLE'] = '1' | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Ensure the temporary directory exists | |
temp_dir = '/tmp/gradio/' | |
os.makedirs(temp_dir, exist_ok=True) | |
# Set Gradio cache directory | |
gr.components.file.GRADIO_CACHE = temp_dir | |
### Utils | |
hf_token, yi_token = load_env_variables() | |
def clear_cuda_cache(): | |
torch.cuda.empty_cache() | |
client = OpenAI(api_key=yi_token, base_url=API_BASE) | |
chroma_client = chromadb.Client(Settings()) | |
# Create a collection | |
chroma_collection = chroma_client.create_collection("all-my-documents") | |
class EmbeddingGenerator: | |
def __init__(self, model_name: str, token: str, intention_client): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, trust_remote_code=True) | |
self.model = AutoModel.from_pretrained(model_name, token=token, trust_remote_code=True).to(self.device) | |
self.intention_client = intention_client | |
def clear_cuda_cache(self): | |
torch.cuda.empty_cache() | |
def compute_embeddings(self, input_text: str): | |
escaped_input_text = escape_special_characters(input_text) | |
intention_completion = self.intention_client.chat.completions.create( | |
model="yi-large", | |
messages=[ | |
{"role": "system", "content": escape_special_characters(intention_prompt)}, | |
{"role": "user", "content": escaped_input_text} | |
] | |
) | |
intention_output = intention_completion.choices[0].message.content | |
# Parse and route the intention | |
parsed_task = parse_and_route(intention_output) | |
selected_task = parsed_task | |
# Construct the prompt | |
if selected_task in tasks: | |
task_description = tasks[selected_task] | |
else: | |
task_description = tasks["DEFAULT"] | |
print(f"Selected task not found: {selected_task}") | |
query_prefix = f"Instruct: {task_description}\nQuery: " | |
queries = [escaped_input_text] | |
# Get the metadata | |
metadata_completion = self.intention_client.chat.completions.create( | |
model="yi-large", | |
messages=[ | |
{"role": "system", "content": escape_special_characters(metadata_prompt)}, | |
{"role": "user", "content": escaped_input_text} | |
] | |
) | |
metadata_output = metadata_completion.choices[0].message.content | |
metadata = self.extract_metadata(metadata_output) | |
# Get the embeddings | |
with torch.no_grad(): | |
inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device) | |
outputs = self.model(**inputs) | |
query_embeddings = outputs["sentence_embeddings"].mean(dim=1) | |
query_embeddings = outputs.last_hidden_state.mean(dim=1) | |
# Normalize embeddings | |
query_embeddings = F.normalize(query_embeddings, p=2, dim=1) | |
embeddings_list = query_embeddings.detach().cpu().numpy().tolist() | |
self.clear_cuda_cache() | |
return embeddings_list, metadata | |
def extract_metadata(self, metadata_output: str): | |
# Regex pattern to extract key-value pairs | |
pattern = re.compile(r'\"(\w+)\": \"([^\"]+)\"') | |
matches = pattern.findall(metadata_output) | |
metadata = {key: value for key, value in matches} | |
return metadata | |
class MyEmbeddingFunction(EmbeddingFunction): | |
def __init__(self, model_name: str, token: str, intention_client): | |
self.model_name = model_name | |
self.token = token | |
self.intention_client = intention_client | |
def create_embedding_generator(self): | |
return EmbeddingGenerator(self.model_name, self.token, self.intention_client) | |
def __call__(self, input: Documents) -> (Embeddings, list): | |
embedding_generator = self.create_embedding_generator() | |
embeddings_with_metadata = [embedding_generator.compute_embeddings(doc.page_content) for doc in input] | |
embeddings = [item[0] for item in embeddings_with_metadata] | |
metadata = [item[1] for item in embeddings_with_metadata] | |
embeddings_flattened = [emb for sublist in embeddings for emb in sublist] | |
metadata_flattened = [meta for sublist in metadata for meta in sublist] | |
return embeddings_flattened, metadata_flattened | |
def load_documents(file_path: str, mode: str = "elements"): | |
loader = UnstructuredFileLoader(file_path, mode=mode) | |
docs = loader.load() | |
return [doc.page_content for doc in docs] | |
def initialize_chroma(collection_name: str, embedding_function: MyEmbeddingFunction): | |
db = Chroma(client=chroma_client, collection_name=collection_name, embedding_function=embedding_function) | |
return db | |
def add_documents_to_chroma(documents: list, embedding_function: MyEmbeddingFunction): | |
for doc in documents: | |
embeddings, metadata = embedding_function.create_embedding_generator().compute_embeddings(doc) | |
for embedding, meta in zip(embeddings, metadata): | |
chroma_collection.add( | |
ids=[str(uuid.uuid1())], | |
documents=[doc], | |
embeddings=[embedding], | |
metadatas=[meta] | |
) | |
def query_chroma(query_text: str, embedding_function: MyEmbeddingFunction): | |
query_embeddings, query_metadata = embedding_function.create_embedding_generator().compute_embeddings(query_text) | |
result_docs = chroma_collection.query( | |
query_texts=[query_text], | |
n_results=2 | |
) | |
return result_docs | |
# Initialize clients | |
intention_client = OpenAI(api_key=yi_token, base_url=API_BASE) | |
embedding_generator = EmbeddingGenerator(model_name=model_name, token=hf_token, intention_client=intention_client) | |
embedding_function = MyEmbeddingFunction(model_name=model_name, token=hf_token, intention_client=intention_client) | |
chroma_db = initialize_chroma(collection_name="Tonic-instruct", embedding_function=embedding_function) | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
retrieved_text = query_documents(message) | |
messages = [{"role": "system", "content": escape_special_characters(system_message)}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": f"{retrieved_text}\n\n{escape_special_characters(message)}"}) | |
response = "" | |
for message in intention_client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = message.choices[0].delta.content | |
response += token | |
yield response | |
def upload_documents(files): | |
for file in files: | |
loader = UnstructuredFileLoader(file.name) | |
documents = loader.load() | |
add_documents_to_chroma(documents, embedding_function) | |
return "Documents uploaded and processed successfully!" | |
def query_documents(query): | |
results = query_chroma(query, embedding_function) | |
return "\n\n".join([result.content for result in results]) | |
with gr.Blocks() as demo: | |
with gr.Tab("Upload Documents"): | |
document_upload = gr.File(file_count="multiple", file_types=["document"]) | |
upload_button = gr.Button("Upload and Process") | |
upload_button.click(upload_documents, inputs=document_upload, outputs=gr.Text()) | |
with gr.Tab("Ask Questions"): | |
with gr.Row(): | |
chat_interface = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), | |
], | |
) | |
query_input = gr.Textbox(label="Query") | |
query_button = gr.Button("Query") | |
query_output = gr.Textbox() | |
query_button.click(query_documents, inputs=query_input, outputs=query_output) | |
if __name__ == "__main__": | |
# os.system("chroma run --host localhost --port 8000 &") | |
demo.launch() | |