rtabrizi commited on
Commit
cbd01e9
·
1 Parent(s): 7e4f428

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -67,7 +67,7 @@ class Retriever:
67
  def load_chunks(self):
68
  self.text = self.extract_text_from_pdf(self.file_path)
69
  text_splitter = RecursiveCharacterTextSplitter(
70
- chunk_size=300,
71
  chunk_overlap=20,
72
  length_function=self.token_len,
73
  separators=["Section", "\n\n", "\n", ".", " ", ""]
@@ -86,7 +86,7 @@ class Retriever:
86
  self.index.add(self.token_embeddings)
87
 
88
  def retrieve_top_k(self, query_prompt, k=10):
89
- encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", max_length=300, truncation=True, padding=True).to(device)
90
 
91
  with torch.no_grad():
92
  model_output = self.question_model(**encoded_query)
@@ -99,6 +99,7 @@ class Retriever:
99
 
100
  return retrieved_texts
101
 
 
102
  class RAG:
103
  def __init__(self,
104
  file_path,
@@ -134,7 +135,7 @@ class RAG:
134
  return answer
135
 
136
  def extractive_query(self, question):
137
- context = self.retriever.retrieve_top_k(question, k=4)
138
 
139
  inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=150, padding="max_length")
140
  with torch.no_grad():
 
67
  def load_chunks(self):
68
  self.text = self.extract_text_from_pdf(self.file_path)
69
  text_splitter = RecursiveCharacterTextSplitter(
70
+ chunk_size=150,
71
  chunk_overlap=20,
72
  length_function=self.token_len,
73
  separators=["Section", "\n\n", "\n", ".", " ", ""]
 
86
  self.index.add(self.token_embeddings)
87
 
88
  def retrieve_top_k(self, query_prompt, k=10):
89
+ encoded_query = self.question_tokenizer(query_prompt, return_tensors="pt", truncation=True, padding=True).to(device)
90
 
91
  with torch.no_grad():
92
  model_output = self.question_model(**encoded_query)
 
99
 
100
  return retrieved_texts
101
 
102
+
103
  class RAG:
104
  def __init__(self,
105
  file_path,
 
135
  return answer
136
 
137
  def extractive_query(self, question):
138
+ context = self.retriever.retrieve_top_k(question, k=7)
139
 
140
  inputs = self.generator_tokenizer(question, ". ".join(context), return_tensors="pt", truncation=True, max_length=150, padding="max_length")
141
  with torch.no_grad():