vilarin commited on
Commit
99a7a45
·
verified ·
1 Parent(s): d5d8ee3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -17
app.py CHANGED
@@ -8,14 +8,13 @@ from threading import Thread
8
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
- MODEL_ID = "THUDM/glm-4-9b-chat"
12
- MODEL_ID2 = "THUDM/glm-4-9b-chat-1m"
13
- MODELS = os.environ.get("MODELS")
14
- MODEL_NAME = MODELS.split("/")[-1]
15
 
16
- TITLE = "<h1><center>GLM-4-9B</center></h1>"
17
 
18
- DESCRIPTION = f'<h3><center>MODEL: <a href="https://hf.co/{MODELS}">{MODEL_NAME}</a></center></h3>'
19
 
20
  CSS = """
21
  .duplicate-button {
@@ -26,18 +25,26 @@ CSS = """
26
  }
27
  """
28
 
29
- model = AutoModelForCausalLM.from_pretrained(
30
- MODELS,
31
  torch_dtype=torch.bfloat16,
32
  low_cpu_mem_usage=True,
33
  trust_remote_code=True,
34
  ).to(0).eval()
35
 
36
- tokenizer = AutoTokenizer.from_pretrained(MODELS,trust_remote_code=True)
37
 
 
 
 
 
 
 
 
 
38
 
39
  @spaces.GPU
40
- def stream_chat(message: str, history: list, temperature: float, max_length: int):
41
  print(f'message is - {message}')
42
  print(f'history is - {history}')
43
  conversation = []
@@ -46,7 +53,14 @@ def stream_chat(message: str, history: list, temperature: float, max_length: int
46
  conversation.append({"role": "user", "content": message})
47
 
48
  print(f"Conversation is -\n{conversation}")
49
-
 
 
 
 
 
 
 
50
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
51
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
52
 
@@ -68,14 +82,10 @@ def stream_chat(message: str, history: list, temperature: float, max_length: int
68
  buffer += new_text
69
  yield buffer
70
 
71
-
72
-
73
-
74
- chatbot = gr.Chatbot(height=450)
75
 
76
  with gr.Blocks(css=CSS) as demo:
77
  gr.HTML(TITLE)
78
- gr.HTML(DESCRIPTION)
79
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
80
  gr.ChatInterface(
81
  fn=stream_chat,
@@ -99,6 +109,10 @@ with gr.Blocks(css=CSS) as demo:
99
  label="Max Length",
100
  render=False,
101
  ),
 
 
 
 
102
  ],
103
  examples=[
104
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
@@ -108,7 +122,7 @@ with gr.Blocks(css=CSS) as demo:
108
  ],
109
  cache_examples=False,
110
  )
111
-
112
 
113
  if __name__ == "__main__":
114
  demo.launch()
 
8
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
+ MODEL_LIST = "THUDM/glm-4-9b-chat, THUDM/glm-4-9b-chat-1m, THUDM/codegeex4-all-9b"
12
+ #MODELS = os.environ.get("MODELS")
13
+ #MODEL_NAME = MODELS.split("/")[-1]
 
14
 
15
+ TITLE = "<h1><center>GLM SPACE</center></h1>"
16
 
17
+ PLACEHOLDER = f'<h3><center>Feel Free To Test GLM</center></h3>'
18
 
19
  CSS = """
20
  .duplicate-button {
 
25
  }
26
  """
27
 
28
+ model_chat = AutoModelForCausalLM.from_pretrained(
29
+ "THUDM/glm-4-9b-chat",
30
  torch_dtype=torch.bfloat16,
31
  low_cpu_mem_usage=True,
32
  trust_remote_code=True,
33
  ).to(0).eval()
34
 
35
+ tokenizer_chat = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat",trust_remote_code=True)
36
 
37
+ model_code = AutoModelForCausalLM.from_pretrained(
38
+ "THUDM/codegeex4-all-9b",
39
+ torch_dtype=torch.bfloat16,
40
+ low_cpu_mem_usage=True,
41
+ trust_remote_code=True
42
+ ).to(device).eval()
43
+
44
+ tokenizer_code = AutoTokenizer.from_pretrained("THUDM/codegeex4-all-9b", trust_remote_code=True)
45
 
46
  @spaces.GPU
47
+ def stream_chat(message: str, history: list, temperature: float, max_length: int, model: str):
48
  print(f'message is - {message}')
49
  print(f'history is - {history}')
50
  conversation = []
 
53
  conversation.append({"role": "user", "content": message})
54
 
55
  print(f"Conversation is -\n{conversation}")
56
+
57
+ if mode == "glm-4-9b-chat":
58
+ tokenizer = tokenizer_chat
59
+ model = model_chat
60
+ else:
61
+ model = model_code
62
+ tokenizer = tokenizer_code
63
+
64
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
65
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
66
 
 
82
  buffer += new_text
83
  yield buffer
84
 
85
+ chatbot = gr.Chatbot(height=600, placeholder = PLACEHOLDER)
 
 
 
86
 
87
  with gr.Blocks(css=CSS) as demo:
88
  gr.HTML(TITLE)
 
89
  gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
90
  gr.ChatInterface(
91
  fn=stream_chat,
 
109
  label="Max Length",
110
  render=False,
111
  ),
112
+ choice = gr.Radio(
113
+ ["glm-4-9b-chat", "codegeex4-all-9b"],
114
+ label="Load Model"
115
+ ),
116
  ],
117
  examples=[
118
  ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
 
122
  ],
123
  cache_examples=False,
124
  )
125
+
126
 
127
  if __name__ == "__main__":
128
  demo.launch()