sam-pointer-bart-base-v0.3 / src /langchain_modules /huggingface_span_embeddings.py
ArneBinder's picture
new demo setup with langchain retriever
2cc87ec verified
raw
history blame
8.21 kB
import logging
from typing import Any, Dict, List, Optional, Union # type: ignore[import-not-found]
import torch
from pydantic import BaseModel, ConfigDict, Field
from transformers import pipeline
from ..hf_pipeline import FeatureExtractionPipelineWithStriding
from .span_embeddings import SpanEmbeddings
logger = logging.getLogger(__name__)
DEFAULT_MODEL_NAME = "allenai/specter2_base"
class HuggingFaceSpanEmbeddings(BaseModel, SpanEmbeddings):
"""An implementation of SpanEmbeddings using a modified HuggingFace Transformers
feature-extraction pipeline, adapted for long text inputs by chunking with optional stride
(see src.hf_pipeline.FeatureExtractionPipelineWithStriding).
Note that calculating embeddings for multiple spans is efficient when all spans for a
text are passed in a single call to embed_document_spans, as the text embedding is computed
only once per unique text, and the span embeddings are simply pooled from these text embeddings.
It accepts any model that can be used with the HuggingFace feature-extraction pipeline, also
models with adapters such as SPECTER2 (see https://huggingface.co/allenai/specter2). In this case,
the model should be loaded beforehand and passed as parameter 'model' instead of the model identifier.
See https://huggingface.co/docs/transformers/main_classes/pipelines for further information.
To use, you should have the ``transformers`` python package installed.
Example:
.. code-block:: python
from transformers import AutoModel
model = "allenai/specter2_base"
pipeline_kwargs = {'device': 'cpu', 'stride': 64, 'batch_size': 8}
encode_kwargs = {'normalize_embeddings': False}
hf = HuggingFaceSpanEmbeddings(
model=model,
pipeline_kwargs=pipeline_kwargs,
)
text = "This is a test sentence."
# calculate embeddings for text[0:4]="This" and text[15:23]="sentence"
embeddings = hf.embed_document_spans(texts=[text, text], starts=[0, 11], ends=[4, 19])
"""
client: Any = None #: :meta private:
model: Optional[Any] = DEFAULT_MODEL_NAME
pooling_strategy: str = "mean"
"""Model name to use."""
pipeline_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the Huggingface pipeline constructor."""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass when calling the pipeline."""
# show_progress: bool = False
"""Whether to show a progress bar."""
model_max_length: Optional[int] = None
"""The maximum input length of the model. Required for some model checkpoints with outdated configs."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
super().__init__(**kwargs)
self.client = pipeline(
"feature-extraction",
model=self.model,
pipeline_class=FeatureExtractionPipelineWithStriding,
trust_remote_code=True,
**self.pipeline_kwargs,
)
# The Transformers library is buggy since 4.40.0,
# see https://github.com/huggingface/transformers/issues/30643,
# so we need to set the max_length to e.g. 512 manually
if self.model_max_length is not None:
self.client.tokenizer.model_max_length = self.model_max_length
# Check if the model has a valid max length
max_input_size = self.client.tokenizer.model_max_length
if max_input_size > 1e5: # A high threshold to catch "unlimited" values
raise ValueError(
"The tokenizer does not specify a valid `model_max_length` attribute. "
"Consider setting it manually by passing `model_max_length` to the "
"HuggingFaceSpanEmbeddings constructor."
)
model_config = ConfigDict(
extra="forbid",
protected_namespaces=(),
)
def get_span_embedding(
self,
last_hidden_state: torch.Tensor,
offset_mapping: torch.Tensor,
start: Union[int, List[int]],
end: Union[int, List[int]],
**unused_kwargs,
) -> Optional[List[float]]:
"""Pool the span embeddings."""
if isinstance(start, int):
start = [start]
if isinstance(end, int):
end = [end]
if len(start) != len(end):
raise ValueError("start and end should have the same length.")
if len(start) == 0:
raise ValueError("start and end should not be empty.")
if last_hidden_state.shape[0] != 1:
raise ValueError("last_hidden_state should have a batch size of 1.")
if last_hidden_state.shape[0] != offset_mapping.shape[0]:
raise ValueError(
"last_hidden_state and offset_mapping should have the same batch size."
)
offset_mapping = offset_mapping[0]
last_hidden_state = last_hidden_state[0]
mask = (start[0] <= offset_mapping[:, 0]) & (offset_mapping[:, 1] <= end[0])
for s, e in zip(start[1:], end[1:]):
mask = mask | ((s <= offset_mapping[:, 0]) & (offset_mapping[:, 1] <= e))
span_embeddings = last_hidden_state[mask]
if span_embeddings.shape[0] == 0:
return None
if self.pooling_strategy == "mean":
return span_embeddings.mean(dim=0).tolist()
elif self.pooling_strategy == "max":
return span_embeddings.max(dim=0).values.tolist()
else:
raise ValueError(f"Unknown pool strategy: {self.pooling_strategy}")
def embed_document_spans(
self,
texts: List[str],
starts: Union[List[int], List[List[int]]],
ends: Union[List[int], List[List[int]]],
) -> List[Optional[List[float]]]:
"""Compute doc embeddings using a HuggingFace transformer model.
Args:
texts: The list of texts to embed.
starts: The list of start indices or list of lists of start indices (multi-span).
ends: The list of end indices or list of lists of end indices (multi-span).
Returns:
List of embeddings, one for each text.
"""
pipeline_kwargs = self.encode_kwargs.copy()
pipeline_kwargs["return_offset_mapping"] = True
# we enable long text handling by providing the stride parameter
if pipeline_kwargs.get("stride", None) is None:
pipeline_kwargs["stride"] = 0
# when stride is positive, we need to create unique embeddings per token
if pipeline_kwargs["stride"] > 0:
pipeline_kwargs["create_unique_embeddings_per_token"] = True
# we ask for tensors to efficiently compute the span embeddings
pipeline_kwargs["return_tensors"] = True
unique_texts = sorted(set(texts))
idx2unique_idx = {i: unique_texts.index(text) for i, text in enumerate(texts)}
pipeline_results = self.client(unique_texts, **pipeline_kwargs)
embeddings = [
self.get_span_embedding(
start=starts[idx], end=ends[idx], **pipeline_results[idx2unique_idx[idx]]
)
for idx in range(len(texts))
]
return embeddings
def embed_query_span(
self, text: str, start: Union[int, List[int]], end: Union[int, List[int]]
) -> Optional[List[float]]:
"""Compute query embeddings using a HuggingFace transformer model.
Args:
text: The text to embed.
start: The start index or list of start indices (multi-span).
end: The end index or list of end indices (multi-span).
Returns:
Embeddings for the text.
"""
starts: Union[List[int], List[List[int]]] = [start] # type: ignore[assignment]
ends: Union[List[int], List[List[int]]] = [end] # type: ignore[assignment]
return self.embed_document_spans([text], starts=starts, ends=ends)[0]
@property
def embedding_dim(self) -> int:
"""Get the embedding dimension."""
return self.client.model.config.hidden_size