mitulagr2 commited on
Commit
5c1d000
·
1 Parent(s): 76b90b9
Files changed (1) hide show
  1. app/rag.py +14 -14
app/rag.py CHANGED
@@ -22,35 +22,38 @@ QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')
22
 
23
 
24
  class ChatPDF:
25
- text_chunks = []
26
- doc_ids = []
27
- nodes = []
28
  hyde_query_engine = None
 
 
 
29
  logger = None
30
 
31
  def __init__(self):
32
  logging.basicConfig(level=logging.INFO)
33
  self.logger = logging.getLogger(__name__)
34
 
35
- text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=100)
36
 
37
  self.logger.info("initializing the vector store related objects")
38
  client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
39
- vector_store = QdrantVectorStore(client=client, collection_name="rag_documents")
40
 
41
  self.logger.info("initializing the OllamaEmbedding")
42
- embed_model = OllamaEmbedding(model_name='mxbai-embed-large')
43
  self.logger.info("initializing the global settings")
44
- Settings.embed_model = embed_model
45
  Settings.llm = Ollama(model="qwen:1.8b", request_timeout=1000000)
46
- Settings.transformations = [text_parser]
47
 
48
  def ingest(self, dir_path: str):
49
  docs = SimpleDirectoryReader(input_dir=dir_path).load_data()
 
 
 
50
 
51
  self.logger.info("enumerating docs")
52
  for doc_idx, doc in enumerate(docs):
53
- curr_text_chunks = text_parser.split_text(doc.text)
54
  text_chunks.extend(curr_text_chunks)
55
  doc_ids.extend([doc_idx] * len(curr_text_chunks))
56
 
@@ -63,13 +66,13 @@ class ChatPDF:
63
 
64
  self.logger.info("enumerating nodes")
65
  for node in nodes:
66
- node_embedding = embed_model.get_text_embedding(
67
  node.get_content(metadata_mode=MetadataMode.ALL)
68
  )
69
  node.embedding = node_embedding
70
 
71
  self.logger.info("initializing the storage context")
72
- storage_context = StorageContext.from_defaults(vector_store=vector_store)
73
  self.logger.info("indexing the nodes in VectorStoreIndex")
74
  index = VectorStoreIndex(
75
  nodes=nodes,
@@ -99,7 +102,4 @@ class ChatPDF:
99
  return response
100
 
101
  def clear(self):
102
- self.text_chunks = []
103
- self.doc_ids = []
104
- self.nodes = []
105
  self.hyde_query_engine = None
 
22
 
23
 
24
  class ChatPDF:
 
 
 
25
  hyde_query_engine = None
26
+ text_parser = None
27
+ vector_store = None
28
+ embed_model = None
29
  logger = None
30
 
31
  def __init__(self):
32
  logging.basicConfig(level=logging.INFO)
33
  self.logger = logging.getLogger(__name__)
34
 
35
+ self.text_parser = SentenceSplitter(chunk_size=512, chunk_overlap=100)
36
 
37
  self.logger.info("initializing the vector store related objects")
38
  client = QdrantClient(url=QDRANT_API_URL, api_key=QDRANT_API_KEY)
39
+ self.vector_store = QdrantVectorStore(client=client, collection_name="rag_documents")
40
 
41
  self.logger.info("initializing the OllamaEmbedding")
42
+ self.embed_model = OllamaEmbedding(model_name='mxbai-embed-large')
43
  self.logger.info("initializing the global settings")
44
+ Settings.embed_model = self.embed_model
45
  Settings.llm = Ollama(model="qwen:1.8b", request_timeout=1000000)
46
+ Settings.transformations = [self.text_parser]
47
 
48
  def ingest(self, dir_path: str):
49
  docs = SimpleDirectoryReader(input_dir=dir_path).load_data()
50
+ text_chunks = []
51
+ doc_ids = []
52
+ nodes = []
53
 
54
  self.logger.info("enumerating docs")
55
  for doc_idx, doc in enumerate(docs):
56
+ curr_text_chunks = self.text_parser.split_text(doc.text)
57
  text_chunks.extend(curr_text_chunks)
58
  doc_ids.extend([doc_idx] * len(curr_text_chunks))
59
 
 
66
 
67
  self.logger.info("enumerating nodes")
68
  for node in nodes:
69
+ node_embedding = self.embed_model.get_text_embedding(
70
  node.get_content(metadata_mode=MetadataMode.ALL)
71
  )
72
  node.embedding = node_embedding
73
 
74
  self.logger.info("initializing the storage context")
75
+ storage_context = StorageContext.from_defaults(vector_store=self.vector_store)
76
  self.logger.info("indexing the nodes in VectorStoreIndex")
77
  index = VectorStoreIndex(
78
  nodes=nodes,
 
102
  return response
103
 
104
  def clear(self):
 
 
 
105
  self.hyde_query_engine = None