Spaces:
Sleeping
Sleeping
import json | |
import logging | |
import os | |
from typing import Dict, List, Optional, Union | |
import numpy as np | |
import torch | |
import torch.multiprocessing as mp | |
from peft import PeftModel | |
from torch import Tensor, device, nn | |
from tqdm.autonotebook import tqdm, trange | |
from transformers import ( | |
AutoModel, | |
AutoConfig, | |
PretrainedConfig, | |
AutoTokenizer, | |
LlamaConfig, | |
MistralConfig, | |
GemmaConfig, | |
Qwen2Config, | |
) | |
logger = logging.getLogger(__name__) | |
def batch_to_device(batch, target_device: device): | |
""" | |
send a pytorch batch to a device (CPU/GPU) | |
""" | |
for key in batch: | |
if isinstance(batch[key], Tensor): | |
batch[key] = batch[key].to(target_device) | |
return batch | |
class LLMEncoder(nn.Module): | |
def __init__( | |
self, | |
model: AutoModel, | |
tokenizer: AutoTokenizer, | |
pooling_mode: str = "weighted_mean", | |
max_length: int = 512, | |
doc_max_length: int = 400, | |
skip_instruction: bool = True, | |
): | |
super().__init__() | |
self.model = model | |
self.tokenizer = tokenizer | |
self.pooling_mode = pooling_mode | |
self.skip_instruction = skip_instruction | |
self.max_length = max_length | |
self.doc_max_length = doc_max_length | |
self.config = model.config | |
def from_pretrained( | |
self, | |
base_model_name_or_path, | |
peft_model_name_or_path=None, | |
cache_dir=None, | |
**kwargs, | |
): | |
""" | |
Load a pretrained model from a model identifier or path. | |
Args: | |
base_model_name_or_path: Model identifier or path to pretrained model. | |
peft_model_name_or_path: Path to any PEFT models to apply. | |
Returns: L3Prune model. | |
""" | |
# pop out encoder args | |
keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"] | |
encoder_args = { | |
key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None | |
} | |
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, cache_dir=cache_dir) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "left" | |
config = AutoConfig.from_pretrained(base_model_name_or_path) | |
model = AutoModel.from_pretrained(base_model_name_or_path, cache_dir=cache_dir, **kwargs) | |
if os.path.isdir(base_model_name_or_path) and os.path.exists( | |
f"{base_model_name_or_path}/config.json" | |
): | |
with open(f"{base_model_name_or_path}/config.json", "r") as fIn: | |
config_dict = json.load(fIn) | |
config = PretrainedConfig.from_dict(config_dict) | |
model.config._name_or_path = config._name_or_path | |
if peft_model_name_or_path is not None: | |
model = PeftModel.from_pretrained( | |
model, | |
peft_model_name_or_path, | |
) | |
model = model.merge_and_unload() | |
config = {} | |
if os.path.exists(f"{base_model_name_or_path}/l3prune_config.json"): | |
with open(f"{base_model_name_or_path}/l3prune_config.json", "r") as fIn: | |
l3prune_config = json.load(fIn) | |
config.update(l3prune_config) | |
for key, value in encoder_args.items(): | |
config[key] = value | |
return self(model=model, tokenizer=tokenizer, **config) | |
def prune(self, percent_prune=0): | |
""" | |
Prune a model to a percentage of layers of the base model. If percent_prune is equal to or greater than 1, | |
it is taken as the specific layer number to prune to. For example, if percent_prune=0.3, 30% of the layers will be pruned. If | |
percent_prune=3, the model will be pruned to 3 layers. | |
""" | |
# take it as the specific layer number to prune to | |
if percent_prune >= 1: | |
new_num_layers = int(percent_prune) | |
else: | |
new_num_layers = int(self.model.config.num_hidden_layers * (1 - percent_prune)) | |
print(f"Pruning to {new_num_layers} layer.") | |
self.model.layers = self.model.layers[:new_num_layers] | |
self.model.config.num_hidden_layers = new_num_layers | |
def prepare_for_tokenization(self, text): | |
if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct": | |
text = ( | |
"<|start_header_id|>user<|end_header_id|>\n\n" | |
+ text.strip() | |
+ "<|eot_id|>" | |
) | |
return text | |
if self.model.config._name_or_path in [ | |
"mistralai/Mistral-7B-Instruct-v0.2", | |
"meta-llama/Llama-2-7b-chat-hf", | |
]: | |
text = "[INST] " + text.strip() + " [/INST]" | |
if self.model.config._name_or_path in [ | |
"google/gemma-2-9b-it", | |
]: | |
text = "<bos><start_of_turn>user\n" + text.strip() + "<end_of_turn>" | |
if self.model.config._name_or_path in [ | |
"Qwen/Qwen2-1.5B-Instruct", | |
"Qwen/Qwen2-7B-Instruct", | |
]: | |
text = "<|im_start|>user\n" + text.strip() + "<|im_end|>" | |
if self.pooling_mode == "eos_token": | |
if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": | |
text = text.strip() + "<|end_of_text|>" | |
elif isinstance(self.model.config, LlamaConfig) or isinstance( | |
self.model.config, MistralConfig | |
): | |
text = text.strip() + " </s>" | |
elif isinstance(self.model.config, GemmaConfig): | |
text = text.strip() + "<eos>" | |
elif isinstance(self.model.config, Qwen2Config): | |
text = text.strip() + "<|endoftext|>" | |
return text | |
def tokenize(self, texts): | |
texts_2 = [] | |
original_texts = [] | |
for text in texts: | |
t = text.split("!@#$%^&*()") | |
texts_2.append(t[1] if len(t) > 1 else "") | |
original_texts.append("".join(t)) | |
original = self.tokenizer( | |
original_texts, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=self.max_length, | |
) | |
embed_mask = None | |
for t_i, t in enumerate(texts_2): | |
ids = self.tokenizer( | |
[t], | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=self.max_length, | |
add_special_tokens=False, | |
) | |
if embed_mask is None: | |
e_m = torch.zeros_like(original["attention_mask"][t_i]) | |
if len(ids["input_ids"][0]) > 0: | |
e_m[-len(ids["input_ids"][0]) :] = torch.ones( | |
len(ids["input_ids"][0]) | |
) | |
embed_mask = e_m.unsqueeze(0) | |
else: | |
e_m = torch.zeros_like(original["attention_mask"][t_i]) | |
if len(ids["input_ids"][0]) > 0: | |
e_m[-len(ids["input_ids"][0]) :] = torch.ones( | |
len(ids["input_ids"][0]) | |
) | |
embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) | |
original["embed_mask"] = embed_mask | |
return original | |
def _skip_instruction(self, sentence_feature): | |
assert ( | |
sentence_feature["attention_mask"].shape | |
== sentence_feature["embed_mask"].shape | |
) | |
sentence_feature["attention_mask"] = sentence_feature["embed_mask"] | |
def forward(self, sentence_feature: Dict[str, Tensor]): | |
embed_mask = None | |
if "embed_mask" in sentence_feature: | |
embed_mask = sentence_feature.pop("embed_mask") | |
reps = self.model(**sentence_feature) | |
sentence_feature["embed_mask"] = embed_mask | |
return self.get_pooling(sentence_feature, reps.last_hidden_state) | |
def get_pooling(self, features, last_hidden_states): # All models padded from left | |
assert ( | |
self.tokenizer.padding_side == "left" | |
), "Pooling modes are implemented for padding from left." | |
if self.skip_instruction: | |
self._skip_instruction(features) | |
seq_lengths = features["attention_mask"].sum(dim=-1) | |
if self.pooling_mode == "mean": | |
return torch.stack( | |
[ | |
last_hidden_states[i, -length:, :].mean(dim=0) | |
for i, length in enumerate(seq_lengths) | |
], | |
dim=0, | |
) | |
elif self.pooling_mode == "weighted_mean": | |
bs, l, _ = last_hidden_states.shape | |
complete_weights = torch.zeros(bs, l, device=last_hidden_states.device) | |
for i, seq_l in enumerate(seq_lengths): | |
if seq_l > 0: | |
complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1 | |
complete_weights[i] /= torch.clamp( | |
complete_weights[i].sum(), min=1e-9 | |
) | |
return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1) | |
elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token": | |
return last_hidden_states[:, -1] | |
elif self.pooling_mode == "bos_token": | |
return last_hidden_states[ | |
features["input_ids"] == self.tokenizer.bos_token_id | |
] | |
else: | |
raise ValueError(f"{self.pooling_mode} is not implemented yet.") | |
def _convert_to_str(self, instruction, text): | |
tokenized_q = self.tokenizer( | |
text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=self.max_length, | |
add_special_tokens=False, | |
) | |
tokenized_q_length = len(tokenized_q["input_ids"][0]) | |
while tokenized_q_length > self.doc_max_length: | |
reduction_ratio = self.doc_max_length / tokenized_q_length | |
reduced_length = int(len(text.split()) * reduction_ratio) | |
text = " ".join(text.split()[:reduced_length]) | |
tokenized_q = self.tokenizer( | |
text, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=self.max_length, | |
add_special_tokens=False, | |
) | |
tokenized_q_length = len(tokenized_q["input_ids"][0]) | |
return ( | |
f"{instruction.strip()} !@#$%^&*(){text}" | |
if instruction | |
else f"!@#$%^&*(){text}" | |
) | |
def encode( | |
self, | |
sentences: Union[str, List[str]], | |
batch_size: int = 32, | |
show_progress_bar: bool = True, | |
convert_to_numpy: bool = False, | |
convert_to_tensor: bool = False, | |
device: Optional[str] = None, | |
): | |
""" | |
Encode a list of sentences to their respective embeddings. The sentences can be a list of strings or a string. | |
Args: | |
sentences: sentence or sentences to encode. | |
batch_size: batch size for turning sentence tokens into embeddings. | |
show_progress_bar: whether to show progress bars during encoding steps. | |
convert_to_numpy: If true, return numpy arrays instead of torch tensors. | |
convert_to_tensor: If true, return torch tensors (default). | |
device: torch backend device identifier (e.g., 'cuda', 'cpu','mps' etc.). If not specified, | |
the default is to use cuda when available, otherwise cpu. Note that only the choice of 'cuda' supports | |
multiprocessing as currently implemented. | |
Returns: embeddings of the sentences. Embeddings are detached and always on the CPU (see _encode implementation). | |
""" | |
if isinstance(sentences[0], str) and isinstance(sentences[-1], int): | |
sentences = [sentences] | |
# required for MEDI version of MTEB | |
if isinstance(sentences[0], str): | |
sentences = [[""] + [sentence] for sentence in sentences] | |
if device is None: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
concatenated_input_texts = [] | |
for sentence in sentences: | |
assert isinstance(sentence[0], str) | |
assert isinstance(sentence[1], str) | |
concatenated_input_texts.append( | |
self._convert_to_str(sentence[0], sentence[1]) | |
) | |
sentences = concatenated_input_texts | |
self.eval() | |
if convert_to_tensor: | |
convert_to_numpy = False | |
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) | |
sentences_sorted = [sentences[idx] for idx in length_sorted_idx] | |
all_embeddings = [] | |
if torch.cuda.device_count() <= 1: | |
# This branch also support mps devices | |
self.to(device) | |
for start_index in trange( | |
0, | |
len(sentences), | |
batch_size, | |
desc="Batches", | |
disable=not show_progress_bar, | |
): | |
sentences_batch = sentences_sorted[ | |
start_index : start_index + batch_size | |
] | |
embeddings = self._encode( | |
sentences_batch, device=device, convert_to_numpy=convert_to_numpy | |
) | |
all_embeddings.append(embeddings) | |
else: | |
num_proc = torch.cuda.device_count() | |
cuda_compatible_multiprocess = mp.get_context("spawn") | |
with cuda_compatible_multiprocess.Pool(num_proc) as p: | |
sentences_batches = [ | |
sentences_sorted[start_index : start_index + batch_size] | |
for start_index in range(0, len(sentences), batch_size) | |
] | |
progress_bar = tqdm( | |
total=len(sentences_batches), | |
desc="Batches", | |
disable=not show_progress_bar, | |
) | |
results = [] | |
def update(*args): | |
progress_bar.update() | |
for batch in sentences_batches: | |
results.append( | |
p.apply_async( | |
self._encode, | |
args=(batch, None, convert_to_numpy, True), | |
callback=update, | |
) | |
) | |
all_embeddings = [result.get() for result in results] | |
progress_bar.close() | |
all_embeddings = torch.cat(all_embeddings, dim=0) | |
all_embeddings = all_embeddings[np.argsort(length_sorted_idx)] | |
all_embeddings = all_embeddings.to(torch.float32) | |
if convert_to_numpy: | |
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) | |
return all_embeddings | |
def save(self, output_path, merge_before_save=False, save_config=True): | |
if merge_before_save and isinstance(self.model, PeftModel): | |
self.model = self.model.merge_and_unload() | |
if hasattr(self.model, "_hf_peft_config_loaded"): | |
self.model._hf_peft_config_loaded = False | |
self.model.save_pretrained(output_path) | |
self.tokenizer.save_pretrained(output_path) | |
l3prune_config = { | |
"pooling_mode": self.pooling_mode, | |
"max_length": self.max_length, | |
"doc_max_length": self.doc_max_length, | |
"skip_instruction": self.skip_instruction, | |
} | |
if save_config: | |
os.makedirs(output_path, exist_ok=True) | |
with open(f"{output_path}/l3prune_config.json", "w") as fOut: | |
json.dump(l3prune_config, fOut, indent=4) | |
def _encode( | |
self, | |
sentences_batch, | |
device: Optional[str] = None, | |
convert_to_numpy: bool = False, | |
multiprocessing=False, | |
): | |
if multiprocessing: | |
# multiprocessing only supports CUDA devices at this time, so we ignore the value of device | |
# and use cuda:rank for the device | |
rank = mp.current_process()._identity[0] | |
if device is None and torch.cuda.is_available(): | |
device = f"cuda:{rank % torch.cuda.device_count()}" | |
self.to(device) | |
features = self.tokenize( | |
[self.prepare_for_tokenization(sentence) for sentence in sentences_batch] | |
) | |
features = batch_to_device(features, device) | |
with torch.no_grad(): | |
embeddings = self.forward(features) | |
embeddings = embeddings.detach() | |
embeddings = embeddings.cpu() | |
return embeddings | |
def _text_length(self, text: Union[List[int], List[List[int]]]): | |
""" | |
Help function to get the length for the input text. Text can be either a string (which means a single text) | |
a list of ints (which means a single tokenized text), or a tuple of list of ints | |
(representing several text inputs to the model). | |
""" | |
if ( | |
isinstance(text, str) | |
or (isinstance(text, list) and isinstance(text[0], int)) | |
or len(text) == 0 | |
): # Single text, list of ints, or empty | |
return len(text) | |
if isinstance(text, dict): # {key: value} case | |
return len(next(iter(text.values()))) | |
elif not hasattr(text, "__len__"): # Object has no len() method | |
return 1 | |
else: | |
return sum([len(t) for t in text]) | |
def resize_token_embeddings( | |
self, | |
new_num_tokens: Optional[int] = None, | |
pad_to_multiple_of: Optional[int] = None, | |
) -> nn.Embedding: | |
return self.model.resize_token_embeddings( | |
new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of | |
) | |
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): | |
self.model.gradient_checkpointing_enable( | |
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs | |
) | |