manjunathshiva commited on
Commit
fe0e9f4
·
verified ·
1 Parent(s): 18cc60c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_index.core import (
2
+ VectorStoreIndex
3
+ )
4
+ from llama_index.core import Settings
5
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
6
+ from llama_index.vector_stores.qdrant import QdrantVectorStore
7
+ from qdrant_client import QdrantClient
8
+ from typing import Any, List, Tuple
9
+ import torch
10
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
11
+ import streamlit as st
12
+ from llama_index.llms.huggingface import (
13
+ HuggingFaceInferenceAPI
14
+ )
15
+ import os
16
+ HUGGINGFACEHUB_API_TOKEN = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
17
+ Q_END_POINT = os.environ.get("Q_END_POINT")
18
+ Q_API_KEY = os.environ.get("Q_API_KEY")
19
+
20
+
21
+ #DOC
22
+ #https://docs.llamaindex.ai/en/stable/examples/vector_stores/qdrant_hybrid.html
23
+
24
+ doc_tokenizer = AutoTokenizer.from_pretrained(
25
+ "naver/efficient-splade-VI-BT-large-doc"
26
+ )
27
+ doc_model = AutoModelForMaskedLM.from_pretrained(
28
+ "naver/efficient-splade-VI-BT-large-doc"
29
+ )
30
+
31
+ query_tokenizer = AutoTokenizer.from_pretrained(
32
+ "naver/efficient-splade-VI-BT-large-query"
33
+ )
34
+ query_model = AutoModelForMaskedLM.from_pretrained(
35
+ "naver/efficient-splade-VI-BT-large-query"
36
+ )
37
+
38
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
39
+
40
+ doc_model = doc_model.to(device)
41
+ query_model = query_model.to(device)
42
+
43
+
44
+ def sparse_doc_vectors(
45
+ texts: List[str],
46
+ ) -> Tuple[List[List[int]], List[List[float]]]:
47
+ """
48
+ Computes vectors from logits and attention mask using ReLU, log, and max operations.
49
+ """
50
+ tokens = doc_tokenizer(
51
+ texts, truncation=True, padding=True, return_tensors="pt"
52
+ )
53
+ if torch.cuda.is_available():
54
+ tokens = tokens.to("cuda:1")
55
+
56
+ output = doc_model(**tokens)
57
+ logits, attention_mask = output.logits, tokens.attention_mask
58
+ relu_log = torch.log(1 + torch.relu(logits))
59
+ weighted_log = relu_log * attention_mask.unsqueeze(-1)
60
+ tvecs, _ = torch.max(weighted_log, dim=1)
61
+
62
+ # extract the vectors that are non-zero and their indices
63
+ indices = []
64
+ vecs = []
65
+ for batch in tvecs:
66
+ indices.append(batch.nonzero(as_tuple=True)[0].tolist())
67
+ vecs.append(batch[indices[-1]].tolist())
68
+
69
+ return indices, vecs
70
+
71
+
72
+ def sparse_query_vectors(
73
+ texts: List[str],
74
+ ) -> Tuple[List[List[int]], List[List[float]]]:
75
+ """
76
+ Computes vectors from logits and attention mask using ReLU, log, and max operations.
77
+ """
78
+ # TODO: compute sparse vectors in batches if max length is exceeded
79
+ tokens = query_tokenizer(
80
+ texts, truncation=True, padding=True, return_tensors="pt"
81
+ )
82
+ if torch.cuda.is_available():
83
+ tokens = tokens.to("cuda:1")
84
+
85
+
86
+ output = query_model(**tokens)
87
+ logits, attention_mask = output.logits, tokens.attention_mask
88
+ relu_log = torch.log(1 + torch.relu(logits))
89
+ weighted_log = relu_log * attention_mask.unsqueeze(-1)
90
+ tvecs, _ = torch.max(weighted_log, dim=1)
91
+
92
+ # extract the vectors that are non-zero and their indices
93
+ indices = []
94
+ vecs = []
95
+ for batch in tvecs:
96
+ indices.append(batch.nonzero(as_tuple=True)[0].tolist())
97
+ vecs.append(batch[indices[-1]].tolist())
98
+
99
+ return indices, vecs
100
+
101
+ st.header("Chat with the Bhagavad Gita docs 💬 📚"")
102
+
103
+ if "messages" not in st.session_state.keys(): # Initialize the chat message history
104
+ st.session_state.messages = [
105
+ {"role": "assistant", "content": "Ask me a question about Gita!"}
106
+ ]
107
+
108
+
109
+ # creates a persistant index to disk
110
+ client = QdrantClient(
111
+ Q_END_POINT,
112
+ api_key=Q_API_KEY,
113
+ )
114
+ # create our vector store with hybrid indexing enabled
115
+ # batch_size controls how many nodes are encoded with sparse vectors at once
116
+ vector_store = QdrantVectorStore(
117
+ "bhagavad_gita", client=client, enable_hybrid=True, batch_size=20,force_disable_check_same_thread=True,
118
+ sparse_doc_fn=sparse_doc_vectors,
119
+ sparse_query_fn=sparse_query_vectors,
120
+ )
121
+
122
+
123
+ llm = HuggingFaceInferenceAPI(
124
+ model_name="mistralai/Mistral-7B-Instruct-v0.2",
125
+ token=HUGGINGFACEHUB_API_TOKEN,
126
+ context_window=8096,
127
+ )
128
+ Settings.llm = llm
129
+ Settings.tokenzier = AutoTokenizer.from_pretrained(
130
+ "mistralai/Mistral-7B-Instruct-v0.2"
131
+ )
132
+
133
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5", device="cpu")
134
+ Settings.embed_model = embed_model
135
+
136
+ index = VectorStoreIndex.from_vector_store(vector_store=vector_store,embed_model=embed_model)
137
+
138
+ from llama_index.core.memory import ChatMemoryBuffer
139
+ memory = ChatMemoryBuffer.from_defaults(token_limit=1500)
140
+
141
+ chat_engine = index.as_chat_engine(chat_mode="condense_question",
142
+ verbose=True,
143
+ memory=memory,
144
+ sparse_top_k=10,
145
+ vector_store_query_mode="hybrid",
146
+ similarity_top_k=3,
147
+ )
148
+
149
+ if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
150
+ st.session_state.messages.append({"role": "user", "content": prompt})
151
+
152
+ for message in st.session_state.messages: # Display the prior chat messages
153
+ with st.chat_message(message["role"]):
154
+ st.write(message["content"])
155
+
156
+ # If last message is not from assistant, generate a new response
157
+ if st.session_state.messages[-1]["role"] != "assistant":
158
+ with st.chat_message("assistant"):
159
+ with st.spinner("Thinking..."):
160
+ response = chat_engine.chat(prompt)
161
+ st.write(response.response)
162
+ message = {"role": "assistant", "content": response.response}
163
+ st.session_state.messages.append(message) # Add response to message history
164
+
165
+
166
+
167
+