BinGSE
Collection
3 items
•
Updated
TODO: 2 line summary and link to paper
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from peft.peft_model import PeftModel
if __name__ == "__main__":
# Loading base Sheared-Llama model, along with custom code that enables bidirectional connections in decoder-only LLMs.
tokenizer = AutoTokenizer.from_pretrained(
"McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp"
)
config = AutoConfig.from_pretrained(
"McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp", trust_remote_code=True
)
model = AutoModel.from_pretrained(
"McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp",
trust_remote_code=True,
config=config,
torch_dtype=torch.bfloat16,
device_map="cuda" if torch.cuda.is_available() else "cpu",
)
# Load the MNTP LoRA weights
model = PeftModel.from_pretrained(
model,
"McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp",
)
# Merge the LoRA weights with the base model
model = model.merge_and_unload() # This can take several minutes on cpu
# Loading BinGE model. This loads the trained LoRA weights on top of MNTP model. Hence the final weights are -- Base model + MNTP (LoRA) + BinGE (LoRA).
model = PeftModel.from_pretrained(
model, "tsirif/BinGE-Sheared-LLaMA"
)
TODO: initialize wrapper, provide example to check loading happened properly - see https://huggingface.co/McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp-unsup-simcse