IProject-10 commited on
Commit
744ae18
·
verified ·
1 Parent(s): 920b5d6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +125 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from rank_bm25 import BM25Okapi
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForQuestionAnswering
4
+ import torch
5
+ import gradio as gr
6
+ from docx import Document
7
+ import pdfplumber
8
+
9
+ # Load the fine-tuned BERT-based QA model and tokenizer
10
+ model_name = "IProject-10/roberta-base-finetuned-squad2" # Replace with your model name
11
+ qa_model = AutoModelForQuestionAnswering.from_pretrained(model_name)
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+
14
+ # Set up the device for BERT
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ qa_model.to(device)
17
+ qa_model.eval()
18
+
19
+ # Create a pipeline for retrieval-augmented QA
20
+ retrieval_qa_pipeline = pipeline(
21
+ "question-answering",
22
+ model=qa_model,
23
+ tokenizer=tokenizer,
24
+ device=device.index if torch.cuda.is_available() else -1
25
+ )
26
+
27
+ def extract_text_from_file(file):
28
+ # Determine the file extension
29
+ file_extension = os.path.splitext(file.name)[1].lower()
30
+ text = ""
31
+
32
+ try:
33
+ if file_extension == ".txt":
34
+ with open(file.name, "r") as f:
35
+ text = f.read()
36
+ elif file_extension == ".docx":
37
+ doc = Document(file.name)
38
+ for para in doc.paragraphs:
39
+ text += para.text + "\n"
40
+ elif file_extension == ".pdf":
41
+ with pdfplumber.open(file.name) as pdf:
42
+ for page in pdf.pages:
43
+ text += page.extract_text() + "\n"
44
+ else:
45
+ raise ValueError("Unsupported file format: {}".format(file_extension))
46
+ except Exception as e:
47
+ text = str(e)
48
+ return text
49
+
50
+ def load_passages(files):
51
+ passages = []
52
+ for file in files:
53
+ passage = extract_text_from_file(file)
54
+ passages.append(passage)
55
+ return passages
56
+
57
+ def highlight_answer(context, answer):
58
+ start_index = context.find(answer)
59
+ if start_index != -1:
60
+ end_index = start_index + len(answer)
61
+ highlighted_context = f"{context[:start_index]}_________<<{context[start_index:end_index]}>>_________{context[end_index:]}"
62
+ return highlighted_context
63
+ else:
64
+ return context
65
+
66
+ def answer_question(question, files):
67
+ try:
68
+ # Load passages from the uploaded files
69
+ passages = load_passages(files)
70
+
71
+ # Create an index using BM25
72
+ bm25 = BM25Okapi([passage.split() for passage in passages])
73
+
74
+ # Retrieve relevant passages using BM25
75
+ tokenized_query = question.split()
76
+ candidate_passages = bm25.get_top_n(tokenized_query, passages, n=3)
77
+ bm25_scores = bm25.get_scores(tokenized_query)
78
+
79
+ # Extract answer using the pipeline for each candidate passage
80
+ answers_with_context = []
81
+ for passage in candidate_passages:
82
+ answer = retrieval_qa_pipeline(question=question, context=passage)
83
+ bm25_score = bm25_scores[passages.index(passage)]
84
+ answer_with_context = {
85
+ "context": passage,
86
+ "answer": answer["answer"],
87
+ "BM25-score": bm25_score # BM25 confidence score for this passage
88
+ }
89
+ answers_with_context.append(answer_with_context)
90
+
91
+ # Choose the answer with the highest model confidence score
92
+ best_answer = max(answers_with_context, key=lambda x: x["BM25-score"])
93
+
94
+ # Highlight the answer in the context
95
+ highlighted_context = highlight_answer(best_answer["context"], best_answer["answer"])
96
+
97
+ return best_answer["answer"], highlighted_context, best_answer["BM25-score"]
98
+ except Exception as e:
99
+ return str(e), "", ""
100
+
101
+ # Define Gradio interface
102
+ iface = gr.Interface(
103
+ fn=answer_question,
104
+ inputs=[
105
+ gr.Textbox(lines=2, placeholder="Enter your question here...", label="Question"),
106
+ gr.Files(label="Upload text, Word, or PDF files")
107
+ ],
108
+ outputs=[
109
+ gr.Textbox(label="Answer"),
110
+ gr.Textbox(label="Context"),
111
+ gr.Textbox(label="BM25 Score")
112
+ ],
113
+ title="Question Answering Model",
114
+ description="Upload a text document and ask a question from the content",
115
+ css="""
116
+ .container { max-width: 800px; margin: auto; }
117
+ .interface-title { font-family: Arial, sans-serif; font-size: 24px; font-weight: bold; }
118
+ .interface-description { font-family: Arial, sans-serif; font-size: 16px; margin-bottom: 20px; }
119
+ .input-textbox, .output-textbox { font-family: Arial, sans-serif; font-size: 14px; }
120
+ .error { color: red; font-family: Arial, sans-serif; font-size: 14px; }
121
+ """
122
+ )
123
+
124
+ # Launch the interface
125
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ transformers
4
+ gradio
5
+ python-docx
6
+ pdfplumber
7
+ rank-bm25
8
+