Spaces:
Sleeping
Sleeping
abhishekdileep
commited on
Commit
·
8bfdeed
1
Parent(s):
c575b59
testing done on app.py and model.py
Browse files- app.py +21 -15
- model.py +13 -10
- rag.configs.yml +3 -3
- requirments.txt +76 -1
app.py
CHANGED
@@ -4,28 +4,32 @@ from model import RAGModel, load_configs
|
|
4 |
|
5 |
|
6 |
def run_on_start():
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
def search(query):
|
14 |
g = GoogleSearch(query)
|
15 |
data = g.all_page_data
|
16 |
-
d = Document(data, min_char_len=configs["document"]["min_char_length"])
|
17 |
-
st.session_state.doc = d.doc()
|
18 |
|
19 |
|
20 |
-
st.title("
|
21 |
|
22 |
if "messages" not in st.session_state:
|
23 |
-
run_on_start()
|
24 |
st.session_state.messages = []
|
25 |
|
26 |
if "doc" not in st.session_state:
|
27 |
st.session_state.doc = None
|
28 |
|
|
|
|
|
29 |
|
30 |
for message in st.session_state.messages:
|
31 |
with st.chat_message(message["role"]):
|
@@ -36,15 +40,17 @@ if prompt := st.chat_input("Search Here insetad of Google"):
|
|
36 |
st.chat_message("user").markdown(prompt)
|
37 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
prompt
|
|
|
|
|
42 |
st.session_state.doc,
|
43 |
-
configs["model"]["embeding_model"],
|
44 |
-
configs["model"]["device"],
|
45 |
)
|
46 |
-
topk = s.semantic_search(query=prompt, k=32)
|
47 |
-
output =
|
48 |
response = output
|
49 |
with st.chat_message("assistant"):
|
50 |
st.markdown(response)
|
|
|
4 |
|
5 |
|
6 |
def run_on_start():
|
7 |
+
|
8 |
+
if "configs" not in st.session_state:
|
9 |
+
st.session_state.configs = configs = load_configs(config_file="rag.configs.yml")
|
10 |
+
if "model" not in st.session_state:
|
11 |
+
st.session_state.model = RAGModel(configs)
|
12 |
+
|
13 |
+
run_on_start()
|
14 |
|
15 |
|
16 |
def search(query):
|
17 |
g = GoogleSearch(query)
|
18 |
data = g.all_page_data
|
19 |
+
d = Document(data, min_char_len=st.session_state.configs["document"]["min_char_length"])
|
20 |
+
st.session_state.doc = d.doc()
|
21 |
|
22 |
|
23 |
+
st.title("LLeUUNDd Google search")
|
24 |
|
25 |
if "messages" not in st.session_state:
|
|
|
26 |
st.session_state.messages = []
|
27 |
|
28 |
if "doc" not in st.session_state:
|
29 |
st.session_state.doc = None
|
30 |
|
31 |
+
if "refresh" not in st.session_state:
|
32 |
+
st.session_state.refresh = True
|
33 |
|
34 |
for message in st.session_state.messages:
|
35 |
with st.chat_message(message["role"]):
|
|
|
40 |
st.chat_message("user").markdown(prompt)
|
41 |
st.session_state.messages.append({"role": "user", "content": prompt})
|
42 |
|
43 |
+
if st.session_state.refresh:
|
44 |
+
st.session_state.refresh = False
|
45 |
+
search(prompt)
|
46 |
+
|
47 |
+
s = SemanticSearch(
|
48 |
st.session_state.doc,
|
49 |
+
st.session_state.configs["model"]["embeding_model"],
|
50 |
+
st.session_state.configs["model"]["device"],
|
51 |
)
|
52 |
+
topk, u = s.semantic_search(query=prompt, k=32)
|
53 |
+
output = st.session_state.model.answer_query(query=prompt, topk_items=topk)
|
54 |
response = output
|
55 |
with st.chat_message("assistant"):
|
56 |
st.markdown(response)
|
model.py
CHANGED
@@ -4,7 +4,7 @@ from transformers import BitsAndBytesConfig
|
|
4 |
from transformers.utils import is_flash_attn_2_available
|
5 |
import yaml
|
6 |
import torch
|
7 |
-
|
8 |
|
9 |
def load_configs(config_file: str) -> dict:
|
10 |
with open(config_file, "r") as f:
|
@@ -35,13 +35,16 @@ class RAGModel:
|
|
35 |
|
36 |
def create_prompt(self, query, topk_items: list[str]):
|
37 |
|
38 |
-
context =
|
39 |
|
40 |
-
base_prompt = f"""
|
|
|
|
|
|
|
41 |
Do not return thinking process, just return the answer.
|
42 |
-
|
43 |
-
Now use the following context items to answer the user query
|
44 |
-
context: {context}
|
45 |
user query : {query}
|
46 |
"""
|
47 |
|
@@ -56,16 +59,16 @@ class RAGModel:
|
|
56 |
|
57 |
prompt = self.create_prompt(query, topk_items)
|
58 |
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
59 |
-
output = self.model.generate(**input_ids, max_new_tokens=512)
|
60 |
text = self.tokenizer.decode(output[0])
|
|
|
61 |
|
62 |
-
return text
|
63 |
|
|
|
64 |
|
65 |
if __name__ == "__main__":
|
66 |
-
|
67 |
configs = load_configs(config_file="rag.configs.yml")
|
68 |
-
query = "
|
69 |
g = GoogleSearch(query)
|
70 |
data = g.all_page_data
|
71 |
d = Document(data, 512)
|
|
|
4 |
from transformers.utils import is_flash_attn_2_available
|
5 |
import yaml
|
6 |
import torch
|
7 |
+
import nltk
|
8 |
|
9 |
def load_configs(config_file: str) -> dict:
|
10 |
with open(config_file, "r") as f:
|
|
|
35 |
|
36 |
def create_prompt(self, query, topk_items: list[str]):
|
37 |
|
38 |
+
context = "\n-".join(c for c in topk_items)
|
39 |
|
40 |
+
base_prompt = f"""You are an alternate to goole search. Your job is to answer the user query in as detailed manner as possible.
|
41 |
+
you have access to the internet and other relevent data related to the user's question.
|
42 |
+
Give time for yourself to read the context and user query and extract relevent data and then answer the query.
|
43 |
+
make sure your answers is as detailed as posssbile.
|
44 |
Do not return thinking process, just return the answer.
|
45 |
+
Give the output structured as a Wikipedia article.
|
46 |
+
Now use the following context items to answer the user query
|
47 |
+
context: {context}
|
48 |
user query : {query}
|
49 |
"""
|
50 |
|
|
|
59 |
|
60 |
prompt = self.create_prompt(query, topk_items)
|
61 |
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
62 |
+
output = self.model.generate(**input_ids, temperature=0.7, max_new_tokens=512, do_sample=True)
|
63 |
text = self.tokenizer.decode(output[0])
|
64 |
+
text = text.replace(prompt, "").replace("<bos>", "").replace("<eos>", "")
|
65 |
|
|
|
66 |
|
67 |
+
return text
|
68 |
|
69 |
if __name__ == "__main__":
|
|
|
70 |
configs = load_configs(config_file="rag.configs.yml")
|
71 |
+
query = "Explain F1 racing for a beginer"
|
72 |
g = GoogleSearch(query)
|
73 |
data = g.all_page_data
|
74 |
d = Document(data, 512)
|
rag.configs.yml
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
document:
|
2 |
-
min_char_length:
|
3 |
|
4 |
model:
|
5 |
embeding_model: all-mpnet-base-v2
|
6 |
-
genration_model: google/gemma-
|
7 |
-
device:
|
8 |
|
|
|
1 |
document:
|
2 |
+
min_char_length: 512
|
3 |
|
4 |
model:
|
5 |
embeding_model: all-mpnet-base-v2
|
6 |
+
genration_model: google/gemma-7b-it
|
7 |
+
device : cuda
|
8 |
|
requirments.txt
CHANGED
@@ -1 +1,76 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
beautifulsoup4==4.12.3
|
3 |
+
accelerate==0.29.2
|
4 |
+
altair==5.3.0
|
5 |
+
attrs==23.2.0
|
6 |
+
beautifulsoup4==4.12.3
|
7 |
+
bitsandbytes==0.42.0
|
8 |
+
blinker==1.7.0
|
9 |
+
Brotli @ file:///Users/runner/miniforge3/conda-bld/brotli-split_1625213545710/work
|
10 |
+
cachetools==5.3.3
|
11 |
+
certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1707022139797/work/certifi
|
12 |
+
cffi @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_b4nang6w_y/croot/cffi_1700254307954/work
|
13 |
+
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1698833585322/work
|
14 |
+
click==8.1.7
|
15 |
+
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
|
16 |
+
filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1711394622191/work
|
17 |
+
fsspec @ file:///home/conda/feedstock_root/build_artifacts/fsspec_1710808267764/work
|
18 |
+
gitdb==4.0.11
|
19 |
+
GitPython==3.1.43
|
20 |
+
huggingface_hub @ file:///home/conda/feedstock_root/build_artifacts/huggingface_hub_1711986612800/work
|
21 |
+
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1701026962277/work
|
22 |
+
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/importlib-metadata_1710971335535/work
|
23 |
+
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1704966972576/work
|
24 |
+
joblib==1.3.2
|
25 |
+
jsonschema==4.21.1
|
26 |
+
jsonschema-specifications==2023.12.1
|
27 |
+
markdown-it-py==3.0.0
|
28 |
+
MarkupSafe @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_a84ni4pci8/croot/markupsafe_1704206002077/work
|
29 |
+
mdurl==0.1.2
|
30 |
+
mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1678228039184/work
|
31 |
+
networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1698504735452/work
|
32 |
+
nltk==3.8.1
|
33 |
+
numpy @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_a51i_mbs7m/croot/numpy_and_numpy_base_1708638620867/work/dist/numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl#sha256=829e20a6c33ce51c1a93497d06cb4af22d84caa54a431ea062765da3134e5287
|
34 |
+
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1710075952259/work
|
35 |
+
pandas==2.2.1
|
36 |
+
pillow @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_e02b4k5qik/croot/pillow_1707233036487/work
|
37 |
+
protobuf==4.25.3
|
38 |
+
psutil @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_1310b568-21f4-4cb0-b0e3-2f3d31e39728k9coaga5/croots/recipe/psutil_1656431280844/work
|
39 |
+
pyarrow==15.0.2
|
40 |
+
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1711811537435/work
|
41 |
+
pydeck==0.8.1b0
|
42 |
+
Pygments==2.17.2
|
43 |
+
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
|
44 |
+
python-dateutil==2.9.0.post0
|
45 |
+
pytz==2024.1
|
46 |
+
PyYAML==5.4.1
|
47 |
+
referencing==0.34.0
|
48 |
+
regex==2023.12.25
|
49 |
+
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1684774241324/work
|
50 |
+
rich==13.7.1
|
51 |
+
rpds-py==0.18.0
|
52 |
+
safetensors @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_09qdt_s9t7/croot/safetensors_1708633848061/work
|
53 |
+
scikit-learn @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_60ynh176wd/croot/scikit-learn_1694789615217/work
|
54 |
+
scipy @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_41w43uybvr/croot/scipy_1710947318888/work/dist/scipy-1.12.0-cp39-cp39-macosx_11_0_arm64.whl#sha256=73d83606c8528425eb69a034da182c70ebf79b1a85019adc1f5f32a1329c830c
|
55 |
+
sentence-transformers @ file:///home/conda/feedstock_root/build_artifacts/sentence-transformers_1711454085860/work
|
56 |
+
sentencepiece @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_dek8463j1w/croot/sentencepiece_1684523571928/work/python
|
57 |
+
six==1.16.0
|
58 |
+
smmap==5.0.1
|
59 |
+
soupsieve==2.5
|
60 |
+
streamlit==1.33.0
|
61 |
+
sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1684180539862/work
|
62 |
+
tenacity==8.2.3
|
63 |
+
threadpoolctl @ file:///home/conda/feedstock_root/build_artifacts/threadpoolctl_1710943558485/work
|
64 |
+
tokenizers @ file:///private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_77bzam0w9g/croot/tokenizers_1708633828244/work
|
65 |
+
toml==0.10.2
|
66 |
+
toolz==0.12.1
|
67 |
+
torch==2.2.2
|
68 |
+
torchaudio==2.2.2
|
69 |
+
torchvision @ file:///private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_cfzx6ndngz/croot/torchvision_1689077985227/work
|
70 |
+
tornado==6.4
|
71 |
+
tqdm==4.66.2
|
72 |
+
transformers==4.39.3
|
73 |
+
typing_extensions==4.10.0
|
74 |
+
tzdata==2024.1
|
75 |
+
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1708239446578/work
|
76 |
+
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1695255097490/work
|