OCR-Correction / app.py
Pclanglais's picture
Update app.py
208476f verified
raw
history blame
4 kB
import re
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from vllm import LLM, SamplingParams
import torch
import gradio as gr
import json
import os
import shutil
import requests
import numpy as np
import pandas as pd
from chromadb.config import Settings
from chromadb.utils import embedding_functions
from FlagEmbedding import BGEM3FlagModel
from sklearn.metrics.pairwise import cosine_similarity
device = "cuda" if torch.cuda.is_available() else "cpu"
model = BGEM3FlagModel('BAAI/bge-m3',
use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
embeddings = np.load("embeddings_tchap.npy")
embeddings_data = pd.read_json("embeddings_tchap.json")
embeddings_text = embeddings_data["text_with_context"].tolist()
# Define the device
temperature=0.2
max_new_tokens=1000
top_p=0.92
repetition_penalty=1.7
model_name = "Pclanglais/Tchap"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model = model.to('cuda:0')
system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es Albert, l'agent conversationnel des services publics qui peut décrire des documents de référence ou aider à des tâches de rédaction<|eot_id|>"
#Vector search over the database
def vector_search(sentence_query):
query_embedding = model.encode(sentence_query,
batch_size=12,
max_length=256, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
)['dense_vecs']
# Reshape the query embedding to fit the cosine_similarity function requirements
query_embedding_reshaped = query_embedding.reshape(1, -1)
# Compute cosine similarities
similarities = cosine_similarity(query_embedding_reshaped, embeddings)
# Find the index of the closest document (highest similarity)
closest_doc_index = np.argmax(similarities)
# Closest document's embedding
closest_doc_embedding = embeddings_text[closest_doc_index]
return closest_doc_embedding
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [29, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def predict(message, history):
text = vector_search(message)
message = message + "\n\n### Source ###\n" + text
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()
messages = "".join(["".join(["<|start_header_id|>user<|end_header_id|>\n\n"+item[0], "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"+item[1]])
for item in history_transformer_format])
messages = system_prompt + messages
model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=1.0,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
if new_token != '<':
partial_message += new_token
yield partial_message
return messages
# Define the Gradio interface
title = "Tchap"
description = "Le chatbot du service public"
examples = [
[
"Qui peut bénéficier de l'AIP?", # user_message
0.7 # temperature
]
]
gr.ChatInterface(predict).launch()