Inspired by sentosa/ZNV-Embedding: A prompt-engineering way to aggregate 'title' info into embeddings.(modifications have been implemented) To do:
- Re-train the dense layers.
- Re-define a more effective concatenation.
- Adopt AnglE to finetune the tiny-llama.
- Loss function.
To run TE_Embedding model:
import os
from transformers import (AutoConfig,
AutoTokenizer,AutoModelForCausalLM
)
import torch
import torch.nn.functional as F
import numpy as np
class TEmbeddingModel(torch.nn.Module):
def __init__(self, model_name_or_path):
super(TEmbeddingModel, self).__init__()
self.prompt_prefix = "Reading the below text and answer questions:\n"
self.prompt_suffixes = ["\n1.One word to summarize the above text:",
"\n2.The deeper meaning of the above text:"]
self.hidden_size = 2048 #depends on the model
self.model_name_or_path = model_name_or_path
self.linear_suffixes = torch.nn.ModuleList(
[torch.nn.Linear(self.hidden_size, self.hidden_size//len(self.prompt_suffixes))
for _ in range(len(self.prompt_suffixes))])
self.tokenizer, self.llama = self.load_llama()
# self.device = torch.device('cuda')
self.tanh = torch.nn.Tanh()
self.suffixes_ids = []
self.suffixes_ids_len = []
self.suffixes_len = 0
for suffix in self.prompt_suffixes:
ids = self.tokenizer(suffix, return_tensors="pt")["input_ids"].tolist()[0]
self.suffixes_ids += ids
self.suffixes_ids_len.append(len(ids))
self.suffixes_len += len(ids)
self.suffixes_ones = torch.ones(self.suffixes_len)
self.suffixes_ids = torch.tensor(self.suffixes_ids)
linear_file = ".//TE//linears"
load_layers = torch.load(linear_file)
model_state = self.state_dict()
model_state.update(load_layers)
self.load_state_dict(model_state, strict=False)
def load_llama(self):
llm_path = os.path.join(self.model_name_or_path)
config = AutoConfig.from_pretrained(llm_path)
tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path)
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(
llm_path,
config=config,
low_cpu_mem_usage=True,
device_map="auto",
)
model.config.use_cache = False
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
return tokenizer, model
def forward(self, sentences):
prompts_embeddings = []
sentences = [self.prompt_prefix + s for s in sentences] #concat前缀
inputs = self.tokenizer(sentences, max_length=256, padding=True, truncation=True,
return_tensors='pt')
attention_mask = inputs["attention_mask"]
input_ids = inputs["input_ids"]
batch_size = len(sentences)
suffixes_ones = self.suffixes_ones.unsqueeze(0)
suffixes_ones = suffixes_ones.repeat(batch_size, 1)
device = next(self.parameters()).device
attention_mask = torch.cat([attention_mask, suffixes_ones], dim=-1).to(device)
suffixes_ids = self.suffixes_ids.unsqueeze(0)
suffixes_ids = suffixes_ids.repeat(batch_size, 1)
input_ids = torch.cat([input_ids, suffixes_ids], dim=-1) #to("cuda")
last_hidden_state = self.llama.base_model(attention_mask=attention_mask, input_ids=input_ids).last_hidden_state.to(device)
index = -1
for i in range(len(self.suffixes_ids_len)):
embedding = last_hidden_state[:, index, :]
embedding = self.linear_suffixes[i](embedding)
prompts_embeddings.append(embedding)
index -= self.suffixes_ids_len[-i-1]
output_embedding = torch.cat(prompts_embeddings, dim=-1)
output_embedding = self.tanh(output_embedding)
output_embedding = F.normalize(output_embedding, p=2, dim=1)
return output_embedding
def encode(self, sentences, batch_size=10, **kwargs):
size = len(sentences)
embeddings = None
handled = 0
while handled < size:
tokens = sentences[handled:handled + batch_size]
output_embeddings = self.forward(tokens)
result = output_embeddings.detach().cpu().numpy()
handled += result.shape[0] # <=10
if embeddings is not None:
embeddings = np.concatenate((embeddings, result), axis=0)
else:
embeddings = result
return embeddings
if __name__ == "__main__":
# TE_model = TEmbeddingModel("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
TE_model = TEmbeddingModel("technicolor/TE_Tinyllama")
TE_model.eval()
with torch.no_grad():
output = TE_model(["Hello", "Nice to meet you"])
cos_sim = F.cosine_similarity(output[0],output[1],dim=0)
print(cos_sim)
- Downloads last month
- 121
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.