BinGSE
Collection
3 items
•
Updated
TODO: 2 line summary and link to paper
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from peft import PeftModel
if __name__ == "__main__":
# Loading base Meta-Llama-3 model, along with custom code that enables bidirectional connections in decoder-only LLMs.
tokenizer = AutoTokenizer.from_pretrained(
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp"
)
config = AutoConfig.from_pretrained(
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp", trust_remote_code=True
)
model = AutoModel.from_pretrained(
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
trust_remote_code=True,
config=config,
torch_dtype=torch.bfloat16,
device_map="cuda" if torch.cuda.is_available() else "cpu",
)
# Loading MNTP (Masked Next Token Prediction) model.
model = PeftModel.from_pretrained(
model,
"McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp",
)
model = model.merge_and_unload() # This can take several minutes on cpu
# Loading BinGSE model. This loads the trained LoRA weights on top of MNTP model. Hence the final weights are -- Base model + MNTP (LoRA) + BinGSE (LoRA).
model = PeftModel.from_pretrained(
model, model_path
)
TODO: initialize wrapper, provide example to check loading happened properly - see https://huggingface.co/McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-unsup-simcse