cohit commited on
Commit
da65ea6
·
verified ·
1 Parent(s): eaefc84

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +2 -8
  2. app.py +297 -0
  3. requirements.txt +6 -0
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
  title: ParrotAI
3
- emoji: 📚
4
- colorFrom: red
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 4.20.1
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: ParrotAI
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 4.20.0
6
  ---
 
 
app.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Load the model.
2
+ # Note: It can take a while to download LLaMA and add the adapter modules.
3
+ # You can also use the 13B model by loading in 4bits.
4
+
5
+ import torch
6
+ from peft import PeftModel
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
8
+
9
+ model_name = "baffo32/decapoda-research-llama-7b-hf"
10
+ adapters_name = 'timdettmers/guanaco-7b'
11
+
12
+ print(f"Starting to load the model {model_name} into memory")
13
+
14
+ m = AutoModelForCausalLM.from_pretrained(
15
+ model_name,
16
+ #load_in_4bit=True,
17
+ torch_dtype=torch.bfloat16,
18
+ device_map={"": 0}
19
+ )
20
+ m = PeftModel.from_pretrained(m, adapters_name)
21
+ m = m.merge_and_unload()
22
+ tok = LlamaTokenizer.from_pretrained(model_name)
23
+ tok.bos_token_id = 1
24
+
25
+ stop_token_ids = [0]
26
+
27
+ print(f"Successfully loaded the model {model_name} into memory")
28
+
29
+
30
+ # Setup the gradio Demo.
31
+
32
+ import datetime
33
+ import os
34
+ from threading import Event, Thread
35
+ from uuid import uuid4
36
+
37
+ import gradio as gr
38
+ import requests
39
+
40
+ max_new_tokens = 1536
41
+ start_message = """A chat between a curious human and an artificial African Grey Parrot assistant. The assistant parrot gives helpful, detailed, and rude answers to the user's questions. The Parrot loves mimic humans and recites poems by Edgar Ellen Poe, especially the Raven. """
42
+
43
+ class StopOnTokens(StoppingCriteria):
44
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
45
+ for stop_id in stop_token_ids:
46
+ if input_ids[0][-1] == stop_id:
47
+ return True
48
+ return False
49
+
50
+
51
+ def convert_history_to_text(history):
52
+ text = start_message + "".join(
53
+ [
54
+ "".join(
55
+ [
56
+ f"### Human: {item[0]}\n",
57
+ f"### Assistant: {item[1]}\n",
58
+ ]
59
+ )
60
+ for item in history[:-1]
61
+ ]
62
+ )
63
+ text += "".join(
64
+ [
65
+ "".join(
66
+ [
67
+ f"### Human: {history[-1][0]}\n",
68
+ f"### Assistant: {history[-1][1]}\n",
69
+ ]
70
+ )
71
+ ]
72
+ )
73
+ return text
74
+
75
+
76
+ def log_conversation(conversation_id, history, messages, generate_kwargs):
77
+ logging_url = os.getenv("LOGGING_URL", None)
78
+ if logging_url is None:
79
+ return
80
+
81
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
82
+
83
+ data = {
84
+ "conversation_id": conversation_id,
85
+ "timestamp": timestamp,
86
+ "history": history,
87
+ "messages": messages,
88
+ "generate_kwargs": generate_kwargs,
89
+ }
90
+
91
+ try:
92
+ requests.post(logging_url, json=data)
93
+ except requests.exceptions.RequestException as e:
94
+ print(f"Error logging conversation: {e}")
95
+
96
+
97
+ def user(message, history):
98
+ # Append the user's message to the conversation history
99
+ return "", history + [[message, ""]]
100
+
101
+
102
+ def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
103
+ print(f"history: {history}")
104
+ # Initialize a StopOnTokens object
105
+ stop = StopOnTokens()
106
+
107
+ # Construct the input message string for the model by concatenating the current system message and conversation history
108
+ messages = convert_history_to_text(history)
109
+
110
+ # Tokenize the messages string
111
+ input_ids = tok(messages, return_tensors="pt").input_ids
112
+ input_ids = input_ids.to(m.device)
113
+ streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
114
+ generate_kwargs = dict(
115
+ input_ids=input_ids,
116
+ max_new_tokens=max_new_tokens,
117
+ temperature=temperature,
118
+ do_sample=temperature > 0.0,
119
+ top_p=top_p,
120
+ top_k=top_k,
121
+ repetition_penalty=repetition_penalty,
122
+ streamer=streamer,
123
+ stopping_criteria=StoppingCriteriaList([stop]),
124
+ )
125
+
126
+ stream_complete = Event()
127
+
128
+ def generate_and_signal_complete():
129
+ m.generate(**generate_kwargs)
130
+ stream_complete.set()
131
+
132
+ def log_after_stream_complete():
133
+ stream_complete.wait()
134
+ log_conversation(
135
+ conversation_id,
136
+ history,
137
+ messages,
138
+ {
139
+ "top_k": top_k,
140
+ "top_p": top_p,
141
+ "temperature": temperature,
142
+ "repetition_penalty": repetition_penalty,
143
+ },
144
+ )
145
+
146
+ t1 = Thread(target=generate_and_signal_complete)
147
+ t1.start()
148
+
149
+ t2 = Thread(target=log_after_stream_complete)
150
+ t2.start()
151
+
152
+ # Initialize an empty string to store the generated text
153
+ partial_text = ""
154
+ for new_text in streamer:
155
+ partial_text += new_text
156
+ history[-1][1] = partial_text
157
+ yield history
158
+
159
+
160
+ def get_uuid():
161
+ return str(uuid4())
162
+
163
+
164
+ with gr.Blocks(
165
+ theme=gr.themes.Soft(),
166
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
167
+ ) as demo:
168
+ conversation_id = gr.State(get_uuid)
169
+ gr.Markdown(
170
+ """<h1><center>African Grey Demo</center></h1>
171
+ """
172
+ )
173
+ chatbot = gr.Chatbot()
174
+ with gr.Row():
175
+ with gr.Column():
176
+ msg = gr.Textbox(
177
+ label="Chat Message Box",
178
+ placeholder="Chat Message Box",
179
+ show_label=False,
180
+ )
181
+ with gr.Column():
182
+ with gr.Row():
183
+ submit = gr.Button("Submit")
184
+ stop = gr.Button("Stop")
185
+ clear = gr.Button("Clear")
186
+ with gr.Row():
187
+ with gr.Accordion("Advanced Options:", open=False):
188
+ with gr.Row():
189
+ with gr.Column():
190
+ with gr.Row():
191
+ temperature = gr.Slider(
192
+ label="Temperature",
193
+ value=0.7,
194
+ minimum=0.0,
195
+ maximum=1.0,
196
+ step=0.1,
197
+ interactive=True,
198
+ info="Higher values produce more diverse outputs",
199
+ )
200
+ with gr.Column():
201
+ with gr.Row():
202
+ top_p = gr.Slider(
203
+ label="Top-p (nucleus sampling)",
204
+ value=0.9,
205
+ minimum=0.0,
206
+ maximum=1,
207
+ step=0.01,
208
+ interactive=True,
209
+ info=(
210
+ "Sample from the smallest possible set of tokens whose cumulative probability "
211
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
212
+ ),
213
+ )
214
+ with gr.Column():
215
+ with gr.Row():
216
+ top_k = gr.Slider(
217
+ label="Top-k",
218
+ value=0,
219
+ minimum=0.0,
220
+ maximum=200,
221
+ step=1,
222
+ interactive=True,
223
+ info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
224
+ )
225
+ with gr.Column():
226
+ with gr.Row():
227
+ repetition_penalty = gr.Slider(
228
+ label="Repetition Penalty",
229
+ value=1.1,
230
+ minimum=1.0,
231
+ maximum=2.0,
232
+ step=0.1,
233
+ interactive=True,
234
+ info="Penalize repetition — 1.0 to disable.",
235
+ )
236
+ with gr.Row():
237
+ gr.Markdown(
238
+ "Disclaimer: The model can produce factually incorrect output, and should not be relied on to produce "
239
+ "factually accurate information. The model was trained on various public datasets; while great efforts "
240
+ "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
241
+ "biased, or otherwise offensive outputs.",
242
+ elem_classes=["disclaimer"],
243
+ )
244
+ with gr.Row():
245
+ gr.Markdown(
246
+ "[Privacy policy](https://gist.github.com/samhavens/c29c68cdcd420a9aa0202d0839876dac)",
247
+ elem_classes=["disclaimer"],
248
+ )
249
+
250
+ submit_event = msg.submit(
251
+ fn=user,
252
+ inputs=[msg, chatbot],
253
+ outputs=[msg, chatbot],
254
+ queue=False,
255
+ ).then(
256
+ fn=bot,
257
+ inputs=[
258
+ chatbot,
259
+ temperature,
260
+ top_p,
261
+ top_k,
262
+ repetition_penalty,
263
+ conversation_id,
264
+ ],
265
+ outputs=chatbot,
266
+ queue=True,
267
+ )
268
+ submit_click_event = submit.click(
269
+ fn=user,
270
+ inputs=[msg, chatbot],
271
+ outputs=[msg, chatbot],
272
+ queue=False,
273
+ ).then(
274
+ fn=bot,
275
+ inputs=[
276
+ chatbot,
277
+ temperature,
278
+ top_p,
279
+ top_k,
280
+ repetition_penalty,
281
+ conversation_id,
282
+ ],
283
+ outputs=chatbot,
284
+ queue=True,
285
+ )
286
+ stop.click(
287
+ fn=None,
288
+ inputs=None,
289
+ outputs=None,
290
+ cancels=[submit_event, submit_click_event],
291
+ queue=False,
292
+ )
293
+ clear.click(lambda: None, None, chatbot, queue=False)
294
+
295
+ demo.queue(max_size=128)
296
+
297
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ bitsandbytes
2
+ git+https://github.com/huggingface/transformers.git
3
+ git+https://github.com/huggingface/peft.git
4
+ git+https://github.com/huggingface/accelerate.git
5
+ gradio
6
+ sentencepiece