XThomasBU
initial commit
d92c997
raw
history blame
3.49 kB
from ragatouille import RAGPretrainedModel
from modules.vectorstore.base import VectorStoreBase
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from typing import Any, List
import os
import json
class RAGatouilleLangChainRetrieverWithScore(BaseRetriever):
model: Any
kwargs: dict = {}
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun, # noqa
) -> List[Document]:
"""Get documents relevant to a query."""
docs = self.model.search(query, **self.kwargs)
return [
Document(
page_content=doc["content"],
metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
)
for doc in docs
]
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun, # noqa
) -> List[Document]:
"""Get documents relevant to a query."""
docs = self.model.search(query, **self.kwargs)
return [
Document(
page_content=doc["content"],
metadata={**doc.get("document_metadata", {}), "score": doc["score"]},
)
for doc in docs
]
class RAGPretrainedModel(RAGPretrainedModel):
"""
Adding len property to RAGPretrainedModel
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._document_count = 0
def set_document_count(self, count):
self._document_count = count
def __len__(self):
return self._document_count
def as_langchain_retriever(self, **kwargs: Any) -> BaseRetriever:
return RAGatouilleLangChainRetrieverWithScore(model=self, kwargs=kwargs)
class ColbertVectorStore(VectorStoreBase):
def __init__(self, config):
self.config = config
self._init_vector_db()
def _init_vector_db(self):
self.colbert = RAGPretrainedModel.from_pretrained(
"colbert-ir/colbertv2.0",
index_root=os.path.join(
self.config["vectorstore"]["db_path"],
"db_" + self.config["vectorstore"]["db_option"],
),
)
def create_database(self, documents, document_names, document_metadata):
index_path = self.colbert.index(
index_name="new_idx",
collection=documents,
document_ids=document_names,
document_metadatas=document_metadata,
)
print(f"Index created at {index_path}")
self.colbert.set_document_count(len(document_names))
def load_database(self):
path = os.path.join(
os.getcwd(),
self.config["vectorstore"]["db_path"],
"db_" + self.config["vectorstore"]["db_option"],
)
self.vectorstore = RAGPretrainedModel.from_index(
f"{path}/colbert/indexes/new_idx"
)
index_metadata = json.load(
open(f"{path}/colbert/indexes/new_idx/0.metadata.json")
)
num_documents = index_metadata["num_passages"]
self.vectorstore.set_document_count(num_documents)
return self.vectorstore
def as_retriever(self):
return self.vectorstore.as_retriever()
def __len__(self):
return len(self.vectorstore)