aplusb / app.py
Nefertury's picture
Create app.py
9d0493a
import json
import os.path as osp
import random
from typing import Union
import os
import sys
from typing import List
import torch
import transformers
from datasets import load_dataset
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
import gradio as gr
import torch.nn as nn
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
set_peft_model_state_dict,
PeftModel
)
from transformers import LlamaForCausalLM, LlamaTokenizer
base_model='nickypro/tinyllama-15M'
class Prompter(object):
def generate_prompt(
self,
instruction: str,
label: Union[None, str] = None,
) -> str:
res = f"{instruction}\nAnswer: "
if label:
res = f"{res}{label}"
return res
def get_response(self, output: str) -> str:
return output.split("Answer:")[1].strip().replace("/", "\u00F7").replace("*", "\u00D7")
model = LlamaForCausalLM.from_pretrained(
base_model,
torch_dtype=torch.float32,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
f'checkpoint-16000',
torch_dtype=torch.float32,
)
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
def generate_answers(instructions, model, tokenizer):
prompter = Prompter()
raw_answers = []
for instruction in instructions:
prompt = prompter.generate_prompt(instruction)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
generation_output = model.generate(
input_ids=input_ids,
return_dict_in_generate=True,
output_scores=True,
pad_token_id=0,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=16
)
s = generation_output.sequences[0]
raw_answers.append(tokenizer.decode(s, skip_special_tokens=True).strip())
return raw_answers
def evaluate(instruction):
return generate_answers([instruction], model, tokenizer)[0]
if __name__ == "__main__":
gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=1,
label="Arithmetic",
placeholder="63303235 + 20239503",
)
],
outputs=[
gr.Textbox(
lines=5,
label="Output",
)
],
title="Arithmetic LLaMA",
description="This model is 15M llama model, finetuned on a+b tasks",
).queue().launch(server_name="0.0.0.0", share=True)