Jae-Won Chung commited on
Commit
764dce6
·
1 Parent(s): 9f1c84b

Push `benchmark.py` from fix_stop_str

Browse files
Files changed (1) hide show
  1. scripts/benchmark.py +124 -138
scripts/benchmark.py CHANGED
@@ -7,8 +7,8 @@ import json
7
  import copy
8
  import atexit
9
  from typing import Generator, Literal, Iterable, Dict
 
10
 
11
- import gc
12
  import numpy as np
13
  import tyro
14
  import torch
@@ -16,6 +16,7 @@ import rich
16
  from rich.table import Table
17
  from fastchat.serve.inference import prepare_logits_processor
18
  from fastchat.model.model_adapter import load_model, get_conversation_template
 
19
  from zeus.monitor import ZeusMonitor
20
 
21
  SYSTEM_PROMPTS = {
@@ -39,21 +40,20 @@ SYSTEM_PROMPTS = {
39
  ),
40
  }
41
 
42
- def is_partial_stop(output: str, stop_str: str):
43
- """Check whether the output contains a partial stop str."""
44
- for i in range(0, min(len(output), len(stop_str))):
45
- if stop_str.startswith(output[-i:]):
46
- return True
47
- return False
48
 
49
  @torch.inference_mode()
50
- def generate_stream(
51
  model,
52
  tokenizer,
53
  params: Dict,
54
  device: str,
55
  context_len: int = 2048,
56
- ):
57
  # Read parameters
58
  prompts = params["prompt"]
59
  temperature = float(params.get("temperature", 1.0))
@@ -62,10 +62,16 @@ def generate_stream(
62
  top_k = int(params.get("top_k", -1)) # -1 means disable
63
  max_new_tokens = int(params.get("max_new_tokens", 256))
64
  stop_str = params.get("stop", None)
65
- stop_token_ids = params.get("stop_token_ids", None) or []
66
  stop_token_ids.append(tokenizer.eos_token_id)
67
  batch_size = len(prompts)
68
 
 
 
 
 
 
 
69
  # left append prompts with eos to make all input prompts the same length
70
  tokenizer.padding_side = "left"
71
  tokenizer.pad_token = tokenizer.eos_token
@@ -75,15 +81,14 @@ def generate_stream(
75
  )
76
 
77
  input_ids = tokenizer(prompts, padding=True).input_ids
78
- output_ids = list(input_ids)
79
 
80
  if model.config.is_encoder_decoder:
81
  max_src_len = context_len
82
  else: # truncate
83
- max_src_len = context_len - max_new_tokens - 8
84
 
85
  input_ids = [input_id[-max_src_len:] for input_id in input_ids]
86
- input_len = len(input_ids[0])
87
 
88
  if model.config.is_encoder_decoder:
89
  encoder_output = model.encoder(
@@ -141,10 +146,10 @@ def generate_stream(
141
  else:
142
  last_token_logits = logits[:, -1, :]
143
 
144
- if device == "mps":
145
- # Switch to CPU by avoiding some bugs in mps backend.
146
- last_token_logits = last_token_logits.float().to("cpu")
147
-
148
  if temperature < 1e-5 or top_p < 1e-8: # greedy
149
  _, indices = torch.topk(last_token_logits, 2)
150
  tokens = [[int(token) for token in query] for query in indices.tolist()]
@@ -152,81 +157,70 @@ def generate_stream(
152
  probs = torch.softmax(last_token_logits, dim=-1)
153
  indices = torch.multinomial(probs, num_samples=2)
154
  tokens = [[int(token) for token in query] for query in indices.tolist()]
 
 
155
 
 
156
  old_stopped = stopped
157
  stopped = np.logical_or(old_stopped, np.array([True if token[0] in stop_token_ids else False for token in tokens]))
158
- output_ids = [ids + [token[0]] for ids, token in zip(output_ids, tokens)]
159
 
160
- def slice(s, pos):
161
- return s[:pos]
162
- vec_slice = np.vectorize(slice, otypes=[str])
163
- vec_is_partial_stop = np.vectorize(is_partial_stop)
164
-
165
- # Yield the output tokens
166
- if any(stopped):
167
- tmp_output_ids = [ids[input_len:] for ids in output_ids]
168
- rfind_start = 0
169
- output = tokenizer.batch_decode(
170
- tmp_output_ids,
171
- skip_special_tokens=True,
172
- spaces_between_special_tokens=False,
173
- clean_up_tokenization_spaces=True,
174
- )
175
- output = np.array(output)
176
-
177
- partially_stopped = np.array(len(output_ids) * [False])
178
- different_indices = np.empty(shape=(0,))
179
- if stop_str:
180
- if isinstance(stop_str, str):
181
- pos_array = np.char.rfind(output, stop_str, rfind_start)
 
182
  find_stop = pos_array != -1
183
- output[find_stop] = vec_slice(output[find_stop], pos_array[find_stop])
184
- stopped = find_stop
185
- partially_stopped = vec_is_partial_stop(output, stop_str)
186
- elif isinstance(stop_str, Iterable):
187
- for each_stop in stop_str:
188
- pos_array = np.char.rfind(output, stop_str, rfind_start)
189
- find_stop = pos_array != -1
190
- output[find_stop] = vec_slice(output[find_stop], pos_array[find_stop])
191
- stopped = find_stop
192
- partially_stopped = partially_stopped | vec_is_partial_stop(output, each_stop)
193
- else:
194
- raise ValueError("Invalid stop field type.")
195
-
196
- # Prevent yielding partial stop sequence
197
- if not any(partially_stopped):
198
- # indicates which request in batch stopped
199
- different_indices = np.where(stopped != old_stopped)[0]
200
- stop_length = np.array([(j, i+1) for j in different_indices])
201
- yield {
202
- "text": output,
203
- "stop_length": stop_length,
204
- }
205
 
206
  if all(stopped):
207
  break
208
 
209
- false_indices = np.where(stopped == False)[0]
210
  if any(stopped) == False:
211
- tmp_output_ids = [ids[input_len:] for ids in output_ids]
212
  output = tokenizer.batch_decode(
213
- tmp_output_ids,
214
  skip_special_tokens=True,
215
  spaces_between_special_tokens=False,
216
  clean_up_tokenization_spaces=True,
217
  )
218
- stop_length = np.array([(i, max_new_tokens) for i in false_indices])
219
 
220
- yield {
221
- "text": output,
222
- "stop_length": stop_length,
223
- }
224
 
225
- # Clean
226
- del past_key_values, out
227
- gc.collect()
228
- torch.cuda.empty_cache()
229
 
 
 
 
230
 
231
  def main(
232
  model_path: str,
@@ -347,108 +341,100 @@ def main(
347
  "temperature": temperature,
348
  "repitition_penalty": repitition_penalty,
349
  "max_new_tokens": max_new_tokens,
 
350
  },
351
  config_json,
352
  indent=4,
353
  )
354
  config_json.write("\n")
355
 
356
- def dataloader(input_file: str) -> Generator[tuple[bool, str], None, None]:
 
 
 
 
 
 
 
 
 
 
 
 
357
  """Yields a tuple of whether this is a warmup run and the input prompt."""
358
- for _ in range(3*batch):
359
- yield True, "Say something long and random. I don't care about the content."
360
- for item in json.load(open(input_file, "r")):
361
- input_prompt = item["conversations"][0]["value"]
362
- yield False, input_prompt
 
 
363
 
364
  # Warm up the GPU with some random prompts.
365
  # Forward through all the prompts.
366
  is_first = True
367
  convs = []
368
  prompts = []
369
- data_iter = iter(dataloader(input_file))
370
-
371
- end_of_file = False # flag to track the end of the file
372
- while True:
373
- try:
374
- is_warmup, input_prompt = next(data_iter)
375
- except StopIteration:
376
- end_of_file = True # no more data
377
-
378
  # Construct the input prompt.
379
- if not end_of_file:
380
  conv = copy.deepcopy(conv_base)
381
- conv.append_message(conv.roles[0], input_prompt)
382
  conv.append_message(conv.roles[1], "")
383
  prompt = conv.get_prompt()
384
  prompts.append(prompt)
385
  convs.append(conv)
386
- if (len(convs) < batch): continue
387
  gen_params["prompt"] = prompts
388
- if end_of_file and len(prompts) == 0:
389
- break
390
 
391
  # Print input prompt.
392
  for i in range(len(convs)):
393
  console.print(f"\n[u cyan]{'Warmup ' if is_warmup else ''}Prompt[/u cyan](batch_{i}):")
394
  console.print(prompts[i].strip() + "\n", markup=False)
395
 
396
- # Generate the ouptut from the model.
397
- output_stream = generate_stream(model, tokenizer, gen_params, device="cuda", context_len=2048)
398
- output = {}
399
- batch_token_len = {}
400
-
401
  #################################################
402
  # Inference and measurement zone!
403
  #################################################
404
  monitor.begin_window("inference")
405
- for output in output_stream:
406
- stop_length = output["stop_length"]
407
- for it in stop_length:
408
- batch_token_len[it[0]] = it[1]
409
  measurements = monitor.end_window("inference")
410
  #################################################
411
-
412
- # Record numbers.
413
- output_text = output["text"]
414
- if not is_warmup:
415
- total_length = int(sum(batch_token_len.values())) # number of valid tokens
416
- response_length = float(total_length) / len(convs)
417
- latency = measurements.time
418
- throughput = response_length / latency
419
- energy = measurements.total_energy
420
- output = {
421
- "model": model_name_cleaned,
422
- "batch": len(convs),
423
- "throughput": throughput,
424
- "response_length": response_length,
425
- "latency": latency,
426
- "energy": energy,
427
- "input": [prompt.strip() for prompt in prompts],
428
- "output": [output_text[i][:batch_token_len[i]].strip() for i in range(len(convs))],
429
- }
430
- output_str = json.dumps(output, indent=4)
431
  if not is_warmup:
432
- if not is_first:
433
- output_json.write(",\n" + output_str)
434
- else:
435
- is_first = False
436
- output_json.write(output_str)
437
- output_json.flush()
438
-
439
- # Print the response.
440
- for i in range(len(convs)):
441
- console.print(f"\n[u cyan]{'Warmup ' if is_warmup else ''}Response[/u cyan](batch_{i}):")
442
- console.print(output_text[i][:batch_token_len[i]].strip() + "\n", markup=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
  # Print measurement.
445
  console.print(measurements)
446
  convs = []
447
  prompts = []
448
 
449
- if end_of_file:
450
- break
451
-
452
-
453
  if __name__ == "__main__":
454
  tyro.cli(main)
 
7
  import copy
8
  import atexit
9
  from typing import Generator, Literal, Iterable, Dict
10
+ from dataclasses import dataclass
11
 
 
12
  import numpy as np
13
  import tyro
14
  import torch
 
16
  from rich.table import Table
17
  from fastchat.serve.inference import prepare_logits_processor
18
  from fastchat.model.model_adapter import load_model, get_conversation_template
19
+ from torch.utils.data import Dataset, DataLoader
20
  from zeus.monitor import ZeusMonitor
21
 
22
  SYSTEM_PROMPTS = {
 
40
  ),
41
  }
42
 
43
+ @dataclass
44
+ class Output:
45
+ response_length: int
46
+ input: str
47
+ output: str
 
48
 
49
  @torch.inference_mode()
50
+ def run_inference(
51
  model,
52
  tokenizer,
53
  params: Dict,
54
  device: str,
55
  context_len: int = 2048,
56
+ ) ->list[Output]:
57
  # Read parameters
58
  prompts = params["prompt"]
59
  temperature = float(params.get("temperature", 1.0))
 
62
  top_k = int(params.get("top_k", -1)) # -1 means disable
63
  max_new_tokens = int(params.get("max_new_tokens", 256))
64
  stop_str = params.get("stop", None)
65
+ stop_token_ids = list(params.get("stop_token_ids", None) or [])
66
  stop_token_ids.append(tokenizer.eos_token_id)
67
  batch_size = len(prompts)
68
 
69
+ empty_result = Output(response_length=-1, input="", output="")
70
+ result = []
71
+ for i, prompt in enumerate(prompts):
72
+ result.append(copy.deepcopy(empty_result))
73
+ result[i].input = prompt
74
+
75
  # left append prompts with eos to make all input prompts the same length
76
  tokenizer.padding_side = "left"
77
  tokenizer.pad_token = tokenizer.eos_token
 
81
  )
82
 
83
  input_ids = tokenizer(prompts, padding=True).input_ids
84
+ output_ids = [[] for _ in range(batch_size)]
85
 
86
  if model.config.is_encoder_decoder:
87
  max_src_len = context_len
88
  else: # truncate
89
+ max_src_len = context_len - max_new_tokens - 1
90
 
91
  input_ids = [input_id[-max_src_len:] for input_id in input_ids]
 
92
 
93
  if model.config.is_encoder_decoder:
94
  encoder_output = model.encoder(
 
146
  else:
147
  last_token_logits = logits[:, -1, :]
148
 
149
+ # handle unexpected Nan issue for llama 2 7b chat
150
+ if torch.any(torch.isnan(last_token_logits)) == True:
151
+ return []
152
+
153
  if temperature < 1e-5 or top_p < 1e-8: # greedy
154
  _, indices = torch.topk(last_token_logits, 2)
155
  tokens = [[int(token) for token in query] for query in indices.tolist()]
 
157
  probs = torch.softmax(last_token_logits, dim=-1)
158
  indices = torch.multinomial(probs, num_samples=2)
159
  tokens = [[int(token) for token in query] for query in indices.tolist()]
160
+
161
+ output_ids = [ids + [token[0]] for ids, token in zip(output_ids, tokens)]
162
 
163
+ # deal with stop_token_ids
164
  old_stopped = stopped
165
  stopped = np.logical_or(old_stopped, np.array([True if token[0] in stop_token_ids else False for token in tokens]))
166
+ different_indices = np.where(stopped != old_stopped)[0]
167
 
168
+ rfind_start = 0
169
+ output = tokenizer.batch_decode(
170
+ output_ids,
171
+ skip_special_tokens=True,
172
+ spaces_between_special_tokens=False,
173
+ clean_up_tokenization_spaces=True,
174
+ )
175
+ output_np = np.array(output)
176
+
177
+ if different_indices.size > 0:
178
+ # here i but not i+1 is because the i+1 token generated is in stop_token_ids
179
+ for j in different_indices:
180
+ result[j].response_length = i
181
+ result[j].output = output[j]
182
+
183
+ # deal with stop_str
184
+ if stop_str:
185
+ if isinstance(stop_str, str):
186
+ pos_array = np.char.rfind(output_np, stop_str, rfind_start)
187
+ find_stop = pos_array != -1
188
+ elif isinstance(stop_str, Iterable):
189
+ for each_stop in stop_str:
190
+ pos_array = np.char.rfind(output_np, each_stop, rfind_start)
191
  find_stop = pos_array != -1
192
+ else:
193
+ raise ValueError("Invalid stop field type.")
194
+
195
+ stop_str_indices = np.where(find_stop & ~stopped)[0]
196
+ if stop_str_indices.size > 0:
197
+ for j in stop_str_indices:
198
+ # TODO: find a elegant way to figure out the size of stop_str, here just suppose stop_str has one token
199
+ result[j].response_length = i
200
+ result[j].output = output[j][:pos_array[j]]
201
+ stopped[find_stop] = True
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  if all(stopped):
204
  break
205
 
206
+ not_finish_indices = np.where(stopped == False)[0]
207
  if any(stopped) == False:
 
208
  output = tokenizer.batch_decode(
209
+ output_ids,
210
  skip_special_tokens=True,
211
  spaces_between_special_tokens=False,
212
  clean_up_tokenization_spaces=True,
213
  )
 
214
 
215
+ for j in not_finish_indices:
216
+ result[j].response_length = max_new_tokens
217
+ result[j].output = output[j]
 
218
 
219
+ return result
 
 
 
220
 
221
+ def write_error_to_file(filename, error_message):
222
+ with open(filename, 'a') as file:
223
+ file.write(error_message + '\n')
224
 
225
  def main(
226
  model_path: str,
 
341
  "temperature": temperature,
342
  "repitition_penalty": repitition_penalty,
343
  "max_new_tokens": max_new_tokens,
344
+ "batch_size": batch,
345
  },
346
  config_json,
347
  indent=4,
348
  )
349
  config_json.write("\n")
350
 
351
+ class CustomDataset(Dataset):
352
+ def __init__(self, data):
353
+ self.data = data
354
+
355
+ def __len__(self):
356
+ return len(self.data)
357
+
358
+ def __getitem__(self, index):
359
+ sample = self.data[index]
360
+ return sample["conversations"][0]["value"]
361
+
362
+
363
+ def dataloader(input_file: str, batch_size: batch) -> Generator[tuple[bool, str], None, None]:
364
  """Yields a tuple of whether this is a warmup run and the input prompt."""
365
+ for _ in range(3):
366
+ yield True, ["Say something long and random. I don't care about the content." for _ in range (batch)]
367
+ data = json.load(open(input_file, "r"))
368
+ custom_dataset = CustomDataset(data)
369
+ data_loader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=False)
370
+ for prompt in data_loader:
371
+ yield False, prompt
372
 
373
  # Warm up the GPU with some random prompts.
374
  # Forward through all the prompts.
375
  is_first = True
376
  convs = []
377
  prompts = []
378
+ data_iter = iter(dataloader(input_file, batch))
379
+
380
+ for is_warmup, input_prompts in data_iter:
 
 
 
 
 
 
381
  # Construct the input prompt.
382
+ for i in range(batch):
383
  conv = copy.deepcopy(conv_base)
384
+ conv.append_message(conv.roles[0], input_prompts[i])
385
  conv.append_message(conv.roles[1], "")
386
  prompt = conv.get_prompt()
387
  prompts.append(prompt)
388
  convs.append(conv)
389
+
390
  gen_params["prompt"] = prompts
 
 
391
 
392
  # Print input prompt.
393
  for i in range(len(convs)):
394
  console.print(f"\n[u cyan]{'Warmup ' if is_warmup else ''}Prompt[/u cyan](batch_{i}):")
395
  console.print(prompts[i].strip() + "\n", markup=False)
396
 
 
 
 
 
 
397
  #################################################
398
  # Inference and measurement zone!
399
  #################################################
400
  monitor.begin_window("inference")
401
+ results = run_inference(model, tokenizer, gen_params, device="cuda", context_len=2048)
 
 
 
402
  measurements = monitor.end_window("inference")
403
  #################################################
404
+ if results:
405
+ # Record numbers.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
  if not is_warmup:
407
+ response_length = sum([result.response_length for result in results]) # number of valid tokens
408
+ latency = measurements.time
409
+ throughput = response_length / latency
410
+ energy = measurements.total_energy
411
+ output = {
412
+ "model": model_name_cleaned,
413
+ "throughput": throughput,
414
+ "response_length": response_length,
415
+ "latency": latency,
416
+ "energy": energy,
417
+ "input": [prompt.strip() for prompt in prompts],
418
+ "output": [(result.output).strip() for result in results],
419
+ }
420
+ output_str = json.dumps(output, indent=4)
421
+ if not is_warmup:
422
+ if not is_first:
423
+ output_json.write(",\n" + output_str)
424
+ else:
425
+ is_first = False
426
+ output_json.write(output_str)
427
+ output_json.flush()
428
+
429
+ # Print the response.
430
+ for i in range(len(convs)):
431
+ console.print(f"\n[u cyan]{'Warmup ' if is_warmup else ''}Response[/u cyan](batch_{i}):")
432
+ console.print(results[i].output.strip() + "\n", markup=False)
433
 
434
  # Print measurement.
435
  console.print(measurements)
436
  convs = []
437
  prompts = []
438
 
 
 
 
 
439
  if __name__ == "__main__":
440
  tyro.cli(main)