alexkueck commited on
Commit
d130b3f
·
1 Parent(s): bbb647c

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +27 -129
utils.py CHANGED
@@ -1,4 +1,29 @@
1
- from app_modules.presets import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  logging.basicConfig(
4
  level=logging.INFO,
@@ -180,7 +205,6 @@ def cancel_outputing():
180
  return "Stop Done"
181
 
182
  def transfer_input(inputs):
183
- # 一次性返回,降低延迟
184
  textbox = reset_textbox()
185
  return (
186
  inputs,
@@ -203,79 +227,7 @@ shared_state = State()
203
 
204
 
205
 
206
- # Greedy Search
207
- def greedy_search(input_ids: torch.Tensor,
208
- model: torch.nn.Module,
209
- tokenizer: transformers.PreTrainedTokenizer,
210
- stop_words: list,
211
- max_length: int,
212
- temperature: float = 1.0,
213
- top_p: float = 1.0,
214
- top_k: int = 25) -> Iterator[str]:
215
- generated_tokens = []
216
- past_key_values = None
217
- current_length = 1
218
- for i in range(max_length):
219
- with torch.no_grad():
220
- if past_key_values is None:
221
- outputs = model(input_ids)
222
- else:
223
- outputs = model(input_ids[:, -1:], past_key_values=past_key_values)
224
- logits = outputs.logits[:, -1, :]
225
- past_key_values = outputs.past_key_values
226
-
227
- # apply temperature
228
- logits /= temperature
229
-
230
- probs = torch.softmax(logits, dim=-1)
231
- # apply top_p
232
- probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
233
- probs_sum = torch.cumsum(probs_sort, dim=-1)
234
- mask = probs_sum - probs_sort > top_p
235
- probs_sort[mask] = 0.0
236
-
237
- # apply top_k
238
- #if top_k is not None:
239
- # probs_sort1, _ = torch.topk(probs_sort, top_k)
240
- # min_top_probs_sort = torch.min(probs_sort1, dim=-1, keepdim=True).values
241
- # probs_sort = torch.where(probs_sort < min_top_probs_sort, torch.full_like(probs_sort, float(0.0)), probs_sort)
242
-
243
- probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
244
- next_token = torch.multinomial(probs_sort, num_samples=1)
245
- next_token = torch.gather(probs_idx, -1, next_token)
246
-
247
- input_ids = torch.cat((input_ids, next_token), dim=-1)
248
-
249
- generated_tokens.append(next_token[0].item())
250
- text = tokenizer.decode(generated_tokens)
251
-
252
- yield text
253
- if any([x in text for x in stop_words]):
254
- del past_key_values
255
- del logits
256
- del probs
257
- del probs_sort
258
- del probs_idx
259
- del probs_sum
260
- gc.collect()
261
- return
262
-
263
- def generate_prompt_with_history(text,history,tokenizer,max_length=2048):
264
- prompt = "The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n[|Human|]Hello!\n[|AI|]Hi!"
265
- history = ["\n[|Human|]{}\n[|AI|]{}".format(x[0],x[1]) for x in history]
266
- history.append("\n[|Human|]{}\n[|AI|]".format(text))
267
- history_text = ""
268
- flag = False
269
- for x in history[::-1]:
270
- if tokenizer(prompt+history_text+x, return_tensors="pt")['input_ids'].size(-1) <= max_length:
271
- history_text = x + history_text
272
- flag = True
273
- else:
274
- break
275
- if flag:
276
- return prompt+history_text,tokenizer(prompt+history_text, return_tensors="pt")
277
- else:
278
- return None
279
 
280
 
281
  def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
@@ -289,57 +241,3 @@ def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
289
 
290
 
291
 
292
- def load_tokenizer_and_model(base_model,adapter_model=None,load_8bit=False):
293
- if torch.cuda.is_available():
294
- device = "cuda"
295
- else:
296
- device = "cpu"
297
-
298
- try:
299
- if torch.backends.mps.is_available():
300
- device = "mps"
301
- except: # noqa: E722
302
- pass
303
- tokenizer = LlamaTokenizer.from_pretrained(base_model)
304
- if device == "cuda":
305
- model = LlamaForCausalLM.from_pretrained(
306
- base_model,
307
- load_in_8bit=load_8bit,
308
- torch_dtype=torch.float16,
309
- device_map="auto",
310
- )
311
- if adapter_model is not None:
312
- model = PeftModel.from_pretrained(
313
- model,
314
- adapter_model,
315
- torch_dtype=torch.float16,
316
- )
317
- elif device == "mps":
318
- model = LlamaForCausalLM.from_pretrained(
319
- base_model,
320
- device_map={"": device},
321
- torch_dtype=torch.float16,
322
- )
323
- if adapter_model is not None:
324
- model = PeftModel.from_pretrained(
325
- model,
326
- adapter_model,
327
- device_map={"": device},
328
- torch_dtype=torch.float16,
329
- )
330
- else:
331
- model = LlamaForCausalLM.from_pretrained(
332
- base_model, device_map={"": device}, low_cpu_mem_usage=True
333
- )
334
- if adapter_model is not None:
335
- model = PeftModel.from_pretrained(
336
- model,
337
- adapter_model,
338
- device_map={"": device},
339
- )
340
-
341
- if not load_8bit:
342
- model.half() # seems to fix bugs for some users.
343
-
344
- model.eval()
345
- return tokenizer,model,device
 
1
+ from __future__ import annotations
2
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
3
+ import logging
4
+ import json
5
+ import os
6
+ import datetime
7
+ import hashlib
8
+ import csv
9
+ import requests
10
+ import re
11
+ import html
12
+ import markdown2
13
+ import torch
14
+ import sys
15
+ import gc
16
+ from pygments.lexers import guess_lexer, ClassNotFound
17
+
18
+ import gradio as gr
19
+ from pypinyin import lazy_pinyin
20
+ import tiktoken
21
+ import mdtex2html
22
+ from markdown import markdown
23
+ from pygments import highlight
24
+ from pygments.lexers import guess_lexer,get_lexer_by_name
25
+ from pygments.formatters import HtmlFormatter
26
+ from beschreibungen import *
27
 
28
  logging.basicConfig(
29
  level=logging.INFO,
 
205
  return "Stop Done"
206
 
207
  def transfer_input(inputs):
 
208
  textbox = reset_textbox()
209
  return (
210
  inputs,
 
227
 
228
 
229
 
230
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
 
233
  def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
 
241
 
242
 
243