jeffeux commited on
Commit
c9d3334
·
1 Parent(s): eca00c9
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -2,7 +2,6 @@
2
  import os, logging, torch, streamlit as st
3
  from transformers import (
4
  AutoTokenizer, AutoModelForCausalLM)
5
- st.balloons()
6
 
7
  # --------------------- HELPER --------------------- #
8
  def C(text, color="yellow"):
@@ -18,18 +17,31 @@ def C(text, color="yellow"):
18
  return (
19
  f"{color_dict.get(color, None)}"
20
  f"{text}{color_dict[None]}")
21
- st.balloons()
22
 
23
  # ------------------ ENVIORNMENT ------------------- #
24
  os.environ["HF_ENDPOINT"] = "https://huggingface.co"
25
  device = ("cuda"
26
  if torch.cuda.is_available() else "cpu")
27
  logging.info(C("[INFO] "f"device = {device}"))
28
- st.balloons()
29
 
30
  # ------------------ INITITALIZE ------------------- #
31
  @st.cache
32
  def model_init():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  tokenizer = AutoTokenizer.from_pretrained(
34
  "ckip-joint/bloom-1b1-zh")
35
  model = AutoModelForCausalLM.from_pretrained(
@@ -44,14 +56,10 @@ def model_init():
44
  return tokenizer, model
45
 
46
  tokenizer, model = model_init()
47
- st.balloons()
48
 
49
  try:
50
  # ===================== INPUT ====================== #
51
- # prompt = "\u554F\uFF1A\u53F0\u7063\u6700\u9AD8\u7684\u5EFA\u7BC9\u7269\u662F\uFF1F\u7B54\uFF1A" #@param {type:"string"}
52
  prompt = st.text_input("Prompt: ")
53
- st.balloons()
54
-
55
 
56
  # =================== INFERENCE ==================== #
57
  if prompt:
@@ -59,13 +67,13 @@ try:
59
  with torch.no_grad():
60
  [texts_out] = model.generate(
61
  **tokenizer(
62
- prompt, return_tensors="pt"
 
63
  ).to(device))
64
- st.balloons()
65
  output_text = tokenizer.decode(texts_out)
66
  st.balloons()
67
  st.markdown(output_text)
68
- st.balloons()
69
  except Exception as err:
70
  st.write(str(err))
71
  st.snow()
 
2
  import os, logging, torch, streamlit as st
3
  from transformers import (
4
  AutoTokenizer, AutoModelForCausalLM)
 
5
 
6
  # --------------------- HELPER --------------------- #
7
  def C(text, color="yellow"):
 
17
  return (
18
  f"{color_dict.get(color, None)}"
19
  f"{text}{color_dict[None]}")
 
20
 
21
  # ------------------ ENVIORNMENT ------------------- #
22
  os.environ["HF_ENDPOINT"] = "https://huggingface.co"
23
  device = ("cuda"
24
  if torch.cuda.is_available() else "cpu")
25
  logging.info(C("[INFO] "f"device = {device}"))
 
26
 
27
  # ------------------ INITITALIZE ------------------- #
28
  @st.cache
29
  def model_init():
30
+
31
+
32
+
33
+ from transformers import GenerationConfig
34
+
35
+ # generation_config, unused_kwargs = GenerationConfig.from_pretrained(
36
+ # "ckip-joint/bloom-1b1-zh",
37
+ # max_new_tokens=200,
38
+ # return_unused_kwargs=True)
39
+
40
+
41
+
42
+
43
+
44
+
45
  tokenizer = AutoTokenizer.from_pretrained(
46
  "ckip-joint/bloom-1b1-zh")
47
  model = AutoModelForCausalLM.from_pretrained(
 
56
  return tokenizer, model
57
 
58
  tokenizer, model = model_init()
 
59
 
60
  try:
61
  # ===================== INPUT ====================== #
 
62
  prompt = st.text_input("Prompt: ")
 
 
63
 
64
  # =================== INFERENCE ==================== #
65
  if prompt:
 
67
  with torch.no_grad():
68
  [texts_out] = model.generate(
69
  **tokenizer(
70
+ prompt, return_tensors="pt",
71
+ max_new_tokens=200,
72
  ).to(device))
 
73
  output_text = tokenizer.decode(texts_out)
74
  st.balloons()
75
  st.markdown(output_text)
76
+
77
  except Exception as err:
78
  st.write(str(err))
79
  st.snow()