|
import abc |
|
import logging |
|
from copy import copy |
|
from typing import Iterator, List, Optional, Sequence, Tuple |
|
|
|
import pandas as pd |
|
from langchain_core.documents import Document as LCDocument |
|
from langchain_core.stores import BaseStore |
|
from pytorch_ie.documents import TextBasedDocument |
|
|
|
from .serializable_store import SerializableStore |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class PieDocumentStore(SerializableStore, BaseStore[str, LCDocument], abc.ABC): |
|
"""Abstract base class for document stores specialized in storing and retrieving pie documents.""" |
|
|
|
METADATA_KEY_PIE_DOCUMENT: str = "pie_document" |
|
"""Key for the pie document in the (langchain) document metadata.""" |
|
|
|
def wrap(self, pie_document: TextBasedDocument, **metadata) -> LCDocument: |
|
"""Wrap the pie document in an LCDocument.""" |
|
return LCDocument( |
|
id=pie_document.id, |
|
page_content="", |
|
metadata={self.METADATA_KEY_PIE_DOCUMENT: pie_document, **metadata}, |
|
) |
|
|
|
def unwrap(self, document: LCDocument) -> TextBasedDocument: |
|
"""Get the pie document from the langchain document.""" |
|
return document.metadata[self.METADATA_KEY_PIE_DOCUMENT] |
|
|
|
def unwrap_with_metadata(self, document: LCDocument) -> Tuple[TextBasedDocument, dict]: |
|
"""Get the pie document and metadata from the langchain document.""" |
|
metadata = copy(document.metadata) |
|
pie_document = metadata.pop(self.METADATA_KEY_PIE_DOCUMENT) |
|
return pie_document, metadata |
|
|
|
@abc.abstractmethod |
|
def mget(self, keys: Sequence[str]) -> List[LCDocument]: |
|
pass |
|
|
|
@abc.abstractmethod |
|
def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None: |
|
pass |
|
|
|
@abc.abstractmethod |
|
def mdelete(self, keys: Sequence[str]) -> None: |
|
pass |
|
|
|
@abc.abstractmethod |
|
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: |
|
pass |
|
|
|
def __len__(self): |
|
return len(list(self.yield_keys())) |
|
|
|
def overview(self, layer_captions: dict, use_predictions: bool = False) -> pd.DataFrame: |
|
"""Get an overview of the document store, including the number of items in each layer for each document |
|
in the store. |
|
|
|
Args: |
|
layer_captions: A dictionary mapping layer names to captions. |
|
use_predictions: Whether to use predictions instead of the actual layers. |
|
|
|
Returns: |
|
DataFrame: A pandas DataFrame containing the overview. |
|
""" |
|
rows = [] |
|
for doc_id in self.yield_keys(): |
|
document = self.mget([doc_id])[0] |
|
pie_document = self.unwrap(document) |
|
layers = { |
|
caption: pie_document[layer_name] for layer_name, caption in layer_captions.items() |
|
} |
|
layer_sizes = { |
|
f"num_{caption}s": len(layer) + (len(layer.predictions) if use_predictions else 0) |
|
for caption, layer in layers.items() |
|
} |
|
rows.append({"doc_id": doc_id, **layer_sizes}) |
|
df = pd.DataFrame(rows) |
|
return df |
|
|
|
def as_dict(self, document: LCDocument) -> dict: |
|
"""Convert the langchain document to a dictionary.""" |
|
pie_document, metadata = self.unwrap_with_metadata(document) |
|
return {self.METADATA_KEY_PIE_DOCUMENT: pie_document.asdict(), "metadata": metadata} |
|
|