8bitnand commited on
Commit
174deaa
·
1 Parent(s): aae4036

pulled code from gpus

Browse files
Files changed (2) hide show
  1. app.py +21 -15
  2. model.py +13 -10
app.py CHANGED
@@ -4,28 +4,32 @@ from model import RAGModel, load_configs
4
 
5
 
6
  def run_on_start():
7
- global r
8
- global configs
9
- configs = load_configs(config_file="rag.configs.yml")
10
- r = RAGModel(configs)
 
 
 
11
 
12
 
13
  def search(query):
14
  g = GoogleSearch(query)
15
  data = g.all_page_data
16
- d = Document(data, min_char_len=configs["document"]["min_char_length"])
17
- st.session_state.doc = d.doc()[0]
18
 
19
 
20
- st.title("LLM powred Google search")
21
 
22
  if "messages" not in st.session_state:
23
- run_on_start()
24
  st.session_state.messages = []
25
 
26
  if "doc" not in st.session_state:
27
  st.session_state.doc = None
28
 
 
 
29
 
30
  for message in st.session_state.messages:
31
  with st.chat_message(message["role"]):
@@ -36,15 +40,17 @@ if prompt := st.chat_input("Search Here insetad of Google"):
36
  st.chat_message("user").markdown(prompt)
37
  st.session_state.messages.append({"role": "user", "content": prompt})
38
 
39
- search(prompt)
40
- s, u = SemanticSearch(
41
- prompt,
 
 
42
  st.session_state.doc,
43
- configs["model"]["embeding_model"],
44
- configs["model"]["device"],
45
  )
46
- topk = s.semantic_search(query=prompt, k=32)
47
- output = r.answer_query(query=prompt, topk_items=topk)
48
  response = output
49
  with st.chat_message("assistant"):
50
  st.markdown(response)
 
4
 
5
 
6
  def run_on_start():
7
+
8
+ if "configs" not in st.session_state:
9
+ st.session_state.configs = configs = load_configs(config_file="rag.configs.yml")
10
+ if "model" not in st.session_state:
11
+ st.session_state.model = RAGModel(configs)
12
+
13
+ run_on_start()
14
 
15
 
16
  def search(query):
17
  g = GoogleSearch(query)
18
  data = g.all_page_data
19
+ d = Document(data, min_char_len=st.session_state.configs["document"]["min_char_length"])
20
+ st.session_state.doc = d.doc()
21
 
22
 
23
+ st.title("Search Here Instead of Google")
24
 
25
  if "messages" not in st.session_state:
 
26
  st.session_state.messages = []
27
 
28
  if "doc" not in st.session_state:
29
  st.session_state.doc = None
30
 
31
+ if "refresh" not in st.session_state:
32
+ st.session_state.refresh = True
33
 
34
  for message in st.session_state.messages:
35
  with st.chat_message(message["role"]):
 
40
  st.chat_message("user").markdown(prompt)
41
  st.session_state.messages.append({"role": "user", "content": prompt})
42
 
43
+ if st.session_state.refresh:
44
+ st.session_state.refresh = False
45
+ search(prompt)
46
+
47
+ s = SemanticSearch(
48
  st.session_state.doc,
49
+ st.session_state.configs["model"]["embeding_model"],
50
+ st.session_state.configs["model"]["device"],
51
  )
52
+ topk, u = s.semantic_search(query=prompt, k=32)
53
+ output = st.session_state.model.answer_query(query=prompt, topk_items=topk)
54
  response = output
55
  with st.chat_message("assistant"):
56
  st.markdown(response)
model.py CHANGED
@@ -4,7 +4,7 @@ from transformers import BitsAndBytesConfig
4
  from transformers.utils import is_flash_attn_2_available
5
  import yaml
6
  import torch
7
-
8
 
9
  def load_configs(config_file: str) -> dict:
10
  with open(config_file, "r") as f:
@@ -35,13 +35,16 @@ class RAGModel:
35
 
36
  def create_prompt(self, query, topk_items: list[str]):
37
 
38
- context = "_ " + "\n-".join(c for c in topk_items)
39
 
40
- base_prompt = f"""Give time for yourself to read the context and then answer the query.
 
 
 
41
  Do not return thinking process, just return the answer.
42
- If you do not find the answer, or if the query is offesnsive or in any other way harmfull just return "I'm not aware of it"
43
- Now use the following context items to answer the user query.
44
- context: {context}.
45
  user query : {query}
46
  """
47
 
@@ -56,16 +59,16 @@ class RAGModel:
56
 
57
  prompt = self.create_prompt(query, topk_items)
58
  input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
59
- output = self.model.generate(**input_ids, max_new_tokens=512)
60
  text = self.tokenizer.decode(output[0])
 
61
 
62
- return text
63
 
 
64
 
65
  if __name__ == "__main__":
66
-
67
  configs = load_configs(config_file="rag.configs.yml")
68
- query = "what is computer vision"
69
  g = GoogleSearch(query)
70
  data = g.all_page_data
71
  d = Document(data, 512)
 
4
  from transformers.utils import is_flash_attn_2_available
5
  import yaml
6
  import torch
7
+ import nltk
8
 
9
  def load_configs(config_file: str) -> dict:
10
  with open(config_file, "r") as f:
 
35
 
36
  def create_prompt(self, query, topk_items: list[str]):
37
 
38
+ context = "\n-".join(c for c in topk_items)
39
 
40
+ base_prompt = f"""You are an alternate to goole search. Your job is to answer the user query in as detailed manner as possible.
41
+ you have access to the internet and other relevent data related to the user's question.
42
+ Give time for yourself to read the context and user query and extract relevent data and then answer the query.
43
+ make sure your answers is as detailed as posssbile.
44
  Do not return thinking process, just return the answer.
45
+ Give the output structured as a Wikipedia article.
46
+ Now use the following context items to answer the user query
47
+ context: {context}
48
  user query : {query}
49
  """
50
 
 
59
 
60
  prompt = self.create_prompt(query, topk_items)
61
  input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
62
+ output = self.model.generate(**input_ids, temperature=0.7, max_new_tokens=512, do_sample=True)
63
  text = self.tokenizer.decode(output[0])
64
+ text = text.replace(prompt, "").replace("<bos>", "").replace("<eos>", "")
65
 
 
66
 
67
+ return text
68
 
69
  if __name__ == "__main__":
 
70
  configs = load_configs(config_file="rag.configs.yml")
71
+ query = "Explain F1 racing for a beginer"
72
  g = GoogleSearch(query)
73
  data = g.all_page_data
74
  d = Document(data, 512)