from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union # type: ignore[import-not-found] from langchain_core.embeddings import Embeddings from langchain_core.runnables.config import run_in_executor class SpanEmbeddings(Embeddings, ABC): """Interface for models that embed text spans within documents.""" @abstractmethod def embed_document_spans( self, texts: list[str], starts: Union[list[int], List[List[int]]], ends: Union[list[int], List[List[int]]], ) -> list[Optional[list[float]]]: """Embed search docs. Args: texts: List of text to embed. starts: List of start indices or list of lists of start indices (multi-span). ends: List of end indices or list of lists of end indices (multi-span). Returns: List of embeddings. """ @abstractmethod def embed_query_span( self, text: str, start: Union[int, list[int]], end: Union[int, list[int]] ) -> Optional[list[float]]: """Embed query text. Args: text: Text to embed. start: Start index or list of start indices (multi-span). end: End index or list of end indices (multi-span). Returns: Embedding. """ def embed_documents(self, texts: list[str]) -> list[Optional[list[float]]]: """Embed search docs. Args: texts: List of text to embed. Returns: List of embeddings. """ return self.embed_document_spans(texts, [0] * len(texts), [len(text) for text in texts]) def embed_query(self, text: str) -> Optional[list[float]]: """Embed query text. Args: text: Text to embed. Returns: Embedding. """ return self.embed_query_span(text, 0, len(text)) async def aembed_document_spans( self, texts: list[str], starts: Union[list[int], list[list[int]]], ends: Union[list[int], list[list[int]]], ) -> list[Optional[list[float]]]: """Asynchronous Embed search docs. Args: texts: List of text to embed. starts: List of start indices or list of lists of start indices (multi-span). ends: List of end indices or list of lists of end indices (multi-span). Returns: List of embeddings. """ return await run_in_executor(None, self.embed_document_spans, texts, starts, ends) async def aembed_query_spans( self, text: str, start: Union[int, list[int]], end: Union[int, list[int]] ) -> Optional[list[float]]: """Asynchronous Embed query text. Args: text: Text to embed. start: Start index or list of start indices (multi-span). end: End index or list of end indices (multi-span). Returns: Embedding. """ return await run_in_executor(None, self.embed_query_span, text, start, end) @property @abstractmethod def embedding_dim(self) -> int: """Get the embedding dimension.""" ...