codegen_xlcost / app.py
giulio98's picture
Update app.py
51547bc
import gradio as gr
import torch
from typing import Dict, List, Any
from transformers import AutoTokenizer, CodeGenForCausalLM, set_seed
import argparse
import json
import os
from pathlib import Path
import random
from time import time
import torch
# check for GPU
device = 0 if torch.cuda.is_available() else "cpu"
seed = 16
max_length=2048
top_p=0.95
num_return_sequences=1
pad_token_id=50256
prefix = "# Import libraries.\n\nimport numpy as np\n"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def truncate(completion):
import re
def find_re(string, pattern, start_pos):
m = pattern.search(string, start_pos)
return m.start() if m else -1
terminals = [re.compile(r, re.MULTILINE) for r in ['^#', re.escape('<|endoftext|>'), "^'''", '^"""', '\n\n\n']]
prints = list(re.finditer('^print', completion, re.MULTILINE))
if len(prints) > 1:
completion = completion[:prints[1].start()]
defs = list(re.finditer('^def', completion, re.MULTILINE))
if len(defs) > 1:
completion = completion[:defs[1].start()]
start_pos = 0
terminals_pos = [pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1]
if len(terminals_pos) > 0:
return completion[:min(terminals_pos)]
else:
return completion
tokenizer = AutoTokenizer.from_pretrained("giulio98/codegen-2B-mono-xlcost")
if torch.cuda.is_available():
model = CodeGenForCausalLM.from_pretrained("giulio98/codegen-2B-mono-xlcost", torch_dtype=torch.float16)
else:
model = CodeGenForCausalLM.from_pretrained("giulio98/codegen-2B-mono-xlcost")
# create inference pipeline
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_size = "left"
model.to(device)
title = "Code Generator "
def sample_multi(prompt, max_gen_length, temp, force_word):
force_words = []
if force_word != '':
force_words.append(tokenizer([force_word], add_special_tokens=False).input_ids)
force_words = [item for sublist in force_words for item in sublist]
input_ids = tokenizer(
prompt,
truncation=True,
padding=True,
return_tensors="pt",
).input_ids
input_ids_len = input_ids.shape[1]
assert input_ids_len < max_length
with torch.no_grad():
input_ids = input_ids.to(device)
if len(force_words) != 0:
tokens = model.generate(
input_ids,
do_sample=False,
num_return_sequences=num_return_sequences,
temperature=temp,
max_length=input_ids_len + max_gen_length,
top_p=top_p,
pad_token_id=pad_token_id,
use_cache=True,
num_beams=5,
force_words_ids=force_words
)
else:
tokens = model.generate(
input_ids,
do_sample=True,
num_return_sequences=num_return_sequences,
temperature=temp,
max_length=input_ids_len + max_gen_length,
top_p=top_p,
pad_token_id=pad_token_id,
use_cache=True,
)
completion_txt = tokenizer.batch_decode(tokens[:, input_ids_len:, ...], skip_special_tokens=True)
return completion_txt[0]
def sample_completion(prompt, function, max_gen_length, temp, force_word, print_=False):
print("prompt is: ", prompt)
print("function is: ", function)
prompt = prompt + "\n"
original_text = prompt
prompt = prefix + prompt
print("prompt after is: ", prompt)
bad_words = []
force_words = []
bad_words.append(tokenizer(['###'], add_special_tokens=False).input_ids)
bad_words.append(tokenizer(['\''], add_special_tokens=False).input_ids)
bad_words.append(tokenizer(['\'\''], add_special_tokens=False).input_ids)
bad_words.append(tokenizer(['\'\'\''], add_special_tokens=False).input_ids)
if force_word != '':
force_words.append(tokenizer([force_word], add_special_tokens=False).input_ids)
if function:
pass
#force_words.append(tokenizer(['def'], add_special_tokens=False).input_ids)
else:
bad_words.append(tokenizer(['def'], add_special_tokens=False).input_ids)
print("last prompt: ", prompt.split("#")[-1].lower())
if "print" in prompt.split("#")[-1].lower():
force_words.append(tokenizer(['print'], add_special_tokens=False).input_ids)
print_ = True
force_words = [item for sublist in force_words for item in sublist]
bad_words = [item for sublist in bad_words for item in sublist]
input_ids = tokenizer(
prompt,
truncation=True,
padding=True,
return_tensors="pt",
).input_ids
print("force words", force_words)
print("bad words", bad_words)
input_ids_len = input_ids.shape[1]
assert input_ids_len < max_length
with torch.no_grad():
input_ids = input_ids.to(device)
if len(force_words) == 0 and len(bad_words)==0:
tokens = model.generate(
input_ids,
do_sample=True,
num_return_sequences=num_return_sequences,
temperature=temp,
max_length=input_ids_len + max_gen_length,
top_p=top_p,
pad_token_id=pad_token_id,
use_cache=True,
)
elif len(force_words) == 0 and len(bad_words)!=0:
tokens = model.generate(
input_ids,
do_sample=True,
num_return_sequences=num_return_sequences,
temperature=temp,
max_length=input_ids_len + max_gen_length,
top_p=top_p,
pad_token_id=pad_token_id,
use_cache=True,
bad_words_ids= bad_words
)
elif len(force_words)!=0 and len(bad_words) ==0:
tokens = model.generate(
input_ids,
do_sample=False,
num_return_sequences=num_return_sequences,
temperature=temp,
max_length=input_ids_len + max_gen_length,
top_p=top_p,
pad_token_id=pad_token_id,
use_cache=True,
num_beams=5,
force_words_ids=force_words
)
elif len(force_words)!=0 and len(bad_words) !=0:
tokens = model.generate(
input_ids,
do_sample=False,
num_return_sequences=num_return_sequences,
temperature=temp,
max_length=input_ids_len + max_gen_length,
top_p=top_p,
pad_token_id=pad_token_id,
use_cache=True,
num_beams=5,
force_words_ids=force_words,
bad_words_ids= bad_words
)
completion_txt = tokenizer.batch_decode(tokens[:, input_ids_len:, ...])
print("before truncate:", completion_txt[0])
print("after truncate:", truncate(completion_txt[0]))
return truncate(completion_txt[0]), original_text
def complete_with_gpt(text, function, tokens_auto, temp_auto, force_word):
# Use the last 50 characters of the text as context
text = text.lstrip().rstrip()
completion, original_text = sample_completion(text[-1024:], function, tokens_auto, temp_auto, force_word)
return original_text + completion
def make_prompt(gen_prompt):
return "\"\"\"\n" + gen_prompt + "\n\"\"\"\n###\n"
def complete_multi(text, tokens_multi, temp_multi, force_word):
# Use the last 50 characters of the text as context
text = make_prompt(text.lstrip().rstrip())
completion = sample_multi(text, tokens_multi, temp_multi, force_word)
return completion
def predict(input, history=[], temp=0.2, max_gen_length=256):
# tokenize the new input sentence
new_user_input_ids = tokenizer.encode(make_prompt(input), return_tensors='pt')
# append the new user input tokens to the chat history
bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
bot_input_ids = bot_input_ids.to(device)
# generate a response
history = model.generate(bot_input_ids, do_sample=True, num_return_sequences=num_return_sequences, temperature=temp, max_new_tokens=max_gen_length, top_p=top_p, use_cache=True, pad_token_id=tokenizer.eos_token_id).tolist()
print(history)
# convert the tokens to text, and then split the responses into lines
response = tokenizer.decode(history[0]).split("\n\"\"\"\n###\n")
print('decoded_response-->>'+str(response))
response = [(response[i].replace("\"\"\"\n", ""), response[i+1].replace('<|endoftext|>', "")) for i in range(0, len(response - 1), 2)] # convert to tuples of list
#print('response-->>'+str(response))
return response, history
with gr.Blocks() as demo:
with gr.Tabs():
with gr.TabItem("Autocomplete"):
with gr.Row():
textbox = gr.Textbox(placeholder="Type here...", lines=16)
checkbox = gr.Checkbox(label="Function implementation?")
tokens_auto = gr.inputs.Slider(
minimum=8,
maximum=256,
step=1,
default=128,
label="Number of tokens to generate",
)
temp_auto = gr.inputs.Slider(
minimum=0,
maximum=2.5,
step=0.1,
default=0.2,
label="Temperature",
)
textbox_force = gr.Textbox(label="Insert force word...")
btn_autocomplete = gr.Button("Generate")
with gr.TabItem("Multisteps"):
with gr.Row():
textbox_input = gr.Textbox(lines=10, label="English instructions")
textbox_output = gr.Textbox(label="Predicted Python code", lines=10)
tokens_multi = gr.inputs.Slider(
minimum=8,
maximum=256,
step=1,
default=128,
label="Number of tokens to generate",
)
temp_multi = gr.inputs.Slider(
minimum=0,
maximum=2.5,
step=0.1,
default=0.2,
label="Temperature",
)
textbox_force_multi = gr.Textbox(label="Insert force word...")
btn_multi = gr.Button("Generate")
btn_autocomplete.click(complete_with_gpt, inputs=[textbox, checkbox, tokens_auto, temp_auto, textbox_force], outputs=[textbox])
btn_multi.click(complete_multi, inputs=[textbox_input, tokens_multi, temp_multi, textbox_force_multi], outputs=textbox_output)
#btn_chat.click(predict, inputs=[text_chat, state], outputs=[chatbot, state])
demo.launch()