dammy commited on
Commit
9035153
·
1 Parent(s): 301614f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -2
app.py CHANGED
@@ -1,11 +1,39 @@
1
  import gradio as gr
2
  from langchain.document_loaders import PDFMinerLoader, PyMuPDFLoader
3
  from langchain.text_splitter import CharacterTextSplitter
4
-
 
 
5
 
6
 
7
  import gradio as gr
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  def upload_pdf(file):
10
  # Save the uploaded file
11
  file_name = file.name
@@ -20,7 +48,42 @@ def upload_pdf(file):
20
 
21
  texts = [i.page_content for i in texts]
22
 
23
- return texts[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  iface = gr.Interface(
26
  fn=upload_pdf,
 
1
  import gradio as gr
2
  from langchain.document_loaders import PDFMinerLoader, PyMuPDFLoader
3
  from langchain.text_splitter import CharacterTextSplitter
4
+ import chromadb
5
+ import chromadb.config
6
+ from chromadb.config import Settings
7
 
8
 
9
  import gradio as gr
10
 
11
+ def get_context(query_text):
12
+ query_emb = st_model.encode(query_text)
13
+ query_response = collection.query(query_embeddings=query_emb.tolist(), n_results=4)
14
+ context = query_response['documents'][0][0]
15
+ context = context.replace('\n', ' ').replace(' ', ' ')
16
+ return context
17
+
18
+ def local_query(query, context):
19
+ t5query = """Using the available context, please answer the question.
20
+ If you aren't sure please say i don't know.
21
+ Context: {}
22
+ Question: {}
23
+ """.format(context, query)
24
+
25
+ inputs = tokenizer(t5query, return_tensors="pt")
26
+ outputs = model.generate(**inputs, max_new_tokens=20)
27
+
28
+ return tokenizer.batch_decode(outputs, skip_special_tokens=True)
29
+
30
+ def run_query(query):
31
+ context = get_context(query)
32
+ result = local_query(query, context)
33
+ return result
34
+
35
+
36
+
37
  def upload_pdf(file):
38
  # Save the uploaded file
39
  file_name = file.name
 
48
 
49
  texts = [i.page_content for i in texts]
50
 
51
+ doc_emb = st_model.encode(texts)
52
+ doc_emb = doc_emb.tolist()
53
+
54
+ ids = [str(uuid.uuid1()) for _ in doc_emb]
55
+
56
+ client = chromadb.Client()
57
+ # Create collection. get_collection, get_or_create_collection, delete_collection also available!
58
+ collection = client.create_collection("test_db")
59
+
60
+ collection.add(
61
+ embeddings=doc_emb,
62
+ documents=texts,
63
+ ids=ids,
64
+ metadata = ["Page": 1, "Section": "diagnosis/prognosis"]
65
+ )
66
+
67
+ return run_query("how to reduce waste?")
68
+
69
+
70
+ from transformers import T5ForConditionalGeneration, AutoTokenizer
71
+ import torch
72
+
73
+
74
+ model_name = 'google/flan-t5-base'
75
+
76
+ model = T5ForConditionalGeneration.from_pretrained(model_name, device_map='auto', offload_folder="offload")
77
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
78
+
79
+ import uuid
80
+ from sentence_transformers import SentenceTransformer
81
+
82
+ ST_name = 'sentence-transformers/sentence-t5-base'
83
+
84
+ st_model = SentenceTransformer(ST_name)
85
+
86
+
87
 
88
  iface = gr.Interface(
89
  fn=upload_pdf,