ShaoXia's picture
初始化代码
069157b
import json
import os
import time
import uuid
from datetime import datetime
import gradio as gr
import openai
from huggingface_hub import HfApi
from langchain.document_loaders import PyPDFLoader, \
UnstructuredPDFLoader, PyPDFium2Loader, PyMuPDFLoader, PDFPlumberLoader
from knowledge.faiss_handler import create_faiss_index_from_zip, load_faiss_index_from_zip
from knowledge.img_handler import process_image, add_markup
from llms.chatbot import OpenAIChatBot
from llms.embeddings import EMBEDDINGS_MAPPING
from utils import make_archive
UPLOAD_REPO_ID=os.getenv("UPLOAD_REPO_ID")
HF_TOKEN=os.getenv("HF_TOKEN")
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.api_base == os.getenv("OPENAI_API_BASE")
hf_api = HfApi(token=HF_TOKEN)
ALL_PDF_LOADERS = [PyPDFLoader, UnstructuredPDFLoader, PyPDFium2Loader, PyMuPDFLoader, PDFPlumberLoader]
ALL_EMBEDDINGS = EMBEDDINGS_MAPPING.keys()
PDF_LOADER_MAPPING = {loader.__name__: loader for loader in ALL_PDF_LOADERS}
#######################################################################################################################
# Host multiple vector database for use
#######################################################################################################################
# todo: add this feature in the future
INSTRUCTIONS = '''# FAISS Chat: 和本地数据库聊天!
***2023-06-06更新:***
1. 支持读取图片格式的图表数据(目前支持JPG, PNG).
2. 在"总结图表(Demo)"的标签页里提供了这个模块的测试.
***2023-06-04更新:***
1. 支持更多的Embedding Model (目前支持[text-embedding-ada-002](https://openai.com/blog/new-and-improved-embedding-model), [text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese), 和[distilbert-dot-tas_b-b256-msmarco](https://huggingface.co/sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco) )
2. 支持更多的文件格式(PDF, TXT, TEX, 和MD).
3. 所有生成的数据库都可以在[这个数据集](https://huggingface.co/datasets/shaocongma/shared-faiss-vdb)里访问了!如果不希望文件被上传,可以在高级设置里关闭.
'''
def load_zip_as_db(file_from_gradio,
pdf_loader,
embedding_model,
chunk_size=300,
chunk_overlap=20,
upload_to_cloud=True):
if chunk_size <= chunk_overlap:
return "chunk_size小于chunk_overlap. 创建失败.", None, None
if file_from_gradio is None:
return "文件为空. 创建失败.", None, None
pdf_loader = PDF_LOADER_MAPPING[pdf_loader]
zip_file_path = file_from_gradio.name
project_name = uuid.uuid4().hex
db, project_name, db_meta = create_faiss_index_from_zip(zip_file_path, embeddings=embedding_model,
pdf_loader=pdf_loader, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, project_name=project_name)
index_name = project_name + ".zip"
make_archive(project_name, index_name)
date = datetime.today().strftime('%Y-%m-%d')
if upload_to_cloud:
hf_api.upload_file(path_or_fileobj=index_name,
path_in_repo=f"{date}/faiss_{index_name}.zip",
repo_id=UPLOAD_REPO_ID,
repo_type="dataset")
return "成功创建知识库. 可以开始聊天了!", index_name, db, db_meta
def load_local_db(file_from_gradio):
if file_from_gradio is None:
return "文件为空. 创建失败.", None
zip_file_path = file_from_gradio.name
db = load_faiss_index_from_zip(zip_file_path)
return "成功读取知识库. 可以开始聊天了!", db
def extract_image(image_path):
from PIL import Image
print("Image Path:", image_path)
im = Image.open(image_path)
table = process_image(im)
print(f"Success in processing the image. Table: {table}")
return table, add_markup(table)
def describe(image):
table = add_markup(process_image(image))
_INSTRUCTION = 'Read the table below to answer the following questions.'
question = "Please refer to the above table, and write a summary of no less than 200 words based on it in Chinese, ensuring that your response is detailed and precise. "
prompt_0shot = _INSTRUCTION + "\n" + add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
messages = [{"role": "assistant", "content": prompt_0shot}]
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0.7,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
)
ret = response.choices[0].message['content']
return ret
with gr.Blocks() as demo:
local_db = gr.State(None)
def get_augmented_message(message, local_db, query_count, preprocessing, meta):
print(f"Receiving message: {message}")
print("Detecting if the user need to read image from the local database...")
# read the db_meta.json from the local file
# read the images file list
files = meta["files"]
source_path = meta["source_path"]
# with open(meta.name, "r", encoding="utf-8") as f:
# files = json.load(f)["files"]
img_files = []
for file in files:
if os.path.splitext(file)[1] in [".png", ".jpg"]:
img_files.append(file)
# scan user's input to see if it contains images' name
do_extract_image = False
target_file = None
for file in img_files:
img = os.path.splitext(file)[0]
if img in message:
do_extract_image = True
target_file = file
break
# extract image to tables
image_info = ""
if do_extract_image:
print("The user needs to read image from the local database. Extract image ... ")
target_file = os.path.join(source_path, target_file)
_, image_info = extract_image(target_file)
if len(image_info)>0:
image_content = {"content": image_info, "source": os.path.basename(target_file)}
else:
image_content = None
print("Querying references from the local database...")
contents = []
try:
if query_count > 0:
docs = local_db.similarity_search(message, k=query_count)
for i in range(query_count):
# pre-processing each chunk
content = docs[i].page_content.replace('\n', ' ')
# pre-process meta data
contents.append(content)
except:
print("Failed to query from the local database. ")
# generate augmented_message
print("Success in querying references: {}".format(contents))
if image_content is not None:
augmented_message = f"{image_content}\n\n---\n\n" + "\n\n---\n\n".join(contents) + "\n\n-----\n\n"
else:
augmented_message = "\n\n---\n\n".join(contents) + "\n\n-----\n\n"
return augmented_message + "\n\n" + f"'user_input': {message}"
def respond(message, local_db, chat_history, meta, query_count=5, test_mode=False, response_delay=5, preprocessing=False):
gpt_chatbot = OpenAIChatBot()
print("Chat History: ", chat_history)
print("Local DB: ", local_db is None)
for chat in chat_history:
gpt_chatbot.load_chat(chat)
if local_db is None or query_count == 0:
bot_message = gpt_chatbot(message)
print(bot_message)
print(message)
chat_history.append((message, bot_message))
return "", chat_history
else:
augmented_message = get_augmented_message(message, local_db, query_count, preprocessing, meta)
bot_message = gpt_chatbot(augmented_message, original_message=message)
print(message)
print(augmented_message)
print(bot_message)
if test_mode:
chat_history.append((augmented_message, bot_message))
else:
chat_history.append((message, bot_message))
time.sleep(response_delay) # sleep 5 seconds to avoid freq. wall.
return "", chat_history
with gr.Row():
with gr.Column():
gr.Markdown(INSTRUCTIONS)
with gr.Row():
with gr.Tab("从本地PDF文件创建知识库"):
zip_file = gr.File(file_types=[".zip"], label="本地PDF文件(.zip)")
create_db = gr.Button("创建知识库", variant="primary")
with gr.Accordion("高级设置", open=False):
embedding_selector = gr.Dropdown(ALL_EMBEDDINGS,
value="distilbert-dot-tas_b-b256-msmarco",
label="Embedding Models")
pdf_loader_selector = gr.Dropdown([loader.__name__ for loader in ALL_PDF_LOADERS],
value=PyPDFLoader.__name__, label="PDF Loader")
chunk_size_slider = gr.Slider(minimum=50, maximum=2000, step=50, value=500,
label="Chunk size (tokens)")
chunk_overlap_slider = gr.Slider(minimum=0, maximum=500, step=1, value=50,
label="Chunk overlap (tokens)")
save_to_cloud_checkbox = gr.Checkbox(value=False, label="把数据库上传到云端")
file_dp_output = gr.File(file_types=[".zip"], label="(输出)知识库文件(.zip)")
with gr.Tab("读取本地知识库文件"):
file_local = gr.File(file_types=[".zip"], label="本地知识库文件(.zip)")
load_db = gr.Button("读取已创建知识库", variant="primary")
with gr.Tab("总结图表(Demo)"):
gr.Markdown(r"代码来源于: https://huggingface.co/spaces/fl399/deplot_plus_llm")
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
extract = gr.Button("总结", variant="primary")
output_text = gr.Textbox(lines=8, label="Output")
with gr.Column():
status = gr.Textbox(label="用来显示程序运行状态的Textbox")
chatbot = gr.Chatbot()
msg = gr.Textbox()
submit = gr.Button("Submit", variant="primary")
with gr.Accordion("高级设置", open=False):
json_output = gr.JSON()
with gr.Row():
query_count_slider = gr.Slider(minimum=0, maximum=10, step=1, value=3,
label="Query counts")
test_mode_checkbox = gr.Checkbox(label="Test mode")
# def load_pdf_as_db(file_from_gradio,
# pdf_loader,
# embedding_model,
# chunk_size=300,
# chunk_overlap=20,
# upload_to_cloud=True):
msg.submit(respond, [msg, local_db, chatbot, json_output, query_count_slider, test_mode_checkbox], [msg, chatbot])
submit.click(respond, [msg, local_db, chatbot, json_output, query_count_slider, test_mode_checkbox], [msg, chatbot])
create_db.click(load_zip_as_db, [zip_file, pdf_loader_selector, embedding_selector, chunk_size_slider, chunk_overlap_slider, save_to_cloud_checkbox],
[status, file_dp_output, local_db, json_output])
load_db.click(load_local_db, [file_local], [status, local_db])
extract.click(describe, [input_image], [output_text])
demo.launch(show_api=False)