Spaces:
Running
Running
from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import FileResponse | |
from datasets import load_dataset | |
from fastapi.middleware.cors import CORSMiddleware | |
import pdfplumber | |
import pytesseract | |
# Loading | |
import os | |
import zipfile | |
import shutil | |
from os import makedirs,getcwd | |
from os.path import join,exists,dirname | |
import torch | |
import json | |
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
NUM_PROC = os.cpu_count() | |
parent_path = dirname(getcwd()) | |
temp_path = join(parent_path,'temp') | |
if not exists(temp_path ): | |
makedirs(temp_path ) | |
# Determine device based on GPU availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
import logging | |
logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING) | |
logging.getLogger("haystack").setLevel(logging.INFO) | |
document_store = QdrantDocumentStore( | |
path="database", | |
recreate_index=True, | |
use_sparse_embeddings=True, | |
embedding_dim = 384 | |
) | |
def extract_zip(zip_path, target_folder): | |
""" | |
Extracts all files from a ZIP archive and returns a list of their paths. | |
Args: | |
zip_path (str): Path to the ZIP file. | |
target_folder (str): Folder where the files will be extracted. | |
Returns: | |
List[str]: List of extracted file paths. | |
""" | |
extracted_files = [] | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
zip_ref.extractall(target_folder) | |
for filename in zip_ref.namelist(): | |
extracted_files.append(os.path.join(target_folder, filename)) | |
return extracted_files | |
def extract_text_from_pdf(pdf_path): | |
with pdfplumber.open(pdf_path) as pdf: | |
text = "" | |
for page in pdf.pages: | |
text += page.extract_text() | |
return text | |
def extract_ocr_text_from_pdf(pdf_path): | |
from pdf2image import convert_from_path | |
images = convert_from_path(pdf_path) | |
text= "" | |
for image in images: | |
text += pytesseract.image_to_string(image,lang='vie') | |
return text | |
async def create_upload_file(text_field: str, file: UploadFile = File(...), ocr:bool=False): | |
# Imports | |
import time | |
from haystack import Document, Pipeline | |
from haystack.components.writers import DocumentWriter | |
from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever | |
from haystack.document_stores.types import DuplicatePolicy | |
from haystack_integrations.components.embedders.fastembed import ( | |
FastembedTextEmbedder, | |
FastembedDocumentEmbedder, | |
FastembedSparseTextEmbedder, | |
FastembedSparseDocumentEmbedder | |
) | |
start_time = time.time() | |
file_savePath = join(temp_path,file.filename) | |
with open(file_savePath,'wb') as f: | |
shutil.copyfileobj(file.file, f) | |
documents=[] | |
# Here you can save the file and do other operations as needed | |
if '.json' in file_savePath: | |
with open(file_savePath) as fd: | |
for line in fd: | |
obj = json.loads(line) | |
document = Document(content=obj[text_field], meta=obj) | |
documents.append(document) | |
elif '.zip' in file_savePath: | |
extracted_files_list = extract_zip(file_savePath, temp_path) | |
print("Extracted files:") | |
for file_path in extracted_files_list: | |
if '.pdf' in file_path: | |
if ocr: | |
text = extract_ocr_text_from_pdf(file_path) | |
else: | |
text = extract_text_from_pdf(file_path) | |
obj = {text_field:text,file_path:file_path} | |
document = Document(content=obj[text_field], meta=obj) | |
documents.append(document) | |
else: | |
raise NotImplementedError("This feature is not supported yet") | |
# Indexing | |
indexing = Pipeline() | |
indexing.add_component("sparse_doc_embedder", FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1")) | |
indexing.add_component("dense_doc_embedder", FastembedDocumentEmbedder(model="BAAI/bge-small-en-v1.5")) | |
indexing.add_component("writer", DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE)) | |
indexing.connect("sparse_doc_embedder", "dense_doc_embedder") | |
indexing.connect("dense_doc_embedder", "writer") | |
indexing.run({"sparse_doc_embedder": {"documents": documents}}) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time} | |
def search(prompt: str): | |
import time | |
from haystack import Document, Pipeline | |
from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever | |
from haystack_integrations.components.embedders.fastembed import ( | |
FastembedTextEmbedder, | |
FastembedSparseTextEmbedder | |
) | |
from haystack.components.rankers import TransformersSimilarityRanker | |
from haystack.components.joiners import DocumentJoiner | |
start_time = time.time() | |
# Querying | |
querying = Pipeline() | |
querying.add_component("sparse_text_embedder", FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1")) | |
querying.add_component("dense_text_embedder", FastembedTextEmbedder( | |
model="BAAI/bge-small-en-v1.5", prefix="Represent this sentence for searching relevant passages: ") | |
) | |
querying.add_component("retriever", QdrantHybridRetriever(document_store=document_store)) | |
querying.add_component("document_joiner", DocumentJoiner()) | |
querying.add_component("ranker", TransformersSimilarityRanker(model="BAAI/bge-reranker-base")) | |
querying.connect("sparse_text_embedder.sparse_embedding", "retriever.query_sparse_embedding") | |
querying.connect("dense_text_embedder.embedding", "retriever.query_embedding") | |
querying.connect("retriever", "document_joiner") | |
querying.connect("document_joiner", "ranker") | |
question = "Cosa sono i marker tumorali?" | |
results = querying.run( | |
{"dense_text_embedder": {"text": question}, | |
"sparse_text_embedder": {"text": question}} | |
) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
print(f"Execution time: {elapsed_time:.6f} seconds") | |
return results["retriever"]["documents"] | |
async def download_database(): | |
import time | |
start_time = time.time() | |
# Path to the database directory | |
database_dir = join(os.getcwd(), 'database') | |
# Path for the zip file | |
zip_path = join(os.getcwd(), 'database.zip') | |
# Create a zip file of the database directory | |
shutil.make_archive(zip_path.replace('.zip', ''), 'zip', database_dir) | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
print(f"Execution time: {elapsed_time:.6f} seconds") | |
# Return the zip file as a response for download | |
return FileResponse(zip_path, media_type='application/zip', filename='database.zip') | |
async def create_upload_file(file: UploadFile = File(...)): | |
import pytesseract | |
from pdf2image import convert_from_path | |
file_savePath = join(temp_path,file.filename) | |
with open(file_savePath,'wb') as f: | |
shutil.copyfileobj(file.file, f) | |
# convert PDF to image | |
images = convert_from_path(file_savePath) | |
text="" | |
# Extract text from images | |
for image in images: | |
ocr_text = pytesseract.image_to_string(image,lang='vie') | |
text=text+ocr_text+'\n' | |
return text | |
def api_home(): | |
return {'detail': 'Welcome to FastAPI Qdrant importer!'} | |