Spaces:
Runtime error
Runtime error
import json | |
import os | |
import time | |
import uuid | |
from datetime import datetime | |
import gradio as gr | |
import openai | |
from huggingface_hub import HfApi | |
from langchain.document_loaders import PyPDFLoader, \ | |
UnstructuredPDFLoader, PyPDFium2Loader, PyMuPDFLoader, PDFPlumberLoader | |
from knowledge.faiss_handler import create_faiss_index_from_zip, load_faiss_index_from_zip | |
from knowledge.img_handler import process_image, add_markup | |
from llms.chatbot import OpenAIChatBot | |
from llms.embeddings import EMBEDDINGS_MAPPING | |
from utils import make_archive | |
UPLOAD_REPO_ID=os.getenv("UPLOAD_REPO_ID") | |
HF_TOKEN=os.getenv("HF_TOKEN") | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
openai.api_base == os.getenv("OPENAI_API_BASE") | |
hf_api = HfApi(token=HF_TOKEN) | |
ALL_PDF_LOADERS = [PyPDFLoader, UnstructuredPDFLoader, PyPDFium2Loader, PyMuPDFLoader, PDFPlumberLoader] | |
ALL_EMBEDDINGS = EMBEDDINGS_MAPPING.keys() | |
PDF_LOADER_MAPPING = {loader.__name__: loader for loader in ALL_PDF_LOADERS} | |
####################################################################################################################### | |
# Host multiple vector database for use | |
####################################################################################################################### | |
# todo: add this feature in the future | |
INSTRUCTIONS = '''# FAISS Chat: 和本地数据库聊天! | |
***2023-06-06更新:*** | |
1. 支持读取图片格式的图表数据(目前支持JPG, PNG). | |
2. 在"总结图表(Demo)"的标签页里提供了这个模块的测试. | |
***2023-06-04更新:*** | |
1. 支持更多的Embedding Model (目前支持[text-embedding-ada-002](https://openai.com/blog/new-and-improved-embedding-model), [text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese), 和[distilbert-dot-tas_b-b256-msmarco](https://huggingface.co/sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco) ) | |
2. 支持更多的文件格式(PDF, TXT, TEX, 和MD). | |
3. 所有生成的数据库都可以在[这个数据集](https://huggingface.co/datasets/shaocongma/shared-faiss-vdb)里访问了!如果不希望文件被上传,可以在高级设置里关闭. | |
''' | |
def load_zip_as_db(file_from_gradio, | |
pdf_loader, | |
embedding_model, | |
chunk_size=300, | |
chunk_overlap=20, | |
upload_to_cloud=True): | |
if chunk_size <= chunk_overlap: | |
return "chunk_size小于chunk_overlap. 创建失败.", None, None | |
if file_from_gradio is None: | |
return "文件为空. 创建失败.", None, None | |
pdf_loader = PDF_LOADER_MAPPING[pdf_loader] | |
zip_file_path = file_from_gradio.name | |
project_name = uuid.uuid4().hex | |
db, project_name, db_meta = create_faiss_index_from_zip(zip_file_path, embeddings=embedding_model, | |
pdf_loader=pdf_loader, chunk_size=chunk_size, | |
chunk_overlap=chunk_overlap, project_name=project_name) | |
index_name = project_name + ".zip" | |
make_archive(project_name, index_name) | |
date = datetime.today().strftime('%Y-%m-%d') | |
if upload_to_cloud: | |
hf_api.upload_file(path_or_fileobj=index_name, | |
path_in_repo=f"{date}/faiss_{index_name}.zip", | |
repo_id=UPLOAD_REPO_ID, | |
repo_type="dataset") | |
return "成功创建知识库. 可以开始聊天了!", index_name, db, db_meta | |
def load_local_db(file_from_gradio): | |
if file_from_gradio is None: | |
return "文件为空. 创建失败.", None | |
zip_file_path = file_from_gradio.name | |
db = load_faiss_index_from_zip(zip_file_path) | |
return "成功读取知识库. 可以开始聊天了!", db | |
def extract_image(image_path): | |
from PIL import Image | |
print("Image Path:", image_path) | |
im = Image.open(image_path) | |
table = process_image(im) | |
print(f"Success in processing the image. Table: {table}") | |
return table, add_markup(table) | |
def describe(image): | |
table = add_markup(process_image(image)) | |
_INSTRUCTION = 'Read the table below to answer the following questions.' | |
question = "Please refer to the above table, and write a summary of no less than 200 words based on it in Chinese, ensuring that your response is detailed and precise. " | |
prompt_0shot = _INSTRUCTION + "\n" + add_markup(table) + "\n" + "Q: " + question + "\n" + "A:" | |
messages = [{"role": "assistant", "content": prompt_0shot}] | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=messages, | |
temperature=0.7, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0, | |
) | |
ret = response.choices[0].message['content'] | |
return ret | |
with gr.Blocks() as demo: | |
local_db = gr.State(None) | |
def get_augmented_message(message, local_db, query_count, preprocessing, meta): | |
print(f"Receiving message: {message}") | |
print("Detecting if the user need to read image from the local database...") | |
# read the db_meta.json from the local file | |
# read the images file list | |
files = meta["files"] | |
source_path = meta["source_path"] | |
# with open(meta.name, "r", encoding="utf-8") as f: | |
# files = json.load(f)["files"] | |
img_files = [] | |
for file in files: | |
if os.path.splitext(file)[1] in [".png", ".jpg"]: | |
img_files.append(file) | |
# scan user's input to see if it contains images' name | |
do_extract_image = False | |
target_file = None | |
for file in img_files: | |
img = os.path.splitext(file)[0] | |
if img in message: | |
do_extract_image = True | |
target_file = file | |
break | |
# extract image to tables | |
image_info = "" | |
if do_extract_image: | |
print("The user needs to read image from the local database. Extract image ... ") | |
target_file = os.path.join(source_path, target_file) | |
_, image_info = extract_image(target_file) | |
if len(image_info)>0: | |
image_content = {"content": image_info, "source": os.path.basename(target_file)} | |
else: | |
image_content = None | |
print("Querying references from the local database...") | |
contents = [] | |
try: | |
if query_count > 0: | |
docs = local_db.similarity_search(message, k=query_count) | |
for i in range(query_count): | |
# pre-processing each chunk | |
content = docs[i].page_content.replace('\n', ' ') | |
# pre-process meta data | |
contents.append(content) | |
except: | |
print("Failed to query from the local database. ") | |
# generate augmented_message | |
print("Success in querying references: {}".format(contents)) | |
if image_content is not None: | |
augmented_message = f"{image_content}\n\n---\n\n" + "\n\n---\n\n".join(contents) + "\n\n-----\n\n" | |
else: | |
augmented_message = "\n\n---\n\n".join(contents) + "\n\n-----\n\n" | |
return augmented_message + "\n\n" + f"'user_input': {message}" | |
def respond(message, local_db, chat_history, meta, query_count=5, test_mode=False, response_delay=5, preprocessing=False): | |
gpt_chatbot = OpenAIChatBot() | |
print("Chat History: ", chat_history) | |
print("Local DB: ", local_db is None) | |
for chat in chat_history: | |
gpt_chatbot.load_chat(chat) | |
if local_db is None or query_count == 0: | |
bot_message = gpt_chatbot(message) | |
print(bot_message) | |
print(message) | |
chat_history.append((message, bot_message)) | |
return "", chat_history | |
else: | |
augmented_message = get_augmented_message(message, local_db, query_count, preprocessing, meta) | |
bot_message = gpt_chatbot(augmented_message, original_message=message) | |
print(message) | |
print(augmented_message) | |
print(bot_message) | |
if test_mode: | |
chat_history.append((augmented_message, bot_message)) | |
else: | |
chat_history.append((message, bot_message)) | |
time.sleep(response_delay) # sleep 5 seconds to avoid freq. wall. | |
return "", chat_history | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(INSTRUCTIONS) | |
with gr.Row(): | |
with gr.Tab("从本地PDF文件创建知识库"): | |
zip_file = gr.File(file_types=[".zip"], label="本地PDF文件(.zip)") | |
create_db = gr.Button("创建知识库", variant="primary") | |
with gr.Accordion("高级设置", open=False): | |
embedding_selector = gr.Dropdown(ALL_EMBEDDINGS, | |
value="distilbert-dot-tas_b-b256-msmarco", | |
label="Embedding Models") | |
pdf_loader_selector = gr.Dropdown([loader.__name__ for loader in ALL_PDF_LOADERS], | |
value=PyPDFLoader.__name__, label="PDF Loader") | |
chunk_size_slider = gr.Slider(minimum=50, maximum=2000, step=50, value=500, | |
label="Chunk size (tokens)") | |
chunk_overlap_slider = gr.Slider(minimum=0, maximum=500, step=1, value=50, | |
label="Chunk overlap (tokens)") | |
save_to_cloud_checkbox = gr.Checkbox(value=False, label="把数据库上传到云端") | |
file_dp_output = gr.File(file_types=[".zip"], label="(输出)知识库文件(.zip)") | |
with gr.Tab("读取本地知识库文件"): | |
file_local = gr.File(file_types=[".zip"], label="本地知识库文件(.zip)") | |
load_db = gr.Button("读取已创建知识库", variant="primary") | |
with gr.Tab("总结图表(Demo)"): | |
gr.Markdown(r"代码来源于: https://huggingface.co/spaces/fl399/deplot_plus_llm") | |
input_image = gr.Image(label="Input Image", type="pil", interactive=True) | |
extract = gr.Button("总结", variant="primary") | |
output_text = gr.Textbox(lines=8, label="Output") | |
with gr.Column(): | |
status = gr.Textbox(label="用来显示程序运行状态的Textbox") | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox() | |
submit = gr.Button("Submit", variant="primary") | |
with gr.Accordion("高级设置", open=False): | |
json_output = gr.JSON() | |
with gr.Row(): | |
query_count_slider = gr.Slider(minimum=0, maximum=10, step=1, value=3, | |
label="Query counts") | |
test_mode_checkbox = gr.Checkbox(label="Test mode") | |
# def load_pdf_as_db(file_from_gradio, | |
# pdf_loader, | |
# embedding_model, | |
# chunk_size=300, | |
# chunk_overlap=20, | |
# upload_to_cloud=True): | |
msg.submit(respond, [msg, local_db, chatbot, json_output, query_count_slider, test_mode_checkbox], [msg, chatbot]) | |
submit.click(respond, [msg, local_db, chatbot, json_output, query_count_slider, test_mode_checkbox], [msg, chatbot]) | |
create_db.click(load_zip_as_db, [zip_file, pdf_loader_selector, embedding_selector, chunk_size_slider, chunk_overlap_slider, save_to_cloud_checkbox], | |
[status, file_dp_output, local_db, json_output]) | |
load_db.click(load_local_db, [file_local], [status, local_db]) | |
extract.click(describe, [input_image], [output_text]) | |
demo.launch(show_api=False) | |