ArneBinder's picture
new demo setup with langchain retriever
2cc87ec verified
raw
history blame
3.36 kB
import abc
import logging
from copy import copy
from typing import Iterator, List, Optional, Sequence, Tuple
import pandas as pd
from langchain_core.documents import Document as LCDocument
from langchain_core.stores import BaseStore
from pytorch_ie.documents import TextBasedDocument
from .serializable_store import SerializableStore
logger = logging.getLogger(__name__)
class PieDocumentStore(SerializableStore, BaseStore[str, LCDocument], abc.ABC):
"""Abstract base class for document stores specialized in storing and retrieving pie documents."""
METADATA_KEY_PIE_DOCUMENT: str = "pie_document"
"""Key for the pie document in the (langchain) document metadata."""
def wrap(self, pie_document: TextBasedDocument, **metadata) -> LCDocument:
"""Wrap the pie document in an LCDocument."""
return LCDocument(
id=pie_document.id,
page_content="",
metadata={self.METADATA_KEY_PIE_DOCUMENT: pie_document, **metadata},
)
def unwrap(self, document: LCDocument) -> TextBasedDocument:
"""Get the pie document from the langchain document."""
return document.metadata[self.METADATA_KEY_PIE_DOCUMENT]
def unwrap_with_metadata(self, document: LCDocument) -> Tuple[TextBasedDocument, dict]:
"""Get the pie document and metadata from the langchain document."""
metadata = copy(document.metadata)
pie_document = metadata.pop(self.METADATA_KEY_PIE_DOCUMENT)
return pie_document, metadata
@abc.abstractmethod
def mget(self, keys: Sequence[str]) -> List[LCDocument]:
pass
@abc.abstractmethod
def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None:
pass
@abc.abstractmethod
def mdelete(self, keys: Sequence[str]) -> None:
pass
@abc.abstractmethod
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
pass
def __len__(self):
return len(list(self.yield_keys()))
def overview(self, layer_captions: dict, use_predictions: bool = False) -> pd.DataFrame:
"""Get an overview of the document store, including the number of items in each layer for each document
in the store.
Args:
layer_captions: A dictionary mapping layer names to captions.
use_predictions: Whether to use predictions instead of the actual layers.
Returns:
DataFrame: A pandas DataFrame containing the overview.
"""
rows = []
for doc_id in self.yield_keys():
document = self.mget([doc_id])[0]
pie_document = self.unwrap(document)
layers = {
caption: pie_document[layer_name] for layer_name, caption in layer_captions.items()
}
layer_sizes = {
f"num_{caption}s": len(layer) + (len(layer.predictions) if use_predictions else 0)
for caption, layer in layers.items()
}
rows.append({"doc_id": doc_id, **layer_sizes})
df = pd.DataFrame(rows)
return df
def as_dict(self, document: LCDocument) -> dict:
"""Convert the langchain document to a dictionary."""
pie_document, metadata = self.unwrap_with_metadata(document)
return {self.METADATA_KEY_PIE_DOCUMENT: pie_document.asdict(), "metadata": metadata}