Spaces:
Sleeping
Sleeping
from huggingface_hub import InferenceClient, HfApi, upload_file | |
import datetime | |
import gradio as gr | |
import random | |
import prompts | |
import json | |
import uuid | |
import os | |
token=os.environ.get("HF_TOKEN") | |
username="johann22" | |
dataset_name="chat-roulette-1" | |
api=HfApi(token="") | |
client = InferenceClient( | |
"mistralai/Mixtral-8x7B-Instruct-v0.1" | |
) | |
history = [] | |
hist_out= [] | |
summary =[] | |
main_point=[] | |
summary.append("") | |
main_point.append("") | |
def format_prompt(message, history): | |
prompt = "<s>" | |
for user_prompt, bot_response in history: | |
prompt += f"[INST] {user_prompt} [/INST]" | |
prompt += f" {bot_response}</s> " | |
prompt += f"[INST] {message} [/INST]" | |
return prompt | |
agents =[ | |
"QUESTION_GENERATOR", | |
"AI_REPORT_WRITER", | |
] | |
temperature=0.9 | |
max_new_tokens=256 | |
max_new_tokens2=1048 | |
top_p=0.95 | |
repetition_penalty=1.0, | |
def compress_history(formatted_prompt): | |
seed = random.randint(1,1111111111111111) | |
agent=prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0]) | |
system_prompt=agent | |
temperature = 0.9 | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=30480, | |
top_p=0.95, | |
repetition_penalty=1.0, | |
do_sample=True, | |
seed=seed, | |
) | |
#history.append((prompt,"")) | |
#formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) | |
formatted_prompt = formatted_prompt | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
#history.append((output,history)) | |
print(output) | |
print(main_point[0]) | |
return output | |
def question_generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,): | |
#def question_generate(prompt, history): | |
seed = random.randint(1,1111111111111111) | |
agent=prompts.QUESTION_GENERATOR.format(focus=main_point[0]) | |
system_prompt=agent | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=seed, | |
) | |
#history.append((prompt,"")) | |
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
for response in stream: | |
output += response.token.text | |
#history.append((output,history)) | |
return output | |
def create_valid_filename(invalid_filename: str) -> str: | |
"""Converts invalid characters in a string to be suitable for a filename.""" | |
invalid_filename.replace(" ","-") | |
valid_chars = '-'.join(invalid_filename.split()) | |
allowed_chars = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', | |
'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', | |
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', | |
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', | |
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '_', '-') | |
return ''.join(char for char in valid_chars if char in allowed_chars) | |
def generate(prompt, history, agent_name=agents[0], sys_prompt="", temperature=0.9, max_new_tokens=1048, top_p=0.95, repetition_penalty=1.0,): | |
main_point[0]=prompt | |
#print(datetime.datetime.now()) | |
uid=uuid.uuid4() | |
current_time = str(datetime.datetime.now()) | |
current_time=current_time.replace(":","-") | |
current_time=current_time.replace(".","-") | |
print (current_time) | |
agent=prompts.AI_REPORT_WRITER | |
system_prompt=agent | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
hist_out=[] | |
sum_out=[] | |
json_hist={} | |
json_obj={} | |
filename=create_valid_filename(f'{prompt}---{current_time}') | |
while True: | |
seed = random.randint(1,1111111111111111) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens2, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=seed, | |
) | |
if prompt.startswith(' \"'): | |
prompt=prompt.strip(' \"') | |
formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history) | |
if len(formatted_prompt) < (50000): | |
print(len(formatted_prompt)) | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
output = "" | |
#if history: | |
# yield history | |
for response in stream: | |
output += response.token.text | |
yield '', [(prompt,output)],summary[0],json_obj, json_hist | |
out_json = {"prompt":prompt,"output":output} | |
prompt = question_generate(output, history) | |
#output += prompt | |
history.append((prompt,output)) | |
print ( f'Prompt:: {len(prompt)}') | |
#print ( f'output:: {output}') | |
print ( f'history:: {len(formatted_prompt)}') | |
hist_out.append(out_json) | |
#try: | |
# for ea in | |
with open(f'{uid}.json', 'w') as f: | |
json_hist=json.dumps(hist_out, indent=4) | |
f.write(json_hist) | |
f.close() | |
upload_file( | |
path_or_fileobj =f"{uid}.json", | |
path_in_repo = f"test/{filename}.json", | |
repo_id =f"{username}/{dataset_name}", | |
repo_type = "dataset", | |
token=token, | |
) | |
else: | |
formatted_prompt = format_prompt(f"{prompts.COMPRESS_HISTORY_PROMPT.format(history=summary[0],focus=main_point[0])}, {summary[0]}", history) | |
#current_time = str(datetime.datetime.now().timestamp()).split(".",1)[0] | |
#filename=f'{filename}-{current_time}' | |
history = [] | |
output = compress_history(formatted_prompt) | |
summary[0]=output | |
sum_json = {"summary":summary[0]} | |
sum_out.append(sum_json) | |
with open(f'{uid}-sum.json', 'w') as f: | |
json_obj=json.dumps(sum_out, indent=4) | |
f.write(json_obj) | |
f.close() | |
upload_file( | |
path_or_fileobj =f"{uid}-sum.json", | |
path_in_repo = f"summary/{filename}-summary.json", | |
repo_id =f"{username}/{dataset_name}", | |
repo_type = "dataset", | |
token=token, | |
) | |
prompt = question_generate(output, history) | |
return prompt, history, summary[0],json_obj,json_hist | |
with gr.Blocks() as iface: | |
gr.HTML("""<center><h1>Chat Roulette</h1><br><h3>This chatbot will respond to itself with additional questions</h3></center>""") | |
#chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"), | |
chatbot=gr.Chatbot() | |
msg = gr.Textbox() | |
with gr.Row(): | |
submit_b = gr.Button() | |
stop_b = gr.Button("Stop") | |
clear = gr.ClearButton([msg, chatbot]) | |
sumbox=gr.Textbox("Summary", max_lines=100) | |
with gr.Column(): | |
sum_out_box=gr.JSON(label="Summaries") | |
hist_out_box=gr.JSON(label="History") | |
sub_b = submit_b.click(generate, [msg,chatbot],[msg,chatbot,sumbox,sum_out_box,hist_out_box]) | |
sub_e = msg.submit(generate, [msg, chatbot], [msg, chatbot,sumbox,sum_out_box,hist_out_box]) | |
stop_b.click(None,None,None, cancels=[sub_b,sub_e]) | |
iface.queue(default_concurrency_limit=10).launch() | |