dammy commited on
Commit
a01ca04
·
1 Parent(s): 7a3625d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -36
app.py CHANGED
@@ -14,22 +14,27 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
14
  import transformers
15
  import torch
16
 
17
-
18
  model_name = 'google/flan-t5-base'
19
- model = T5ForConditionalGeneration.from_pretrained(model_name, device_map='auto', offload_folder="offload")
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
 
22
 
23
-
24
  ST_name = 'sentence-transformers/sentence-t5-base'
25
  st_model = SentenceTransformer(ST_name)
26
- print('sentence read')
27
 
 
28
  client = chromadb.Client()
29
- collection = client.create_collection("test_db")
30
 
31
 
32
  def get_context(query_text):
 
 
 
 
 
33
 
34
  query_emb = st_model.encode(query_text)
35
  query_response = collection.query(query_embeddings=query_emb.tolist(), n_results=4)
@@ -37,13 +42,23 @@ def get_context(query_text):
37
  context = context.replace('\n', ' ').replace(' ', ' ')
38
  return context
39
 
 
 
40
  def local_query(query, context):
41
- t5query = """Using the available context, please answer the question.
42
- If you aren't sure please say i don't know.
 
 
 
 
 
 
 
43
  Context: {}
44
  Question: {}
45
  """.format(context, query)
46
 
 
47
  inputs = tokenizer(t5query, return_tensors="pt")
48
 
49
  outputs = model.generate(**inputs, max_new_tokens=20)
@@ -55,48 +70,59 @@ def local_query(query, context):
55
 
56
 
57
 
58
- def run_query(btn, history, query):
 
 
 
 
 
 
59
 
60
- context = get_context(query)
61
-
62
- print('calling local query')
63
- result = local_query(query, context)
64
-
65
-
66
- print('printing result after call back')
67
- print(result)
68
 
69
- history.append((query, str(result[0])))
70
 
71
 
72
- print('printing history')
73
- print(history)
74
  return history, ""
75
 
76
 
77
 
78
  def upload_pdf(file):
 
 
 
 
 
 
 
 
79
  try:
80
  if file is not None:
81
 
82
  global collection
83
 
84
  file_name = file.name
85
-
 
86
  loader = PDFMinerLoader(file_name)
87
  doc = loader.load()
88
-
89
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
 
90
  texts = text_splitter.split_documents(doc)
91
 
92
  texts = [i.page_content for i in texts]
93
-
 
94
  doc_emb = st_model.encode(texts)
95
  doc_emb = doc_emb.tolist()
96
-
 
97
  ids = [str(uuid.uuid1()) for _ in doc_emb]
98
 
99
-
100
  collection.add(
101
  embeddings=doc_emb,
102
  documents=texts,
@@ -116,26 +142,28 @@ def upload_pdf(file):
116
 
117
 
118
  with gr.Blocks() as demo:
119
-
 
 
 
 
120
  btn = gr.UploadButton("Upload a PDF", file_types=[".pdf"])
121
- output = gr.Textbox(label="Output Box")
122
- chatbot = gr.Chatbot(height=240)
123
 
124
  with gr.Row():
125
  with gr.Column(scale=0.70):
126
  txt = gr.Textbox(
127
  show_label=False,
128
- placeholder="Enter a question",
129
  )
130
 
131
-
132
- # Event handler for uploading a PDF
 
133
  btn.upload(fn=upload_pdf, inputs=[btn], outputs=[output])
134
- txt.submit(run_query, [btn, chatbot, txt], [chatbot, txt])
135
- #.then(
136
- # generate_response, inputs =[chatbot,],outputs = chatbot,)
137
 
138
 
139
  gr.close_all()
140
- # demo.launch(share=True)
141
- demo.queue().launch()
 
14
  import transformers
15
  import torch
16
 
17
+ # load the model
18
  model_name = 'google/flan-t5-base'
19
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
 
22
 
23
+ # to calculate text embeddings
24
  ST_name = 'sentence-transformers/sentence-t5-base'
25
  st_model = SentenceTransformer(ST_name)
 
26
 
27
+ # to store our embeddings and search
28
  client = chromadb.Client()
29
+ collection = client.create_collection("my_db")
30
 
31
 
32
  def get_context(query_text):
33
+ '''
34
+ Given query in tokenized format, find its embeddings
35
+ Search in Chroma DB
36
+ and return results
37
+ '''
38
 
39
  query_emb = st_model.encode(query_text)
40
  query_response = collection.query(query_embeddings=query_emb.tolist(), n_results=4)
 
42
  context = context.replace('\n', ' ').replace(' ', ' ')
43
  return context
44
 
45
+
46
+
47
  def local_query(query, context):
48
+ '''
49
+ Given query (user response)
50
+ Construct LLM query adding context to it
51
+ Return response of LLM
52
+ '''
53
+
54
+
55
+ t5query = """Please answer the question based on the given context.
56
+ If you are not sure about your response, say I am not sure.
57
  Context: {}
58
  Question: {}
59
  """.format(context, query)
60
 
61
+ # calculate embeddings for the query
62
  inputs = tokenizer(t5query, return_tensors="pt")
63
 
64
  outputs = model.generate(**inputs, max_new_tokens=20)
 
70
 
71
 
72
 
73
+ def run_query(history, query):
74
+ '''
75
+ Run Gradio ChatInterface
76
+ Given user response (query), find the most similar/related part to the question from the uploaded document
77
+ Using Chroma search
78
+ Update the query with context, and ask the question to LLM
79
+ '''
80
 
81
+ context = get_context(query) # find the related part from the pdf
82
+ result = local_query(query, context) # add context to model query
83
+
 
 
 
 
 
84
 
85
+ history.append((query, str(result[0]))) # append result to chatInterface history
86
 
87
 
 
 
88
  return history, ""
89
 
90
 
91
 
92
  def upload_pdf(file):
93
+ '''
94
+ Upload a PDF
95
+ Split into chunks
96
+ Encode each chunk into embeddings
97
+ Assign a unique ID for each chunk embedding
98
+ Construct Chroma DB
99
+ Update your global Chroma DB collection
100
+ '''
101
  try:
102
  if file is not None:
103
 
104
  global collection
105
 
106
  file_name = file.name
107
+
108
+ # Upload pdf document
109
  loader = PDFMinerLoader(file_name)
110
  doc = loader.load()
111
+
112
+ # extract chunks
113
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=10)
114
  texts = text_splitter.split_documents(doc)
115
 
116
  texts = [i.page_content for i in texts]
117
+
118
+ # find embedding for each chunk
119
  doc_emb = st_model.encode(texts)
120
  doc_emb = doc_emb.tolist()
121
+
122
+ # index the embeddings
123
  ids = [str(uuid.uuid1()) for _ in doc_emb]
124
 
125
+ # add each chunk embedding to ChromaDB
126
  collection.add(
127
  embeddings=doc_emb,
128
  documents=texts,
 
142
 
143
 
144
  with gr.Blocks() as demo:
145
+ '''
146
+ Frontend for our tool
147
+ '''
148
+
149
+ # Upload a PDF focument
150
  btn = gr.UploadButton("Upload a PDF", file_types=[".pdf"])
151
+ output = gr.Textbox(label="Output Box") # to put message indicating the status of upload
152
+ chatbot = gr.Chatbot(height=240) # our chatbot interface
153
 
154
  with gr.Row():
155
  with gr.Column(scale=0.70):
156
  txt = gr.Textbox(
157
  show_label=False,
158
+ placeholder="Type a question",
159
  )
160
 
161
+
162
+ # Backend for our tool
163
+ # Event handlers
164
  btn.upload(fn=upload_pdf, inputs=[btn], outputs=[output])
165
+ txt.submit(run_query, [chatbot, txt], [chatbot, txt])
 
 
166
 
167
 
168
  gr.close_all()
169
+ demo.queue().launch() # use query for a better performance