Spaces:
Running
Running
fix
Browse files- app/rag.py +14 -14
app/rag.py
CHANGED
@@ -22,35 +22,38 @@ QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')
|
|
22 |
|
23 |
|
24 |
class ChatPDF:
|
25 |
-
text_chunks = []
|
26 |
-
doc_ids = []
|
27 |
-
nodes = []
|
28 |
hyde_query_engine = None
|
|
|
|
|
|
|
29 |
logger = None
|
30 |
|
31 |
def __init__(self):
|
32 |
logging.basicConfig(level=logging.INFO)
|
33 |
self.logger = logging.getLogger(__name__)
|
34 |
|
35 |
-
text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=100)
|
36 |
|
37 |
self.logger.info("initializing the vector store related objects")
|
38 |
client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
|
39 |
-
vector_store = QdrantVectorStore(client=client, collection_name="rag_documents")
|
40 |
|
41 |
self.logger.info("initializing the OllamaEmbedding")
|
42 |
-
embed_model = OllamaEmbedding(model_name='mxbai-embed-large')
|
43 |
self.logger.info("initializing the global settings")
|
44 |
-
Settings.embed_model = embed_model
|
45 |
Settings.llm = Ollama(model="qwen:1.8b", request_timeout=1000000)
|
46 |
-
Settings.transformations = [text_parser]
|
47 |
|
48 |
def ingest(self, dir_path: str):
|
49 |
docs = SimpleDirectoryReader(input_dir=dir_path).load_data()
|
|
|
|
|
|
|
50 |
|
51 |
self.logger.info("enumerating docs")
|
52 |
for doc_idx, doc in enumerate(docs):
|
53 |
-
curr_text_chunks = text_parser.split_text(doc.text)
|
54 |
text_chunks.extend(curr_text_chunks)
|
55 |
doc_ids.extend([doc_idx] * len(curr_text_chunks))
|
56 |
|
@@ -63,13 +66,13 @@ class ChatPDF:
|
|
63 |
|
64 |
self.logger.info("enumerating nodes")
|
65 |
for node in nodes:
|
66 |
-
node_embedding = embed_model.get_text_embedding(
|
67 |
node.get_content(metadata_mode=MetadataMode.ALL)
|
68 |
)
|
69 |
node.embedding = node_embedding
|
70 |
|
71 |
self.logger.info("initializing the storage context")
|
72 |
-
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
73 |
self.logger.info("indexing the nodes in VectorStoreIndex")
|
74 |
index = VectorStoreIndex(
|
75 |
nodes=nodes,
|
@@ -99,7 +102,4 @@ class ChatPDF:
|
|
99 |
return response
|
100 |
|
101 |
def clear(self):
|
102 |
-
self.text_chunks = []
|
103 |
-
self.doc_ids = []
|
104 |
-
self.nodes = []
|
105 |
self.hyde_query_engine = None
|
|
|
22 |
|
23 |
|
24 |
class ChatPDF:
|
|
|
|
|
|
|
25 |
hyde_query_engine = None
|
26 |
+
text_parser = None
|
27 |
+
vector_store = None
|
28 |
+
embed_model = None
|
29 |
logger = None
|
30 |
|
31 |
def __init__(self):
|
32 |
logging.basicConfig(level=logging.INFO)
|
33 |
self.logger = logging.getLogger(__name__)
|
34 |
|
35 |
+
self.text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=100)
|
36 |
|
37 |
self.logger.info("initializing the vector store related objects")
|
38 |
client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
|
39 |
+
self.vector_store = QdrantVectorStore(client=client, collection_name="rag_documents")
|
40 |
|
41 |
self.logger.info("initializing the OllamaEmbedding")
|
42 |
+
self.embed_model = OllamaEmbedding(model_name='mxbai-embed-large')
|
43 |
self.logger.info("initializing the global settings")
|
44 |
+
Settings.embed_model = self.embed_model
|
45 |
Settings.llm = Ollama(model="qwen:1.8b", request_timeout=1000000)
|
46 |
+
Settings.transformations = [self.text_parser]
|
47 |
|
48 |
def ingest(self, dir_path: str):
|
49 |
docs = SimpleDirectoryReader(input_dir=dir_path).load_data()
|
50 |
+
text_chunks = []
|
51 |
+
doc_ids = []
|
52 |
+
nodes = []
|
53 |
|
54 |
self.logger.info("enumerating docs")
|
55 |
for doc_idx, doc in enumerate(docs):
|
56 |
+
curr_text_chunks = self.text_parser.split_text(doc.text)
|
57 |
text_chunks.extend(curr_text_chunks)
|
58 |
doc_ids.extend([doc_idx] * len(curr_text_chunks))
|
59 |
|
|
|
66 |
|
67 |
self.logger.info("enumerating nodes")
|
68 |
for node in nodes:
|
69 |
+
node_embedding = self.embed_model.get_text_embedding(
|
70 |
node.get_content(metadata_mode=MetadataMode.ALL)
|
71 |
)
|
72 |
node.embedding = node_embedding
|
73 |
|
74 |
self.logger.info("initializing the storage context")
|
75 |
+
storage_context = StorageContext.from_defaults(vector_store=self.vector_store)
|
76 |
self.logger.info("indexing the nodes in VectorStoreIndex")
|
77 |
index = VectorStoreIndex(
|
78 |
nodes=nodes,
|
|
|
102 |
return response
|
103 |
|
104 |
def clear(self):
|
|
|
|
|
|
|
105 |
self.hyde_query_engine = None
|