Spaces:
Sleeping
Sleeping
8bitnand
commited on
Commit
·
871255a
1
Parent(s):
87d5c64
Added support for streamlit and rag model
Browse files- .gitignore +1 -0
- __init__.py +1 -0
- __pycache__/google.cpython-39.pyc +0 -0
- app.py +32 -3
- google.py +10 -7
- model.py +71 -8
- rag.configs.yml +3 -3
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__/
|
__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from google import GoogleSearch, Document, SemanticSearch
|
__pycache__/google.cpython-39.pyc
DELETED
Binary file (5.39 kB)
|
|
app.py
CHANGED
@@ -1,10 +1,33 @@
|
|
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
st.title("LLM powred Google search")
|
4 |
|
5 |
if "messages" not in st.session_state:
|
|
|
6 |
st.session_state.messages = []
|
7 |
|
|
|
|
|
|
|
|
|
8 |
for message in st.session_state.messages:
|
9 |
with st.chat_message(message["role"]):
|
10 |
st.markdown(message["content"])
|
@@ -14,10 +37,16 @@ if prompt := st.chat_input("Search Here insetad of Google"):
|
|
14 |
st.chat_message("user").markdown(prompt)
|
15 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
16 |
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
)
|
20 |
-
|
|
|
|
|
21 |
with st.chat_message("assistant"):
|
22 |
st.markdown(response)
|
23 |
|
|
|
1 |
+
import sys
|
2 |
import streamlit as st
|
3 |
+
from google import SemanticSearch, GoogleSearch, Document
|
4 |
+
from model import RAGModel, load_configs
|
5 |
+
|
6 |
+
|
7 |
+
def run_on_start():
|
8 |
+
global r
|
9 |
+
global configs
|
10 |
+
configs = load_configs(config_file="rag.configs.yml")
|
11 |
+
r = RAGModel(configs)
|
12 |
+
|
13 |
+
|
14 |
+
def search(query):
|
15 |
+
g = GoogleSearch(query)
|
16 |
+
data = g.all_page_data
|
17 |
+
d = Document(data, min_char_len=configs["document"]["min_char_length"])
|
18 |
+
st.session_state.doc = d.doc()[0]
|
19 |
+
|
20 |
|
21 |
st.title("LLM powred Google search")
|
22 |
|
23 |
if "messages" not in st.session_state:
|
24 |
+
run_on_start()
|
25 |
st.session_state.messages = []
|
26 |
|
27 |
+
if "doc" not in st.session_state:
|
28 |
+
st.session_state.doc = None
|
29 |
+
|
30 |
+
|
31 |
for message in st.session_state.messages:
|
32 |
with st.chat_message(message["role"]):
|
33 |
st.markdown(message["content"])
|
|
|
37 |
st.chat_message("user").markdown(prompt)
|
38 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
39 |
|
40 |
+
search(prompt)
|
41 |
+
s = SemanticSearch(
|
42 |
+
prompt,
|
43 |
+
st.session_state.doc,
|
44 |
+
configs["model"]["embeding_model"],
|
45 |
+
configs["model"]["device"],
|
46 |
)
|
47 |
+
topk = s.semantic_search(query=prompt, k=32)
|
48 |
+
output = r.answer_query(query=prompt, topk_items=topk)
|
49 |
+
response = output
|
50 |
with st.chat_message("assistant"):
|
51 |
st.markdown(response)
|
52 |
|
google.py
CHANGED
@@ -13,7 +13,7 @@ class GoogleSearch:
|
|
13 |
escaped_query = urllib.parse.quote_plus(query)
|
14 |
self.URL = f"https://www.google.com/search?q={escaped_query}"
|
15 |
|
16 |
-
self.headers =
|
17 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/72.0.3538.102 Safari/537.36"
|
18 |
}
|
19 |
self.links = self.get_initial_links()
|
@@ -46,7 +46,7 @@ class GoogleSearch:
|
|
46 |
"""
|
47 |
scrape google for the query with keyword based search
|
48 |
"""
|
49 |
-
|
50 |
response = requests.get(self.URL, headers=self.headers)
|
51 |
soup = BeautifulSoup(response.text, "html.parser")
|
52 |
anchors = soup.find_all("a", href=True)
|
@@ -95,6 +95,7 @@ class Document:
|
|
95 |
return min_len_chunks
|
96 |
|
97 |
def doc(self) -> tuple[list[str], list[str]]:
|
|
|
98 |
chunked_data: list[str] = []
|
99 |
urls: list[str] = []
|
100 |
for url, dataitem in self.data:
|
@@ -108,16 +109,17 @@ class Document:
|
|
108 |
|
109 |
class SemanticSearch:
|
110 |
def __init__(
|
111 |
-
self,
|
112 |
) -> None:
|
113 |
query = query
|
114 |
-
self.doc_chunks, self.urls =
|
115 |
self.st = SentenceTransformer(
|
116 |
model_path,
|
117 |
device,
|
118 |
)
|
119 |
|
120 |
-
def
|
|
|
121 |
query_embeding = self.get_embeding(query)
|
122 |
doc_embeding = self.get_embeding(self.doc_chunks)
|
123 |
scores = util.dot_score(a=query_embeding, b=doc_embeding)[0]
|
@@ -136,8 +138,9 @@ if __name__ == "__main__":
|
|
136 |
g = GoogleSearch(query)
|
137 |
data = g.all_page_data
|
138 |
d = Document(data, 333)
|
139 |
-
|
140 |
-
|
|
|
141 |
|
142 |
# g = GoogleSearch("what is LLM")
|
143 |
# d = Document(g.all_page_data)
|
|
|
13 |
escaped_query = urllib.parse.quote_plus(query)
|
14 |
self.URL = f"https://www.google.com/search?q={escaped_query}"
|
15 |
|
16 |
+
self.headers = {
|
17 |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/72.0.3538.102 Safari/537.36"
|
18 |
}
|
19 |
self.links = self.get_initial_links()
|
|
|
46 |
"""
|
47 |
scrape google for the query with keyword based search
|
48 |
"""
|
49 |
+
print("Searching Google...")
|
50 |
response = requests.get(self.URL, headers=self.headers)
|
51 |
soup = BeautifulSoup(response.text, "html.parser")
|
52 |
anchors = soup.find_all("a", href=True)
|
|
|
95 |
return min_len_chunks
|
96 |
|
97 |
def doc(self) -> tuple[list[str], list[str]]:
|
98 |
+
print("Creating Document...")
|
99 |
chunked_data: list[str] = []
|
100 |
urls: list[str] = []
|
101 |
for url, dataitem in self.data:
|
|
|
109 |
|
110 |
class SemanticSearch:
|
111 |
def __init__(
|
112 |
+
self, doc_chunks: tuple[list, list], model_path: str, device: str
|
113 |
) -> None:
|
114 |
query = query
|
115 |
+
self.doc_chunks, self.urls = doc_chunks
|
116 |
self.st = SentenceTransformer(
|
117 |
model_path,
|
118 |
device,
|
119 |
)
|
120 |
|
121 |
+
def semantic_search(self, query: str, k: int = 10):
|
122 |
+
print("Searhing Top k in document...")
|
123 |
query_embeding = self.get_embeding(query)
|
124 |
doc_embeding = self.get_embeding(self.doc_chunks)
|
125 |
scores = util.dot_score(a=query_embeding, b=doc_embeding)[0]
|
|
|
138 |
g = GoogleSearch(query)
|
139 |
data = g.all_page_data
|
140 |
d = Document(data, 333)
|
141 |
+
|
142 |
+
s = SemanticSearch("all-mpnet-base-v2", "mps")
|
143 |
+
print(len(s.semantic_search(query, k=64)))
|
144 |
|
145 |
# g = GoogleSearch("what is LLM")
|
146 |
# d = Document(g.all_page_data)
|
model.py
CHANGED
@@ -1,15 +1,78 @@
|
|
1 |
-
from google import SemanticSearch
|
2 |
-
from transformers import AutoTokenizer,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
|
5 |
class RAGModel:
|
6 |
def __init__(self, configs) -> None:
|
7 |
self.configs = configs
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
13 |
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from google import SemanticSearch, GoogleSearch, Document
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
+
from transformers import BitsAndBytesConfig
|
4 |
+
from transformers.utils import is_flash_attn_2_available
|
5 |
+
import yaml
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def load_configs(config_file: str) -> dict:
|
10 |
+
with open(config_file, "r") as f:
|
11 |
+
configs = yaml.safe_load(f)
|
12 |
+
|
13 |
+
return configs
|
14 |
|
15 |
|
16 |
class RAGModel:
|
17 |
def __init__(self, configs) -> None:
|
18 |
self.configs = configs
|
19 |
+
self.device = configs["model"]["device"]
|
20 |
+
model_url = configs["model"]["genration_model"]
|
21 |
+
# quantization_config = BitsAndBytesConfig(
|
22 |
+
# load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
|
23 |
+
# )
|
24 |
+
|
25 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
26 |
+
model_url,
|
27 |
+
torch_dtype=torch.float16,
|
28 |
+
# quantization_config=quantization_config,
|
29 |
+
low_cpu_mem_usage=False,
|
30 |
+
attn_implementation="sdpa",
|
31 |
+
).to(self.device)
|
32 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
33 |
+
model_url,
|
34 |
+
)
|
35 |
+
|
36 |
+
def create_prompt(self, query, topk_items: list[str]):
|
37 |
+
|
38 |
+
context = "_ " + "\n-".join(c for c in topk_items)
|
39 |
+
|
40 |
+
base_prompt = f"""Based on the follwing context items, please answer the query.
|
41 |
+
Give time for yourself to read the context and then answer the query.
|
42 |
+
Do not return thinking process, just return the answer.
|
43 |
+
If you do not find the answer, or if the query is offesnsive or in any other way harmfull just return "I'm not aware of it"
|
44 |
+
Now use the following context items to answer the user query.
|
45 |
+
{context}.
|
46 |
+
user query : {query}
|
47 |
+
"""
|
48 |
+
|
49 |
+
dialog_template = [{"role": "user", "content": base_prompt}]
|
50 |
+
|
51 |
+
prompt = self.tokenizer.apply_chat_template(
|
52 |
+
conversation=dialog_template, tokenize=False, add_feneration_prompt=True
|
53 |
+
)
|
54 |
+
return prompt
|
55 |
+
|
56 |
+
def answer_query(self, query: str, topk_items: list[str]):
|
57 |
+
|
58 |
+
prompt = self.create_prompt(query, topk_items)
|
59 |
+
print(prompt)
|
60 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
61 |
+
output = self.model.generate(**input_ids, max_new_tokens=512)
|
62 |
+
text = self.tokenizer.decode(output[0])
|
63 |
+
|
64 |
+
return text
|
65 |
+
|
66 |
|
67 |
+
if __name__ == "__main__":
|
68 |
|
69 |
+
configs = load_configs(config_file="rag.configs.yml")
|
70 |
+
query = "what is LLM"
|
71 |
+
# g = GoogleSearch(query)
|
72 |
+
# data = g.all_page_data
|
73 |
+
# d = Document(data, 512)
|
74 |
+
# s = SemanticSearch( "all-mpnet-base-v2", "mps")
|
75 |
+
# topk = s.semantic_search(query=query, k=32)
|
76 |
+
r = RAGModel(configs)
|
77 |
+
output = r.answer_query(query=query, topk_items=[""])
|
78 |
+
print(output)
|
rag.configs.yml
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
document:
|
2 |
min_char_length: 333
|
3 |
|
4 |
-
|
5 |
embeding_model: all-mpnet-base-v2
|
6 |
-
genration_model:
|
7 |
-
device:
|
8 |
|
|
|
1 |
document:
|
2 |
min_char_length: 333
|
3 |
|
4 |
+
model:
|
5 |
embeding_model: all-mpnet-base-v2
|
6 |
+
genration_model: google/gemma-2b-it
|
7 |
+
device: mps
|
8 |
|