claim-detect / app.py
dpaul93's picture
Update app.py
c390e36 verified
raw
history blame
2.24 kB
import gradio as gr
import pandas as pd
import json
import os
from pprint import pprint
import bitsandbytes as bnb
import torch
import torch.nn as nn
import transformers
import accelerate
from datasets import load_dataset, Dataset
from huggingface_hub import notebook_login
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from huggingface_hub import login
import os
access_token = os.environ["HF_Token"]
login(token=access_token)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
print("claim")
PEFT_MODEL = "dpaul93/falcon-7b-qlora-chat-claim-finetune" #"/content/trained-model"
config = PeftConfig.from_pretrained(PEFT_MODEL)
config.base_model_name_or_path = "tiiuae/falcon-7b"
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
tokenizer=AutoTokenizer.from_pretrained(config.base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
model = PeftModel.from_pretrained(model, PEFT_MODEL)
def generate_test_prompt(text):
return f"""Given the following claim:
{text}
pick one of the following option
(a) true
(b) false
(c) mixture
(d) unknown
(e) not_applicable?""".strip()
def generate_and_tokenize_prompt(text):
prompt = generate_test_prompt(text)
device = "cuda"
encoding = tokenizer(prompt, return_tensors="pt").to(device)
with torch.inference_mode():
outputs = model.generate(
input_ids = encoding.input_ids,
attention_mask = encoding.attention_mask,
generation_config = generation_config
)
return tokenizer.decode(outputs[0], skip_special_tokens=True).split("Answer:")[1].split("\n")[0].split(".")[0]
def classifyUsingLLAMA(text):
return generate_and_tokenize_prompt(text)
iface = gr.Interface(fn=classifyUsingLLAMA, inputs="text", outputs="text")
iface.launch()