|
from abc import ABC, abstractmethod |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
from langchain_core.embeddings import Embeddings |
|
from langchain_core.runnables.config import run_in_executor |
|
|
|
|
|
class SpanEmbeddings(Embeddings, ABC): |
|
"""Interface for models that embed text spans within documents.""" |
|
|
|
@abstractmethod |
|
def embed_document_spans( |
|
self, |
|
texts: list[str], |
|
starts: Union[list[int], List[List[int]]], |
|
ends: Union[list[int], List[List[int]]], |
|
) -> list[Optional[list[float]]]: |
|
"""Embed search docs. |
|
|
|
Args: |
|
texts: List of text to embed. |
|
starts: List of start indices or list of lists of start indices (multi-span). |
|
ends: List of end indices or list of lists of end indices (multi-span). |
|
|
|
Returns: |
|
List of embeddings. |
|
""" |
|
|
|
@abstractmethod |
|
def embed_query_span( |
|
self, text: str, start: Union[int, list[int]], end: Union[int, list[int]] |
|
) -> Optional[list[float]]: |
|
"""Embed query text. |
|
|
|
Args: |
|
text: Text to embed. |
|
start: Start index or list of start indices (multi-span). |
|
end: End index or list of end indices (multi-span). |
|
|
|
Returns: |
|
Embedding. |
|
""" |
|
|
|
def embed_documents(self, texts: list[str]) -> list[Optional[list[float]]]: |
|
"""Embed search docs. |
|
|
|
Args: |
|
texts: List of text to embed. |
|
|
|
Returns: |
|
List of embeddings. |
|
""" |
|
return self.embed_document_spans(texts, [0] * len(texts), [len(text) for text in texts]) |
|
|
|
def embed_query(self, text: str) -> Optional[list[float]]: |
|
"""Embed query text. |
|
|
|
Args: |
|
text: Text to embed. |
|
|
|
Returns: |
|
Embedding. |
|
""" |
|
return self.embed_query_span(text, 0, len(text)) |
|
|
|
async def aembed_document_spans( |
|
self, |
|
texts: list[str], |
|
starts: Union[list[int], list[list[int]]], |
|
ends: Union[list[int], list[list[int]]], |
|
) -> list[Optional[list[float]]]: |
|
"""Asynchronous Embed search docs. |
|
|
|
Args: |
|
texts: List of text to embed. |
|
starts: List of start indices or list of lists of start indices (multi-span). |
|
ends: List of end indices or list of lists of end indices (multi-span). |
|
|
|
Returns: |
|
List of embeddings. |
|
""" |
|
return await run_in_executor(None, self.embed_document_spans, texts, starts, ends) |
|
|
|
async def aembed_query_spans( |
|
self, text: str, start: Union[int, list[int]], end: Union[int, list[int]] |
|
) -> Optional[list[float]]: |
|
"""Asynchronous Embed query text. |
|
|
|
Args: |
|
text: Text to embed. |
|
start: Start index or list of start indices (multi-span). |
|
end: End index or list of end indices (multi-span). |
|
|
|
Returns: |
|
Embedding. |
|
""" |
|
return await run_in_executor(None, self.embed_query_span, text, start, end) |
|
|
|
@property |
|
@abstractmethod |
|
def embedding_dim(self) -> int: |
|
"""Get the embedding dimension.""" |
|
... |
|
|