|
import logging |
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
import torch |
|
from pydantic import BaseModel, ConfigDict, Field |
|
from transformers import pipeline |
|
|
|
from ..hf_pipeline import FeatureExtractionPipelineWithStriding |
|
from .span_embeddings import SpanEmbeddings |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
DEFAULT_MODEL_NAME = "allenai/specter2_base" |
|
|
|
|
|
class HuggingFaceSpanEmbeddings(BaseModel, SpanEmbeddings): |
|
"""An implementation of SpanEmbeddings using a modified HuggingFace Transformers |
|
feature-extraction pipeline, adapted for long text inputs by chunking with optional stride |
|
(see src.hf_pipeline.FeatureExtractionPipelineWithStriding). |
|
|
|
Note that calculating embeddings for multiple spans is efficient when all spans for a |
|
text are passed in a single call to embed_document_spans, as the text embedding is computed |
|
only once per unique text, and the span embeddings are simply pooled from these text embeddings. |
|
|
|
It accepts any model that can be used with the HuggingFace feature-extraction pipeline, also |
|
models with adapters such as SPECTER2 (see https://huggingface.co/allenai/specter2). In this case, |
|
the model should be loaded beforehand and passed as parameter 'model' instead of the model identifier. |
|
See https://huggingface.co/docs/transformers/main_classes/pipelines for further information. |
|
|
|
To use, you should have the ``transformers`` python package installed. |
|
|
|
Example: |
|
.. code-block:: python |
|
from transformers import AutoModel |
|
|
|
model = "allenai/specter2_base" |
|
pipeline_kwargs = {'device': 'cpu', 'stride': 64, 'batch_size': 8} |
|
encode_kwargs = {'normalize_embeddings': False} |
|
hf = HuggingFaceSpanEmbeddings( |
|
model=model, |
|
pipeline_kwargs=pipeline_kwargs, |
|
) |
|
|
|
text = "This is a test sentence." |
|
|
|
# calculate embeddings for text[0:4]="This" and text[15:23]="sentence" |
|
embeddings = hf.embed_document_spans(texts=[text, text], starts=[0, 11], ends=[4, 19]) |
|
""" |
|
|
|
client: Any = None |
|
model: Optional[Any] = DEFAULT_MODEL_NAME |
|
pooling_strategy: str = "mean" |
|
"""Model name to use.""" |
|
pipeline_kwargs: Dict[str, Any] = Field(default_factory=dict) |
|
"""Keyword arguments to pass to the Huggingface pipeline constructor.""" |
|
encode_kwargs: Dict[str, Any] = Field(default_factory=dict) |
|
"""Keyword arguments to pass when calling the pipeline.""" |
|
|
|
"""Whether to show a progress bar.""" |
|
model_max_length: Optional[int] = None |
|
"""The maximum input length of the model. Required for some model checkpoints with outdated configs.""" |
|
|
|
def __init__(self, **kwargs: Any): |
|
"""Initialize the sentence_transformer.""" |
|
super().__init__(**kwargs) |
|
|
|
self.client = pipeline( |
|
"feature-extraction", |
|
model=self.model, |
|
pipeline_class=FeatureExtractionPipelineWithStriding, |
|
trust_remote_code=True, |
|
**self.pipeline_kwargs, |
|
) |
|
|
|
|
|
|
|
|
|
if self.model_max_length is not None: |
|
self.client.tokenizer.model_max_length = self.model_max_length |
|
|
|
|
|
max_input_size = self.client.tokenizer.model_max_length |
|
if max_input_size > 1e5: |
|
raise ValueError( |
|
"The tokenizer does not specify a valid `model_max_length` attribute. " |
|
"Consider setting it manually by passing `model_max_length` to the " |
|
"HuggingFaceSpanEmbeddings constructor." |
|
) |
|
|
|
model_config = ConfigDict( |
|
extra="forbid", |
|
protected_namespaces=(), |
|
) |
|
|
|
def get_span_embedding( |
|
self, |
|
last_hidden_state: torch.Tensor, |
|
offset_mapping: torch.Tensor, |
|
start: Union[int, List[int]], |
|
end: Union[int, List[int]], |
|
**unused_kwargs, |
|
) -> Optional[List[float]]: |
|
"""Pool the span embeddings.""" |
|
if isinstance(start, int): |
|
start = [start] |
|
if isinstance(end, int): |
|
end = [end] |
|
if len(start) != len(end): |
|
raise ValueError("start and end should have the same length.") |
|
if len(start) == 0: |
|
raise ValueError("start and end should not be empty.") |
|
if last_hidden_state.shape[0] != 1: |
|
raise ValueError("last_hidden_state should have a batch size of 1.") |
|
if last_hidden_state.shape[0] != offset_mapping.shape[0]: |
|
raise ValueError( |
|
"last_hidden_state and offset_mapping should have the same batch size." |
|
) |
|
offset_mapping = offset_mapping[0] |
|
last_hidden_state = last_hidden_state[0] |
|
|
|
mask = (start[0] <= offset_mapping[:, 0]) & (offset_mapping[:, 1] <= end[0]) |
|
for s, e in zip(start[1:], end[1:]): |
|
mask = mask | ((s <= offset_mapping[:, 0]) & (offset_mapping[:, 1] <= e)) |
|
span_embeddings = last_hidden_state[mask] |
|
if span_embeddings.shape[0] == 0: |
|
return None |
|
if self.pooling_strategy == "mean": |
|
return span_embeddings.mean(dim=0).tolist() |
|
elif self.pooling_strategy == "max": |
|
return span_embeddings.max(dim=0).values.tolist() |
|
else: |
|
raise ValueError(f"Unknown pool strategy: {self.pooling_strategy}") |
|
|
|
def embed_document_spans( |
|
self, |
|
texts: List[str], |
|
starts: Union[List[int], List[List[int]]], |
|
ends: Union[List[int], List[List[int]]], |
|
) -> List[Optional[List[float]]]: |
|
"""Compute doc embeddings using a HuggingFace transformer model. |
|
|
|
Args: |
|
texts: The list of texts to embed. |
|
starts: The list of start indices or list of lists of start indices (multi-span). |
|
ends: The list of end indices or list of lists of end indices (multi-span). |
|
|
|
Returns: |
|
List of embeddings, one for each text. |
|
""" |
|
pipeline_kwargs = self.encode_kwargs.copy() |
|
pipeline_kwargs["return_offset_mapping"] = True |
|
|
|
if pipeline_kwargs.get("stride", None) is None: |
|
pipeline_kwargs["stride"] = 0 |
|
|
|
if pipeline_kwargs["stride"] > 0: |
|
pipeline_kwargs["create_unique_embeddings_per_token"] = True |
|
|
|
pipeline_kwargs["return_tensors"] = True |
|
|
|
unique_texts = sorted(set(texts)) |
|
idx2unique_idx = {i: unique_texts.index(text) for i, text in enumerate(texts)} |
|
pipeline_results = self.client(unique_texts, **pipeline_kwargs) |
|
embeddings = [ |
|
self.get_span_embedding( |
|
start=starts[idx], end=ends[idx], **pipeline_results[idx2unique_idx[idx]] |
|
) |
|
for idx in range(len(texts)) |
|
] |
|
return embeddings |
|
|
|
def embed_query_span( |
|
self, text: str, start: Union[int, List[int]], end: Union[int, List[int]] |
|
) -> Optional[List[float]]: |
|
"""Compute query embeddings using a HuggingFace transformer model. |
|
|
|
Args: |
|
text: The text to embed. |
|
start: The start index or list of start indices (multi-span). |
|
end: The end index or list of end indices (multi-span). |
|
|
|
Returns: |
|
Embeddings for the text. |
|
""" |
|
starts: Union[List[int], List[List[int]]] = [start] |
|
ends: Union[List[int], List[List[int]]] = [end] |
|
return self.embed_document_spans([text], starts=starts, ends=ends)[0] |
|
|
|
@property |
|
def embedding_dim(self) -> int: |
|
"""Get the embedding dimension.""" |
|
return self.client.model.config.hidden_size |
|
|