Spaces:
Runtime error
Runtime error
import fastapi as api | |
from typing import Annotated | |
from fastapi.security import OAuth2PasswordBearer, OAuth2AuthorizationCodeBearer, OAuth2PasswordRequestForm | |
from model.document import Document, PlainTextDocument, JsonDocument | |
import sys | |
from model.user import User | |
from fastapi import FastAPI, File, UploadFile | |
from di import initialize_di_for_app | |
import gradio as gr | |
import os | |
import json | |
SETTINGS, STORAGE, EMBEDDING, INDEX = initialize_di_for_app() | |
user_json_str = STORAGE.load('user.json') | |
USER = User.parse_raw(user_json_str) | |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/token") | |
app = api.FastAPI() | |
app.openapi_version = "3.0.0" | |
users = [USER] | |
async def get_current_user(token: str = api.Depends(oauth2_scheme)): | |
''' | |
Get current user | |
''' | |
for user in users: | |
if user.user_name == token: | |
return user | |
raise api.HTTPException(status_code=401, detail="Invalid authentication credentials") | |
async def login(form_data: Annotated[OAuth2PasswordRequestForm, api.Depends()]): | |
''' | |
Login to get a token | |
''' | |
return {"access_token": form_data.username} | |
def create_upload_file(file: UploadFile = api.File(...)) -> Document: | |
''' | |
Upload a file | |
''' | |
fileUrl = f'{USER.user_name}-{file.filename}' | |
STORAGE.save(fileUrl, file.read()) | |
# create plainTextDocument if the file is a text file | |
if file.filename.endswith('.txt'): | |
return PlainTextDocument( | |
name=file.filename, | |
status='uploading', | |
url=fileUrl, | |
embedding=EMBEDDING, | |
storage=STORAGE, | |
) | |
else: | |
raise api.HTTPException(status_code=400, detail="File type not supported") | |
### /api/v1/.well-known | |
#### Get /openapi.json | |
# Get the openapi json file | |
async def get_openapi(): | |
''' | |
otherwise return 401 | |
''' | |
# get a list of document names + description | |
document_list = [[doc.name, doc.description] for doc in USER.documents] | |
# get openapi json from api | |
openapi = app.openapi().copy() | |
openapi['info']['title'] = 'DocumentSearch' | |
description = f'''Search documents with a query. | |
## Documents | |
{document_list} | |
''' | |
openapi['info']['description'] = description | |
# update description in /api/v1/search | |
openapi['paths']['/api/v1/search']['get']['description'] += f''' | |
Available documents: | |
{document_list} | |
''' | |
# filter out unnecessary endpoints | |
openapi['paths'] = { | |
'/api/v1/search': openapi['paths']['/api/v1/search'], | |
} | |
# remove components | |
openapi['components'] = {} | |
# return the openapi json | |
return openapi | |
### /api/v1/document | |
#### Get /list | |
# Get the list of documents | |
# async def get_document_list(user: Annotated[User, api.Depends(get_current_user)]) -> list[Document]: | |
async def get_document_list() -> list[Document]: | |
''' | |
Get the list of documents | |
''' | |
return USER.documents | |
#### Post /upload | |
# Upload a document | |
# def upload_document(user: Annotated[User, api.Depends(get_current_user)], document: Annotated[Document, api.Depends(create_upload_file)]): | |
def upload_document(document: Annotated[Document, api.Depends(create_upload_file)]): | |
''' | |
Upload a document | |
''' | |
document.status = 'processing' | |
INDEX.load_or_update_document(user, document, progress) | |
document.status = 'done' | |
USER.documents.append(document) | |
#### Get /delete | |
# Delete a document | |
# async def delete_document(user: Annotated[User, api.Depends(get_current_user)], document_name: str): | |
async def delete_document(document_name: str): | |
''' | |
Delete a document | |
''' | |
for doc in USER.documents: | |
if doc.name == document_name: | |
STORAGE.delete(doc.url) | |
INDEX.remove_document(USER, doc) | |
USER.documents.remove(doc) | |
return | |
raise api.HTTPException(status_code=404, detail="Document not found") | |
# Query the index | |
def search( | |
# user: Annotated[User, api.Depends(get_current_user)], | |
query: str, | |
document_name: str = None, | |
top_k: int = 10, | |
threshold: float = 0.5): | |
''' | |
Search documents with a query. It will return [top_k] results with a score higher than [threshold]. | |
query: the query string, required | |
document_name: the document name, optional. You can provide this parameter to search in a specific document. | |
top_k: the number of results to return, optional. Default to 10. | |
threshold: the threshold of the results, optional. Default to 0.5. | |
''' | |
if document_name: | |
for doc in USER.documents: | |
if doc.name == document_name: | |
return INDEX.query_document(USER, doc, query, top_k, threshold) | |
raise api.HTTPException(status_code=404, detail="Document not found") | |
else: | |
return INDEX.query_index(USER, query, top_k, threshold) | |
def receive_signal(signalNumber, frame): | |
print('Received:', signalNumber) | |
sys.exit() | |
async def startup_event(): | |
import signal | |
signal.signal(signal.SIGINT, receive_signal) | |
# startup tasks | |
def exit_event(): | |
# save USER | |
STORAGE.save('user.json', USER.model_dump_json()) | |
print('exit') | |
user = USER | |
def gradio_upload_document(file: File): | |
file_temp_path = file.name | |
# load file | |
file_name = os.path.basename(file_temp_path) | |
fileUrl = f'{USER.user_name}-{file_name}' | |
with open(file_temp_path, 'r', encoding='utf-8') as f: | |
STORAGE.save(fileUrl, f.read()) | |
# create plainTextDocument if the file is a text file | |
doc = None | |
if file_name.endswith('.txt'): | |
doc = PlainTextDocument( | |
name=file_name, | |
status='uploading', | |
url=fileUrl, | |
embedding=EMBEDDING, | |
storage=STORAGE, | |
) | |
elif file_name.endswith('.json'): | |
doc = JsonDocument( | |
name=file_name, | |
status='uploading', | |
url=fileUrl, | |
embedding=EMBEDDING, | |
storage=STORAGE, | |
) | |
else: | |
raise api.HTTPException(status_code=400, detail="File type not supported") | |
doc.status = 'processing' | |
INDEX.load_or_update_document(user, doc) | |
doc.status = 'done' | |
USER.documents.append(doc) | |
return f'uploaded {file_name}' | |
def gradio_query(query: str, document_name: str = None, top_k: int = 10, threshold: float = 0.5): | |
res_or_exception = search(query, document_name, top_k, threshold) | |
if isinstance(res_or_exception, api.HTTPException): | |
raise res_or_exception | |
# convert to json string | |
records = [record.model_dump(mode='json') for record in res_or_exception] | |
return json.dumps(records, indent=4) | |
with gr.Blocks() as ui: | |
gr.Markdown("#llm-memory") | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
## LLM Memory | |
""") | |
with gr.Row(): | |
user_name = gr.Label(label="User name", value=USER.user_name) | |
# url to .well-known/openapi.json | |
gr.Label(label=".wellknown/openapi.json", value=f"/api/v1/.well-known/openapi.json") | |
# with gr.Tab("avaiable documents"): | |
# available_documents = gr.Label(label="avaiable documents", value="avaiable documents") | |
# refresh_btn = gr.Button(label="refresh", type="button") | |
# refresh_btn.click(lambda: '\r\n'.join([doc.name for doc in USER.documents]), None, available_documents) | |
# documents = USER.documents | |
# for document in documents: | |
# gr.Label(label=document.name, value=document.name) | |
# with gr.Tab("upload document"): | |
# with gr.Tab("upload .txt document"): | |
# file = gr.File(label="upload document", type="file", file_types=[".txt"]) | |
# output = gr.Label(label="output", value="output") | |
# upload_btn = gr.Button("upload document", type="button") | |
# upload_btn.click(gradio_upload_document, file, output) | |
# with gr.Tab("upload .json document"): | |
# gr.Markdown( | |
# """ | |
# The json document should be a list of objects, each object should have a `content` field. If you want to add more fields, you can add them in the `meta_data` field. | |
# For example: | |
# ```json | |
# [ | |
# { | |
# "content": "hello world", | |
# "meta_data": { | |
# "title": "hello world", | |
# "author": "llm-memory" | |
# } | |
# }, | |
# { | |
# "content": "hello world" | |
# "meta_data": { | |
# "title": "hello world", | |
# "author": "llm-memory" | |
# } | |
# } | |
# ] | |
# ``` | |
# ## Note | |
# - The `meta_data` should be a dict which both keys and values are strings. | |
# """) | |
# file = gr.File(label="upload document", type="file", file_types=[".json"]) | |
# output = gr.Label(label="output", value="output") | |
# upload_btn = gr.Button("upload document", type="button") | |
# upload_btn.click(gradio_upload_document, file, output) | |
with gr.Tab("search"): | |
query = gr.Textbox(label="search", placeholder="Query") | |
document = gr.Dropdown(label="document", choices=[None] + [doc.name for doc in USER.documents], placeholder="document, optional") | |
top_k = gr.Number(label="top_k", placeholder="top_k, optional", value=10) | |
threshold = gr.Number(label="threshold", placeholder="threshold, optional", value=0.5) | |
output = gr.Code(label="output", language="json", value="output") | |
query_btn = gr.Button("Query") | |
query_btn.click(gradio_query, [query, document, top_k, threshold], output, api_name="search") | |
gradio_app = gr.routes.App.create_app(ui) | |
app.mount("/", gradio_app) | |
ui.launch() |