Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from huggingface_hub import whoami
|
3 |
+
import datetime
|
4 |
+
from dataset_uploader import ParquetScheduler
|
5 |
+
|
6 |
+
##########
|
7 |
+
# Setup #
|
8 |
+
##########
|
9 |
+
|
10 |
+
contributor_username = whoami()["name"]
|
11 |
+
|
12 |
+
# only show an info the first time uploading to the hub
|
13 |
+
show_info = True
|
14 |
+
|
15 |
+
every = 1 # we push once every 1 minute (use 5 if there are lots of people using the same HF token)
|
16 |
+
|
17 |
+
choices = ["sharegpt","standard"]
|
18 |
+
|
19 |
+
# schedulers
|
20 |
+
schedulers = {
|
21 |
+
"sft-sharegpt": ParquetScheduler(repo_id=f"{contributor_username}/sft-sharegpt", every=every),
|
22 |
+
"sft-standard": ParquetScheduler(repo_id=f"{contributor_username}/sft-standard", every=every),
|
23 |
+
"dpo-sharegpt": ParquetScheduler(repo_id=f"{contributor_username}/dpo-sharegpt", every=every),
|
24 |
+
"dpo-standard": ParquetScheduler(repo_id=f"{contributor_username}/dpo-standard", every=every),
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
##########
|
29 |
+
# Utils #
|
30 |
+
##########
|
31 |
+
|
32 |
+
|
33 |
+
def chat_message(role, content, prompt_type=None):
|
34 |
+
"""
|
35 |
+
A function that transforms the chat content into a chat message
|
36 |
+
Args:
|
37 |
+
role: A string, either "user" or "assistant"
|
38 |
+
content: A string, the content of the message
|
39 |
+
prompt_type: A string, either "standard" or "sharegpt"
|
40 |
+
Returns:
|
41 |
+
A dictionary, the message to be sent to the chatbot.
|
42 |
+
"""
|
43 |
+
if prompt_type == "sharegpt":
|
44 |
+
if role == "user":
|
45 |
+
role = "human"
|
46 |
+
elif role == "assistant":
|
47 |
+
role = "gpt"
|
48 |
+
# sharegpt chat format
|
49 |
+
return {"from": role, "value": content}
|
50 |
+
else:
|
51 |
+
return {"role": role, "content": content}
|
52 |
+
|
53 |
+
|
54 |
+
def chat(prompt: str, history=[]):
|
55 |
+
"""
|
56 |
+
A function that generates a response to a given prompt.
|
57 |
+
Args:
|
58 |
+
prompt: A string, the prompt to be sent to the chatbot.
|
59 |
+
history: A list of dictionaries, each dictionary being a message from the user or the assistant.
|
60 |
+
Returns:
|
61 |
+
A generator in the form of a single updated list of dictionaries, being a list of messages from the user and assistant
|
62 |
+
"""
|
63 |
+
if history == [] or (len(history) > 1 and history[-1]["role"] == "assistant"):
|
64 |
+
history.append(chat_message("user", prompt))
|
65 |
+
else:
|
66 |
+
history.append(chat_message("assistant", prompt))
|
67 |
+
return history
|
68 |
+
|
69 |
+
|
70 |
+
def clear_textbox_field():
|
71 |
+
"""
|
72 |
+
A function that clears the textbox field.
|
73 |
+
"""
|
74 |
+
return None
|
75 |
+
|
76 |
+
|
77 |
+
def clear_both_fields():
|
78 |
+
"""
|
79 |
+
A function that clears both the textbox and the chatbot.
|
80 |
+
"""
|
81 |
+
return None, None
|
82 |
+
|
83 |
+
|
84 |
+
def clear_3_fields():
|
85 |
+
"""
|
86 |
+
A function that clears both the textbox and the chatbot.
|
87 |
+
"""
|
88 |
+
return None, None, None
|
89 |
+
|
90 |
+
|
91 |
+
def setup_submission(system_prompt="", history=[], chat_format="sharegpt"):
|
92 |
+
# removes the extra metadata field from the chat history and format sharegpt accordingly
|
93 |
+
for i in range(len(history)):
|
94 |
+
sample = history[i]
|
95 |
+
history[i] = chat_message(
|
96 |
+
sample["role"], sample["content"], prompt_type=chat_format
|
97 |
+
)
|
98 |
+
|
99 |
+
# add system prompt if provided
|
100 |
+
system_prompt = system_prompt.strip()
|
101 |
+
if system_prompt != "":
|
102 |
+
sys = chat_message("system", system_prompt, prompt_type=chat_format)
|
103 |
+
history.insert(0, sys)
|
104 |
+
|
105 |
+
return history
|
106 |
+
|
107 |
+
|
108 |
+
def save_sft_data(system_prompt="", history=[], sft_chat_format="sharegpt"):
|
109 |
+
"""
|
110 |
+
A function that pushes the data to the hub.
|
111 |
+
"""
|
112 |
+
|
113 |
+
# setup the info message to only show once
|
114 |
+
global show_info
|
115 |
+
scheduler = schedulers[f"sft-{sft_chat_format}"]
|
116 |
+
|
117 |
+
# case user clicked submit and did not have any chat history
|
118 |
+
if history == []:
|
119 |
+
raise gr.Error("you need to setup a chat first")
|
120 |
+
|
121 |
+
# case history ends with user prompt
|
122 |
+
if history[-1]["role"] == "user":
|
123 |
+
raise gr.Error("history needs to end with assistant prompt")
|
124 |
+
|
125 |
+
history = setup_submission(system_prompt, history, sft_chat_format)
|
126 |
+
# preparing the submission
|
127 |
+
data = {"contributor": contributor_username}
|
128 |
+
data["timestamp"] = str(datetime.datetime.now(datetime.UTC))
|
129 |
+
data["chat_format"] = sft_chat_format
|
130 |
+
data["conversations"] = history
|
131 |
+
|
132 |
+
# submitting the data
|
133 |
+
scheduler.append(data)
|
134 |
+
|
135 |
+
# show the info message only once
|
136 |
+
if show_info:
|
137 |
+
gr.Info("Data has been saved successfully (this message is only shown once)")
|
138 |
+
gr.Info(
|
139 |
+
"The scheduler may take up to 1 minute to push the data, please wait 🤗"
|
140 |
+
)
|
141 |
+
show_info = False
|
142 |
+
|
143 |
+
|
144 |
+
def save_dpo_data(
|
145 |
+
system_prompt="", history=[], chosen="", rejected="", dpo_chat_format="sharegpt"
|
146 |
+
):
|
147 |
+
"""
|
148 |
+
A function that pushes the data to the hub.
|
149 |
+
"""
|
150 |
+
|
151 |
+
# setup the info message to only show once
|
152 |
+
global show_info
|
153 |
+
scheduler = schedulers[f"dpo-{dpo_chat_format}"]
|
154 |
+
|
155 |
+
# case user clicked submit and did not have any chat history
|
156 |
+
if history == []:
|
157 |
+
raise gr.Error("you need to setup a chat first")
|
158 |
+
|
159 |
+
# case history ends with user prompt
|
160 |
+
if history[-1]["role"] == "assistant":
|
161 |
+
raise gr.Error("history needs to end with user prompt")
|
162 |
+
|
163 |
+
# case chosen and rejected are not full
|
164 |
+
chosen, rejected = chosen.strip(), rejected.strip()
|
165 |
+
if chosen == "" or rejected == "":
|
166 |
+
raise gr.Error(
|
167 |
+
"both chosen and rejected need to have a text when you click the submit button"
|
168 |
+
)
|
169 |
+
|
170 |
+
history = setup_submission(system_prompt, history, dpo_chat_format)
|
171 |
+
chosen_chat, rejected_chat = history.copy(), history.copy()
|
172 |
+
chosen_chat.append(chat_message("user", chosen, dpo_chat_format))
|
173 |
+
rejected_chat.append(chat_message("user", rejected, dpo_chat_format))
|
174 |
+
|
175 |
+
# preparing the submission
|
176 |
+
data = {"contributor": contributor_username}
|
177 |
+
|
178 |
+
data["timestamp"] = str(datetime.datetime.now(datetime.UTC))
|
179 |
+
data["chat_format"] = dpo_chat_format
|
180 |
+
data["prompt"] = history
|
181 |
+
data["chosen"] = chosen_chat
|
182 |
+
data["rejected"] = rejected_chat
|
183 |
+
|
184 |
+
# submitting the data
|
185 |
+
scheduler.append(data)
|
186 |
+
|
187 |
+
# show the info message only once
|
188 |
+
if show_info:
|
189 |
+
gr.Info("Data has been saved successfully (this message is only shown once)")
|
190 |
+
gr.Info(
|
191 |
+
"The scheduler may take up to 1 minute to push the data, please wait 🤗"
|
192 |
+
)
|
193 |
+
show_info = False
|
194 |
+
|
195 |
+
|
196 |
+
def undo_chat(history):
|
197 |
+
return history[:-2]
|
198 |
+
|
199 |
+
|
200 |
+
##############
|
201 |
+
# Interface #
|
202 |
+
##############
|
203 |
+
|
204 |
+
with gr.Blocks() as demo:
|
205 |
+
gr.Markdown("<h1 style='text-align: center'>ShareGPT-Builder</h1>")
|
206 |
+
|
207 |
+
#### SFT ####
|
208 |
+
with gr.Tab("SFT"):
|
209 |
+
with gr.Accordion("system prompt", open=False):
|
210 |
+
system_prompt = gr.TextArea(show_label=False, container=False)
|
211 |
+
sft_chat_format = gr.Radio(choices=choices, value="sharegpt")
|
212 |
+
|
213 |
+
chatbot = gr.Chatbot(
|
214 |
+
type="messages", show_copy_button=True, show_copy_all_button=True
|
215 |
+
)
|
216 |
+
textbox = gr.Textbox(show_label=False, submit_btn=True)
|
217 |
+
textbox.submit(
|
218 |
+
fn=chat, inputs=[textbox, chatbot], outputs=[chatbot]
|
219 |
+
).then( # empty field for convinience
|
220 |
+
clear_textbox_field, outputs=[textbox]
|
221 |
+
)
|
222 |
+
chatbot.undo(undo_chat, inputs=chatbot, outputs=chatbot)
|
223 |
+
with gr.Row():
|
224 |
+
clear_button = gr.Button("Clear")
|
225 |
+
clear_button.click(clear_both_fields, outputs=[textbox, chatbot])
|
226 |
+
submit = gr.Button("save chat", variant="primary")
|
227 |
+
submit.click(
|
228 |
+
save_sft_data, inputs=[system_prompt, chatbot, sft_chat_format]
|
229 |
+
).then(clear_both_fields, outputs=[textbox, chatbot])
|
230 |
+
|
231 |
+
#### DPO ####
|
232 |
+
with gr.Tab("DPO"):
|
233 |
+
with gr.Accordion("system prompt", open=False):
|
234 |
+
dpo_system_prompt = gr.TextArea(show_label=False, container=False)
|
235 |
+
dpo_chat_format = gr.Radio(choices=choices, value="sharegpt")
|
236 |
+
dpo_chatbot = gr.Chatbot(
|
237 |
+
type="messages", show_copy_button=True, show_copy_all_button=True
|
238 |
+
)
|
239 |
+
gr.Markdown(
|
240 |
+
"type in either of these fields and press enter, when you are ready for the final submission fill both fields, don't press enter and click on the save chat button"
|
241 |
+
)
|
242 |
+
with gr.Row():
|
243 |
+
dpo_rejected_textbox = gr.Textbox(label="rejected (or add chat)", render=True)
|
244 |
+
dpo_chosen_textbox = gr.Textbox(label="chosen (or add chat)")
|
245 |
+
# submit using either of these fields
|
246 |
+
dpo_chosen_textbox.submit(
|
247 |
+
fn=chat, inputs=[dpo_chosen_textbox, dpo_chatbot], outputs=[dpo_chatbot]
|
248 |
+
).then( # empty field for convinience
|
249 |
+
clear_textbox_field, outputs=[dpo_chosen_textbox]
|
250 |
+
)
|
251 |
+
dpo_rejected_textbox.submit(
|
252 |
+
fn=chat,
|
253 |
+
inputs=[dpo_rejected_textbox, dpo_chatbot],
|
254 |
+
outputs=[dpo_chatbot],
|
255 |
+
).then( # empty field for convinience
|
256 |
+
clear_textbox_field, outputs=[dpo_rejected_textbox]
|
257 |
+
)
|
258 |
+
dpo_chatbot.undo(undo_chat, inputs=dpo_chatbot, outputs=dpo_chatbot)
|
259 |
+
with gr.Row():
|
260 |
+
dpo_clear_button = gr.Button("Clear")
|
261 |
+
dpo_clear_button.click(
|
262 |
+
clear_3_fields,
|
263 |
+
outputs=[dpo_chosen_textbox, dpo_rejected_textbox, dpo_chatbot],
|
264 |
+
)
|
265 |
+
dpo_submit = gr.Button("save chat", variant="primary")
|
266 |
+
dpo_submit.click(
|
267 |
+
save_dpo_data,
|
268 |
+
inputs=[
|
269 |
+
dpo_system_prompt,
|
270 |
+
dpo_chatbot,
|
271 |
+
dpo_chosen_textbox,
|
272 |
+
dpo_rejected_textbox,
|
273 |
+
dpo_chat_format,
|
274 |
+
],
|
275 |
+
).then(
|
276 |
+
clear_3_fields,
|
277 |
+
outputs=[dpo_chosen_textbox, dpo_rejected_textbox, dpo_chatbot],
|
278 |
+
)
|
279 |
+
with gr.Tab("Inspect datasets"):
|
280 |
+
dataset = gr.Dropdown(choices=list(schedulers.keys()))
|
281 |
+
@gr.render(inputs=dataset)
|
282 |
+
def show_dataset(dataset) :
|
283 |
+
gr.HTML(f""" <iframe
|
284 |
+
src="https://huggingface.co/datasets/{contributor_username}/{dataset}/embed/viewer/default/train?row=0"
|
285 |
+
frameborder="0"
|
286 |
+
width="100%"
|
287 |
+
height="560px"
|
288 |
+
></iframe>""")
|
289 |
+
|
290 |
+
if __name__ == "__main__":
|
291 |
+
demo.launch(debug=True, show_error=True)
|