Svngoku commited on
Commit
b5b97af
·
verified ·
1 Parent(s): 40deac7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +172 -0
README.md CHANGED
@@ -216,7 +216,179 @@ with open("studio/unsloth_studio/chat.py", "r") as chat_module:
216
  )
217
  exec(code)
218
  ```
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
221
 
222
  [<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
 
216
  )
217
  exec(code)
218
  ```
219
+ - Change the `chat.py`
220
 
221
+ ```py
222
+ # Unsloth Studio
223
+ # Copyright (C) 2024-present the Unsloth AI team. All rights reserved.
224
+
225
+ # This program is free software: you can redistribute it and/or modify
226
+ # it under the terms of the GNU Affero General Public License as published
227
+ # by the Free Software Foundation, either version 3 of the License, or
228
+ # (at your option) any later version.
229
+
230
+ # This program is distributed in the hope that it will be useful,
231
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
232
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
233
+ # GNU Affero General Public License for more details.
234
+
235
+ # You should have received a copy of the GNU Affero General Public License
236
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
237
+
238
+ from IPython.display import clear_output
239
+ import subprocess
240
+ import os
241
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
242
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
243
+ MODEL_NAME = "vutuka/Llama-3.1-8B-Instruct-African-Ultrachat"
244
+
245
+ print("Installing packages for 🦥 Unsloth Studio ... Please wait 1 minute ...")
246
+
247
+ install_first = [
248
+ "pip", "install",
249
+ "huggingface_hub[hf_transfer]",
250
+ ]
251
+ install_first = subprocess.Popen(install_first)
252
+ install_first.wait()
253
+
254
+ install_second = [
255
+ "pip", "install",
256
+ "gradio",
257
+ "unsloth[colab-new]@git+https://github.com/unslothai/unsloth.git",
258
+ ]
259
+ install_second = subprocess.Popen(install_second)
260
+
261
+ from huggingface_hub import snapshot_download
262
+ import warnings
263
+ warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
264
+ warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
265
+ warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub")
266
+ warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
267
+ warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
268
+ warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
269
+ warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocessing")
270
+ warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess")
271
+
272
+ from huggingface_hub.utils import disable_progress_bars
273
+ disable_progress_bars()
274
+ snapshot_download(repo_id = MODEL_NAME, repo_type = "model")
275
+
276
+ install_second.wait()
277
+
278
+ install_dependencies = [
279
+ "pip", "install", "--no-deps",
280
+ "xformers<0.0.27", "trl<0.9.0", "peft", "accelerate", "bitsandbytes",
281
+ ]
282
+ install_dependencies = subprocess.Popen(install_dependencies)
283
+ install_dependencies.wait()
284
+ clear_output()
285
+
286
+
287
+ from contextlib import redirect_stdout
288
+ import io
289
+ import logging
290
+ logging.getLogger("transformers.utils.hub").setLevel(logging.CRITICAL+1)
291
+
292
+ print("Loading model ... Please wait 1 more minute! ...")
293
+
294
+ with redirect_stdout(io.StringIO()):
295
+ from unsloth import FastLanguageModel
296
+ import torch
297
+ model, tokenizer = FastLanguageModel.from_pretrained(
298
+ model_name = MODEL_NAME,
299
+ max_seq_length = None,
300
+ dtype = None,
301
+ load_in_4bit = True,
302
+ )
303
+ FastLanguageModel.for_inference(model)
304
+ pass
305
+ clear_output()
306
+
307
+ import gradio
308
+ gradio.strings.en["SHARE_LINK_DISPLAY"] = ""
309
+ from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
310
+ from threading import Thread
311
+
312
+ class StopOnTokens(StoppingCriteria):
313
+ def __init__(self, stop_token_ids):
314
+ self.stop_token_ids = tuple(set(stop_token_ids))
315
+ pass
316
+
317
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
318
+ return input_ids[0][-1].item() in self.stop_token_ids
319
+ pass
320
+ pass
321
+
322
+ def async_process_chatbot(message, history):
323
+ eos_token = tokenizer.eos_token
324
+ stop_on_tokens = StopOnTokens([eos_token,])
325
+ text_streamer = TextIteratorStreamer(tokenizer, skip_prompt = True)
326
+
327
+ # From https://www.gradio.app/guides/creating-a-chatbot-fast
328
+ history_transformer_format = history + [[message, ""]]
329
+ messages = []
330
+ for item in history_transformer_format:
331
+ messages.append({"role": "user", "content": item[0]})
332
+ messages.append({"role": "assistant", "content": item[1]})
333
+ pass
334
+ # Remove last assistant and instead use add_generation_prompt
335
+ messages.pop(-1)
336
+
337
+ input_ids = tokenizer.apply_chat_template(
338
+ messages,
339
+ add_generation_prompt = True,
340
+ return_tensors = "pt",
341
+ ).to("cuda", non_blocking = True)
342
+
343
+ # Add stopping criteria - will not output EOS / EOT
344
+ generation_kwargs = dict(
345
+ input_ids = input_ids,
346
+ streamer = text_streamer,
347
+ max_new_tokens = 1024,
348
+ stopping_criteria = StoppingCriteriaList([stop_on_tokens,]),
349
+ temperature = 0.7,
350
+ do_sample = True,
351
+ )
352
+ thread = Thread(target = model.generate, kwargs = generation_kwargs)
353
+ thread.start()
354
+
355
+ # Yield will save the output to history!
356
+ generated_text = ""
357
+ for new_text in text_streamer:
358
+ if new_text.endswith(eos_token):
359
+ new_text = new_text[:len(new_text) - len(eos_token)]
360
+ generated_text += new_text
361
+ yield generated_text
362
+ pass
363
+ pass
364
+
365
+ studio_theme = gradio.themes.Soft(
366
+ primary_hue = "teal",
367
+ )
368
+
369
+ scene = gradio.ChatInterface(
370
+ async_process_chatbot,
371
+ chatbot = gradio.Chatbot(
372
+ height = 325,
373
+ label = "Unsloth Studio Chat",
374
+ ),
375
+ textbox = gradio.Textbox(
376
+ placeholder = "Message Unsloth Chat",
377
+ container = False,
378
+ ),
379
+ title = None,
380
+ theme = studio_theme,
381
+ examples = None,
382
+ cache_examples = False,
383
+ retry_btn = None,
384
+ undo_btn = "Remove Previous Message",
385
+ clear_btn = "Restart Entire Chat",
386
+ )
387
+
388
+ scene.launch(quiet = True)
389
+ ```
390
+
391
+
392
  This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
393
 
394
  [<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)