praeclarumjj3 commited on
Commit
9c4915e
·
1 Parent(s): f598a68

Update chat.py

Browse files
Files changed (1) hide show
  1. chat.py +25 -25
chat.py CHANGED
@@ -131,7 +131,7 @@ class Chat:
131
  keywords = [stop_str]
132
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
133
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
134
-
135
  max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens - num_seg_tokens - num_depth_tokens)
136
 
137
  if max_new_tokens < 1:
@@ -159,30 +159,30 @@ class Chat:
159
  yield json.dumps({"text": generated_text, "error_code": 0}).encode()
160
 
161
  def generate_stream_gate(self, params):
162
- # try:
163
- for x in self.generate_stream(params):
164
- yield x
165
- # except ValueError as e:
166
- # print("Caught ValueError:", e)
167
- # ret = {
168
- # "text": server_error_msg,
169
- # "error_code": 1,
170
- # }
171
- # yield json.dumps(ret).encode()
172
- # except torch.cuda.CudaError as e:
173
- # print("Caught torch.cuda.CudaError:", e)
174
- # ret = {
175
- # "text": server_error_msg,
176
- # "error_code": 1,
177
- # }
178
- # yield json.dumps(ret).encode()
179
- # except Exception as e:
180
- # print("Caught Unknown Error", e)
181
- # ret = {
182
- # "text": server_error_msg,
183
- # "error_code": 1,
184
- # }
185
- # yield json.dumps(ret).encode()
186
 
187
 
188
  if __name__ == "__main__":
 
131
  keywords = [stop_str]
132
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
133
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
134
+
135
  max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens - num_seg_tokens - num_depth_tokens)
136
 
137
  if max_new_tokens < 1:
 
159
  yield json.dumps({"text": generated_text, "error_code": 0}).encode()
160
 
161
  def generate_stream_gate(self, params):
162
+ try:
163
+ for x in self.generate_stream(params):
164
+ yield x
165
+ except ValueError as e:
166
+ print("Caught ValueError:", e)
167
+ ret = {
168
+ "text": server_error_msg,
169
+ "error_code": 1,
170
+ }
171
+ yield json.dumps(ret).encode()
172
+ except torch.cuda.CudaError as e:
173
+ print("Caught torch.cuda.CudaError:", e)
174
+ ret = {
175
+ "text": server_error_msg,
176
+ "error_code": 1,
177
+ }
178
+ yield json.dumps(ret).encode()
179
+ except Exception as e:
180
+ print("Caught Unknown Error", e)
181
+ ret = {
182
+ "text": server_error_msg,
183
+ "error_code": 1,
184
+ }
185
+ yield json.dumps(ret).encode()
186
 
187
 
188
  if __name__ == "__main__":