claim-detect / app.py
dpaul93's picture
Update app.py
87a94a4 verified
import gradio as gr
import pandas as pd
import json
import os
from pprint import pprint
import bitsandbytes as bnb
import torch
import torch.nn as nn
import transformers
import accelerate
from datasets import load_dataset, Dataset
from huggingface_hub import notebook_login
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from huggingface_hub import login
import os
access_token = os.environ["HF_Token"]
login(token=access_token)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
print("claim")
PEFT_MODEL = "dpaul93/falcon-7b-qlora-chat-claim-finetune" #"/content/trained-model"
config = PeftConfig.from_pretrained(PEFT_MODEL)
config.base_model_name_or_path = "tiiuae/falcon-7b"
'''model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
return_dict=True,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)'''
model = AutoModelForCausalLM.from_pretrained(PEFT_MODEL, device_map="auto",offload_folder="offload")
tokenizer=AutoTokenizer.from_pretrained(config.base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
#model = PeftModel.from_pretrained(model, PEFT_MODEL)
def generate_test_prompt(text):
return f"""Given the following claim:
{text}
pick one of the following option
(a) true
(b) false
(c) mixture
(d) unknown
(e) not_applicable?""".strip()
def generate_and_tokenize_prompt(text):
prompt = generate_test_prompt(text)
device = "cuda"
encoding = tokenizer(prompt, return_tensors="pt").to(device)
with torch.inference_mode():
outputs = model.generate(
input_ids = encoding.input_ids,
attention_mask = encoding.attention_mask,
generation_config = generation_config
)
return tokenizer.decode(outputs[0], skip_special_tokens=True).split("Answer:")[1].split("\n")[0].split(".")[0]
def classifyUsingLLAMA(text):
return generate_and_tokenize_prompt(text)
iface = gr.Interface(fn=classifyUsingLLAMA, inputs="text", outputs="text")
iface.launch()