b1ade
Collection
SLMs for RAG
•
5 items
•
Updated
Instruction fine tuned 1B parameter model; pass in:
context: <...>
question: <...>
and expect an answer: <...>
See implemetation example below (also see https://huggingface.co/spaces/w601sxs/b1ade-1b):
import torch
import transformers
import os, time
import tempfile
from transformers import AutoTokenizer, AutoModelForCausalLM
BASE_MODEL = "w601sxs/b1ade-1b-bf16"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
offload_folder="offload")
model.eval()
from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords_ids:list):
self.keywords = keywords_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if input_ids[0][-1] in self.keywords:
return True
return False
stop_words = ['>', ' >','> ']
stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
stop_criteria = StoppingCriteriaList([KeywordsStoppingCriteria(keywords_ids = stop_ids)])
def predict(text):
inputs = tokenizer(text, return_tensors="pt").to('cuda')
with torch.no_grad():
outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=128, stopping_criteria=stop_criteria)
out_text = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0].split("answer:")[-1]
return print(out_text.split(text)[-1])
predict("context: <The center contact of the bulb typically connects to the medium-power filament, and the ring connects to the low-power filament. Thus, if a 3-way bulb is screwed into a standard light socket that has only a center contact, only the medium-power filament operates. In the case of the 50 W / 100 W / 150 W bulb, putting this bulb in a regular lamp socket will result in it behaving like a normal 100W bulb.>\n question: <Question: Do 3 way light bulbs work in any lamp?>\n")