zhang4096 commited on
Commit
dee839b
·
verified ·
1 Parent(s): 8687776

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +277 -0
app.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # isort: skip_file
3
+ import copy
4
+ import warnings
5
+ from dataclasses import asdict, dataclass
6
+ from typing import Callable, List, Optional
7
+
8
+ import streamlit as st
9
+ import torch
10
+ from torch import nn
11
+ from transformers.generation.utils import (LogitsProcessorList,
12
+ StoppingCriteriaList)
13
+ from transformers.utils import logging
14
+
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
16
+
17
+ logger = logging.get_logger(__name__)
18
+ model_name_or_path="zhang4096/intern_study_L0_4"
19
+
20
+ @dataclass
21
+ class GenerationConfig:
22
+ # this config is used for chat to provide more diversity
23
+ max_length: int = 32768
24
+ top_p: float = 0.8
25
+ temperature: float = 0.8
26
+ do_sample: bool = True
27
+ repetition_penalty: float = 1.005
28
+
29
+
30
+ @torch.inference_mode()
31
+ def generate_interactive(
32
+ model,
33
+ tokenizer,
34
+ prompt,
35
+ generation_config: Optional[GenerationConfig] = None,
36
+ logits_processor: Optional[LogitsProcessorList] = None,
37
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
38
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
39
+ List[int]]] = None,
40
+ additional_eos_token_id: Optional[int] = None,
41
+ **kwargs,
42
+ ):
43
+ inputs = tokenizer([prompt], padding=True, return_tensors='pt')
44
+ input_length = len(inputs['input_ids'][0])
45
+ for k, v in inputs.items():
46
+ inputs[k] = v.cuda()
47
+ input_ids = inputs['input_ids']
48
+ _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
49
+ if generation_config is None:
50
+ generation_config = model.generation_config
51
+ generation_config = copy.deepcopy(generation_config)
52
+ model_kwargs = generation_config.update(**kwargs)
53
+ bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
54
+ generation_config.bos_token_id,
55
+ generation_config.eos_token_id,
56
+ )
57
+ if isinstance(eos_token_id, int):
58
+ eos_token_id = [eos_token_id]
59
+ if additional_eos_token_id is not None:
60
+ eos_token_id.append(additional_eos_token_id)
61
+ has_default_max_length = kwargs.get(
62
+ 'max_length') is None and generation_config.max_length is not None
63
+ if has_default_max_length and generation_config.max_new_tokens is None:
64
+ warnings.warn(
65
+ f"Using 'max_length''s default \
66
+ ({repr(generation_config.max_length)}) \
67
+ to control the generation length. "
68
+ 'This behaviour is deprecated and will be removed from the \
69
+ config in v5 of Transformers -- we'
70
+ ' recommend using `max_new_tokens` to control the maximum \
71
+ length of the generation.',
72
+ UserWarning,
73
+ )
74
+ elif generation_config.max_new_tokens is not None:
75
+ generation_config.max_length = generation_config.max_new_tokens + \
76
+ input_ids_seq_length
77
+ if not has_default_max_length:
78
+ logger.warn( # pylint: disable=W4902
79
+ f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
80
+ f"and 'max_length'(={generation_config.max_length}) seem to "
81
+ "have been set. 'max_new_tokens' will take precedence. "
82
+ 'Please refer to the documentation for more information. '
83
+ '(https://huggingface.co/docs/transformers/main/'
84
+ 'en/main_classes/text_generation)',
85
+ UserWarning,
86
+ )
87
+
88
+ if input_ids_seq_length >= generation_config.max_length:
89
+ input_ids_string = 'input_ids'
90
+ logger.warning(
91
+ f'Input length of {input_ids_string} is {input_ids_seq_length}, '
92
+ f"but 'max_length' is set to {generation_config.max_length}. "
93
+ 'This can lead to unexpected behavior. You should consider'
94
+ " increasing 'max_new_tokens'.")
95
+
96
+ # 2. Set generation parameters if not already defined
97
+ logits_processor = logits_processor if logits_processor is not None \
98
+ else LogitsProcessorList()
99
+ stopping_criteria = stopping_criteria if stopping_criteria is not None \
100
+ else StoppingCriteriaList()
101
+
102
+ logits_processor = model._get_logits_processor(
103
+ generation_config=generation_config,
104
+ input_ids_seq_length=input_ids_seq_length,
105
+ encoder_input_ids=input_ids,
106
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
107
+ logits_processor=logits_processor,
108
+ )
109
+
110
+ stopping_criteria = model._get_stopping_criteria(
111
+ generation_config=generation_config,
112
+ stopping_criteria=stopping_criteria)
113
+ logits_warper = model._get_logits_warper(generation_config)
114
+
115
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
116
+ scores = None
117
+ while True:
118
+ model_inputs = model.prepare_inputs_for_generation(
119
+ input_ids, **model_kwargs)
120
+ # forward pass to get next token
121
+ outputs = model(
122
+ **model_inputs,
123
+ return_dict=True,
124
+ output_attentions=False,
125
+ output_hidden_states=False,
126
+ )
127
+
128
+ next_token_logits = outputs.logits[:, -1, :]
129
+
130
+ # pre-process distribution
131
+ next_token_scores = logits_processor(input_ids, next_token_logits)
132
+ next_token_scores = logits_warper(input_ids, next_token_scores)
133
+
134
+ # sample
135
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
136
+ if generation_config.do_sample:
137
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
138
+ else:
139
+ next_tokens = torch.argmax(probs, dim=-1)
140
+
141
+ # update generated ids, model inputs, and length for next step
142
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
143
+ model_kwargs = model._update_model_kwargs_for_generation(
144
+ outputs, model_kwargs, is_encoder_decoder=False)
145
+ unfinished_sequences = unfinished_sequences.mul(
146
+ (min(next_tokens != i for i in eos_token_id)).long())
147
+
148
+ output_token_ids = input_ids[0].cpu().tolist()
149
+ output_token_ids = output_token_ids[input_length:]
150
+ for each_eos_token_id in eos_token_id:
151
+ if output_token_ids[-1] == each_eos_token_id:
152
+ output_token_ids = output_token_ids[:-1]
153
+ response = tokenizer.decode(output_token_ids)
154
+
155
+ yield response
156
+ # stop when each sentence is finished
157
+ # or if we exceed the maximum length
158
+ if unfinished_sequences.max() == 0 or stopping_criteria(
159
+ input_ids, scores):
160
+ break
161
+
162
+
163
+ def on_btn_click():
164
+ del st.session_state.messages
165
+
166
+
167
+ @st.cache_resource
168
+ def load_model():
169
+ model = (AutoModelForCausalLM.from_pretrained(
170
+ model_name_or_path,
171
+ trust_remote_code=True).to(torch.bfloat16).cuda())
172
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
173
+ trust_remote_code=True)
174
+ return model, tokenizer
175
+
176
+
177
+ def prepare_generation_config():
178
+ with st.sidebar:
179
+ max_length = st.slider('Max Length',
180
+ min_value=8,
181
+ max_value=32768,
182
+ value=32768)
183
+ top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
184
+ temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
185
+ st.button('Clear Chat History', on_click=on_btn_click)
186
+
187
+ generation_config = GenerationConfig(max_length=max_length,
188
+ top_p=top_p,
189
+ temperature=temperature)
190
+
191
+ return generation_config
192
+
193
+
194
+ user_prompt = '<|im_start|>user\n{user}<|im_end|>\n'
195
+ robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n'
196
+ cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
197
+ <|im_start|>assistant\n'
198
+
199
+
200
+ def combine_history(prompt):
201
+ messages = st.session_state.messages
202
+ meta_instruction = ('You are a helpful, honest, '
203
+ 'and harmless AI assistant.')
204
+ total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
205
+ for message in messages:
206
+ cur_content = message['content']
207
+ if message['role'] == 'user':
208
+ cur_prompt = user_prompt.format(user=cur_content)
209
+ elif message['role'] == 'robot':
210
+ cur_prompt = robot_prompt.format(robot=cur_content)
211
+ else:
212
+ raise RuntimeError
213
+ total_prompt += cur_prompt
214
+ total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
215
+ return total_prompt
216
+
217
+
218
+ def main():
219
+ st.title('internlm2_5-7b-chat-assistant')
220
+
221
+ # torch.cuda.empty_cache()
222
+ print('load model begin.')
223
+ model, tokenizer = load_model()
224
+ print('load model end.')
225
+
226
+ generation_config = prepare_generation_config()
227
+
228
+ # Initialize chat history
229
+ if 'messages' not in st.session_state:
230
+ st.session_state.messages = []
231
+
232
+ # Display chat messages from history on app rerun
233
+ for message in st.session_state.messages:
234
+ with st.chat_message(message['role'], avatar=message.get('avatar')):
235
+ st.markdown(message['content'])
236
+
237
+ # Accept user input
238
+ if prompt := st.chat_input('What is up?'):
239
+ # Display user message in chat message container
240
+
241
+ with st.chat_message('user', avatar='user'):
242
+
243
+ st.markdown(prompt)
244
+ real_prompt = combine_history(prompt)
245
+ # Add user message to chat history
246
+ st.session_state.messages.append({
247
+ 'role': 'user',
248
+ 'content': prompt,
249
+ 'avatar': 'user'
250
+ })
251
+
252
+ with st.chat_message('robot', avatar='assistant'):
253
+
254
+ message_placeholder = st.empty()
255
+ for cur_response in generate_interactive(
256
+ model=model,
257
+ tokenizer=tokenizer,
258
+ prompt=real_prompt,
259
+ additional_eos_token_id=92542,
260
+ device='cuda:0',
261
+ **asdict(generation_config),
262
+ ):
263
+ # Display robot response in chat message container
264
+ message_placeholder.markdown(cur_response + '▌')
265
+ message_placeholder.markdown(cur_response)
266
+ # Add robot response to chat history
267
+ st.session_state.messages.append({
268
+ 'role': 'robot',
269
+ 'content': cur_response, # pylint: disable=undefined-loop-variable
270
+ 'avatar': 'assistant',
271
+ })
272
+ torch.cuda.empty_cache()
273
+
274
+
275
+ if __name__ == '__main__':
276
+ main()
277
+