File size: 3,798 Bytes
2b4fa99
23e203e
2b4fa99
 
 
 
 
23e203e
2b4fa99
 
 
 
 
 
 
 
 
 
 
23e203e
 
 
 
2b4fa99
23e203e
 
 
 
2b4fa99
 
 
 
 
 
23e203e
2b4fa99
 
 
 
 
 
 
 
 
 
23e203e
2b4fa99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23e203e
 
2b4fa99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23e203e
2b4fa99
 
 
 
 
 
 
 
 
 
23e203e
 
 
 
2b4fa99
 
23e203e
2b4fa99
 
 
23e203e
 
 
 
2b4fa99
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import base64
import gradio as gr
import json
import mimetypes
import os
import requests
import time


MODEL_VERSION = os.environ['MODEL_VERSION']
API_URL = os.environ['API_URL']
API_KEY = os.environ['API_KEY']
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT')
MULTIMODAL_FLAG = os.environ.get('MULTIMODAL')
MODEL_CONTROL_DEFAULTS = json.loads(os.environ['MODEL_CONTROL_DEFAULTS'])
NAME_MAP = {
    'system': os.environ.get('SYSTEM_NAME'),
    'user': os.environ.get('USER_NAME'),
}


def respond(
    message,
    history,
    max_tokens,
    temperature,
    top_p,
):
    messages = []
    if SYSTEM_PROMPT is not None:
        messages.append({
            'role': 'system',
            'content': SYSTEM_PROMPT,
        })
    for val in history:
        messages.append({
            'role': val['role'],
            'content': convert_content(val['content']),
        })
    messages.append({
        'role': 'user',
        'content': convert_content(message),
    })
    for message in messages:
        add_name_for_message(message)

    data = {
        'model': MODEL_VERSION,
        'messages': messages,
        'stream': True,
        'max_tokens': max_tokens,
        'temperature': temperature,
        'top_p': top_p,
    }
    r = requests.post(
        API_URL,
        headers={
            'Content-Type': 'application/json',
            'Authorization': 'Bearer {}'.format(API_KEY),
        },
        data=json.dumps(data),
        stream=True,
    )
    reply = ''
    for row in r.iter_lines():
        if row.startswith(b'data:'):
            data = json.loads(row[5:])
            if 'choices' not in data:
                raise gr.Error('request failed')
            choice = data['choices'][0]
            if 'delta' in choice:
                reply += choice['delta']['content']
                yield reply
            elif 'message' in choice:
                yield choice['message']['content']


def add_name_for_message(message):
    name = NAME_MAP.get(message['role'])
    if name is not None:
        message['name'] = name


def convert_content(content):
    if isinstance(content, str):
        return content
    if isinstance(content, tuple):
        return [{
            'type': 'image_url',
            'image_url': {
                'url': encode_base64(content[0]),
            },
        }]
    content_list = []
    for key, val in content.items():
        if key == 'text':
            content_list.append({
                'type': 'text',
                'text': val,
            })
        elif key == 'files':
            for f in val:
                content_list.append({
                    'type': 'image_url',
                    'image_url': {
                        'url': encode_base64(f),
                    },
                })
    return content_list


def encode_base64(path):
    guess_type = mimetypes.guess_type(path)[0]
    if not guess_type.startswith('image/'):
        raise gr.Error('not an image ({}): {}'.format(guess_type, path))
    with open(path, 'rb') as handle:
        data = handle.read()
        return 'data:{};base64,{}'.format(
            guess_type,
            base64.b64encode(data).decode(),
        )


demo = gr.ChatInterface(
    respond,
    multimodal=MULTIMODAL_FLAG == 'ON',
    type='messages',
    additional_inputs=[
        gr.Slider(minimum=1, maximum=1000000, value=MODEL_CONTROL_DEFAULTS['tokens_to_generate'], step=1, label='Tokens to generate'),
        gr.Slider(minimum=0.1, maximum=1.0, value=MODEL_CONTROL_DEFAULTS['temperature'], step=0.05, label='Temperature'),
        gr.Slider(minimum=0.1, maximum=1.0, value=MODEL_CONTROL_DEFAULTS['top_p'], step=0.05, label='Top-p (nucleus sampling)'),
    ],
)


if __name__ == '__main__':
    demo.queue(default_concurrency_limit=50).launch()