File size: 4,625 Bytes
d8f6559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr 
import os
from typing import List
import logging
import urllib.request
from utils import model_name_mapping, urial_template, openai_base_request
from constant import js_code_label, HEADER_MD
from openai import OpenAI
import datetime
# add logging info to console 
logging.basicConfig(level=logging.INFO)

URIAL_VERSION = "inst_1k_v4.help"
URIAL_URL = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt"
urial_prompt = urllib.request.urlopen(URIAL_URL).read().decode('utf-8')
urial_prompt = urial_prompt.replace("```", '"""') # new version of URIAL uses """ instead of ```
STOP_STRS = ['"""', '# Query:', '# Answer:']

addr_limit_counter = {}
LAST_UPDATE_TIME = datetime.datetime.now() 


def respond(
    message,
    history: list[tuple[str, str]],
    max_tokens,
    temperature,
    top_p,
    rp,
    model_name,
    api_key,
    request:gr.Request
):  
    global STOP_STRS, urial_prompt, LAST_UPDATE_TIME, addr_limit_counter
    rp = 1.0
    prompt = urial_template(urial_prompt, history, message)
    
    # _model_name = "meta-llama/Llama-3-8b-hf"
    _model_name = model_name_mapping(model_name)

    if api_key and len(api_key) == 64:
        api_key = api_key
    else:
        api_key = None

    # headers = request.headers
    # if already 24 hours passed, reset the counter
    if datetime.datetime.now() - LAST_UPDATE_TIME > datetime.timedelta(days=1):
        addr_limit_counter = {}
        LAST_UPDATE_TIME = datetime.datetime.now()
    host_addr = request.client.host
    if host_addr not in addr_limit_counter:
        addr_limit_counter[host_addr] = 0
    if addr_limit_counter[host_addr] > 100:
        return "You have reached the limit of 100 requests for today. Please use your own API key."

    infer_request = openai_base_request(prompt=prompt, model=_model_name, 
                                   temperature=temperature, 
                                   max_tokens=max_tokens, 
                                   top_p=top_p, 
                                   repetition_penalty=rp,
                                   stop=STOP_STRS, api_key=api_key)  
    addr_limit_counter[host_addr] += 1
    logging.info(f"Requesting chat completion from OpenAI API with model {_model_name}")
    logging.info(f"addr_limit_counter: {addr_limit_counter}; Last update time: {LAST_UPDATE_TIME};")

    response = ""
    for msg in infer_request:
        # print(msg.choices[0].delta.keys())
        if hasattr(msg.choices[0], "delta"):
            token = msg.choices[0].delta["content"]
        else:
            token = msg.choices[0].text
        should_stop = False
        for _stop in STOP_STRS:
            if _stop in response + token:
                should_stop = True
                break
        if should_stop:
            break
        response += token
        if response.endswith('\n"'):
            response = response[:-1]
        elif response.endswith('\n""'):
            response = response[:-2]
        yield response
 
with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown(HEADER_MD)
            model_name = gr.Radio(["Llama-3.1-405B-FP8", "Llama-3-70B", "Llama-3-8B", 
                                   "Mistral-7B-v0.1", 
                                   "Mixtral-8x22B", "Qwen1.5-72B", "Yi-34B", "Llama-2-7B", "Llama-2-70B", "OLMO"]
                                  , value="Llama-3.1-405B-FP8", label="Base LLM name")
        with gr.Column():
            api_key = gr.Textbox(label="🔑 APIKey", placeholder="Enter your Together/Hyperbolic API Key. Leave it blank to use our key with limited usage.", type="password", elem_id="api_key", visible=False)
            # with gr.Column():
            with gr.Accordion("⚙️ Parameters for Base LLM", open=True):
                with gr.Row():
                    max_tokens = gr.Textbox(value=256, label="Max tokens")
                    temperature = gr.Textbox(value=0.5, label="Temperature")
                    top_p = gr.Textbox(value=0.9, label="Top-p")
                    rp = gr.Textbox(value=1.1, label="Repetition penalty")
    # with gr.Row():            
    chat = gr.ChatInterface(
        respond,
        additional_inputs=[max_tokens, temperature, top_p, rp, model_name, api_key],
        # additional_inputs_accordion="⚙️ Parameters",
        # fill_height=True, 
    )
    chat.chatbot.label="Chat with Base LLMs via URIAL"
    chat.chatbot.height = 550
    chat.chatbot.show_copy_button = True  

if __name__ == "__main__": 
    demo.launch(show_api=False)