import logging from typing import Any, Dict, List, Optional, Union # type: ignore[import-not-found] import torch from pydantic import BaseModel, ConfigDict, Field from transformers import pipeline from ..hf_pipeline import FeatureExtractionPipelineWithStriding from .span_embeddings import SpanEmbeddings logger = logging.getLogger(__name__) DEFAULT_MODEL_NAME = "allenai/specter2_base" class HuggingFaceSpanEmbeddings(BaseModel, SpanEmbeddings): """An implementation of SpanEmbeddings using a modified HuggingFace Transformers feature-extraction pipeline, adapted for long text inputs by chunking with optional stride (see src.hf_pipeline.FeatureExtractionPipelineWithStriding). Note that calculating embeddings for multiple spans is efficient when all spans for a text are passed in a single call to embed_document_spans, as the text embedding is computed only once per unique text, and the span embeddings are simply pooled from these text embeddings. It accepts any model that can be used with the HuggingFace feature-extraction pipeline, also models with adapters such as SPECTER2 (see https://huggingface.co/allenai/specter2). In this case, the model should be loaded beforehand and passed as parameter 'model' instead of the model identifier. See https://huggingface.co/docs/transformers/main_classes/pipelines for further information. To use, you should have the ``transformers`` python package installed. Example: .. code-block:: python from transformers import AutoModel model = "allenai/specter2_base" pipeline_kwargs = {'device': 'cpu', 'stride': 64, 'batch_size': 8} encode_kwargs = {'normalize_embeddings': False} hf = HuggingFaceSpanEmbeddings( model=model, pipeline_kwargs=pipeline_kwargs, ) text = "This is a test sentence." # calculate embeddings for text[0:4]="This" and text[15:23]="sentence" embeddings = hf.embed_document_spans(texts=[text, text], starts=[0, 11], ends=[4, 19]) """ client: Any = None #: :meta private: model: Optional[Any] = DEFAULT_MODEL_NAME pooling_strategy: str = "mean" """Model name to use.""" pipeline_kwargs: Dict[str, Any] = Field(default_factory=dict) """Keyword arguments to pass to the Huggingface pipeline constructor.""" encode_kwargs: Dict[str, Any] = Field(default_factory=dict) """Keyword arguments to pass when calling the pipeline.""" # show_progress: bool = False """Whether to show a progress bar.""" model_max_length: Optional[int] = None """The maximum input length of the model. Required for some model checkpoints with outdated configs.""" def __init__(self, **kwargs: Any): """Initialize the sentence_transformer.""" super().__init__(**kwargs) self.client = pipeline( "feature-extraction", model=self.model, pipeline_class=FeatureExtractionPipelineWithStriding, trust_remote_code=True, **self.pipeline_kwargs, ) # The Transformers library is buggy since 4.40.0, # see https://github.com/huggingface/transformers/issues/30643, # so we need to set the max_length to e.g. 512 manually if self.model_max_length is not None: self.client.tokenizer.model_max_length = self.model_max_length # Check if the model has a valid max length max_input_size = self.client.tokenizer.model_max_length if max_input_size > 1e5: # A high threshold to catch "unlimited" values raise ValueError( "The tokenizer does not specify a valid `model_max_length` attribute. " "Consider setting it manually by passing `model_max_length` to the " "HuggingFaceSpanEmbeddings constructor." ) model_config = ConfigDict( extra="forbid", protected_namespaces=(), ) def get_span_embedding( self, last_hidden_state: torch.Tensor, offset_mapping: torch.Tensor, start: Union[int, List[int]], end: Union[int, List[int]], **unused_kwargs, ) -> Optional[List[float]]: """Pool the span embeddings.""" if isinstance(start, int): start = [start] if isinstance(end, int): end = [end] if len(start) != len(end): raise ValueError("start and end should have the same length.") if len(start) == 0: raise ValueError("start and end should not be empty.") if last_hidden_state.shape[0] != 1: raise ValueError("last_hidden_state should have a batch size of 1.") if last_hidden_state.shape[0] != offset_mapping.shape[0]: raise ValueError( "last_hidden_state and offset_mapping should have the same batch size." ) offset_mapping = offset_mapping[0] last_hidden_state = last_hidden_state[0] mask = (start[0] <= offset_mapping[:, 0]) & (offset_mapping[:, 1] <= end[0]) for s, e in zip(start[1:], end[1:]): mask = mask | ((s <= offset_mapping[:, 0]) & (offset_mapping[:, 1] <= e)) span_embeddings = last_hidden_state[mask] if span_embeddings.shape[0] == 0: return None if self.pooling_strategy == "mean": return span_embeddings.mean(dim=0).tolist() elif self.pooling_strategy == "max": return span_embeddings.max(dim=0).values.tolist() else: raise ValueError(f"Unknown pool strategy: {self.pooling_strategy}") def embed_document_spans( self, texts: List[str], starts: Union[List[int], List[List[int]]], ends: Union[List[int], List[List[int]]], ) -> List[Optional[List[float]]]: """Compute doc embeddings using a HuggingFace transformer model. Args: texts: The list of texts to embed. starts: The list of start indices or list of lists of start indices (multi-span). ends: The list of end indices or list of lists of end indices (multi-span). Returns: List of embeddings, one for each text. """ pipeline_kwargs = self.encode_kwargs.copy() pipeline_kwargs["return_offset_mapping"] = True # we enable long text handling by providing the stride parameter if pipeline_kwargs.get("stride", None) is None: pipeline_kwargs["stride"] = 0 # when stride is positive, we need to create unique embeddings per token if pipeline_kwargs["stride"] > 0: pipeline_kwargs["create_unique_embeddings_per_token"] = True # we ask for tensors to efficiently compute the span embeddings pipeline_kwargs["return_tensors"] = True unique_texts = sorted(set(texts)) idx2unique_idx = {i: unique_texts.index(text) for i, text in enumerate(texts)} pipeline_results = self.client(unique_texts, **pipeline_kwargs) embeddings = [ self.get_span_embedding( start=starts[idx], end=ends[idx], **pipeline_results[idx2unique_idx[idx]] ) for idx in range(len(texts)) ] return embeddings def embed_query_span( self, text: str, start: Union[int, List[int]], end: Union[int, List[int]] ) -> Optional[List[float]]: """Compute query embeddings using a HuggingFace transformer model. Args: text: The text to embed. start: The start index or list of start indices (multi-span). end: The end index or list of end indices (multi-span). Returns: Embeddings for the text. """ starts: Union[List[int], List[List[int]]] = [start] # type: ignore[assignment] ends: Union[List[int], List[List[int]]] = [end] # type: ignore[assignment] return self.embed_document_spans([text], starts=starts, ends=ends)[0] @property def embedding_dim(self) -> int: """Get the embedding dimension.""" return self.client.model.config.hidden_size