Spaces:
Running
Running
"""Perform inference of one model on one input prompt and measure time and energy.""" | |
from __future__ import annotations | |
import os | |
import json | |
import copy | |
import atexit | |
from typing import Generator, Literal | |
import tyro | |
import torch | |
import rich | |
from rich.table import Table | |
from fastchat.serve.inference import generate_stream | |
from fastchat.model.model_adapter import load_model, get_conversation_template | |
from zeus.monitor import ZeusMonitor | |
SYSTEM_PROMPTS = { | |
"chat": ( | |
"A chat between a human user (prompter) and an artificial intelligence (AI) assistant. " | |
"The assistant gives helpful, detailed, and polite answers to the user's questions." | |
), | |
"chat-concise": ( | |
"A chat between a human user (prompter) and an artificial intelligence (AI) assistant. " | |
"The assistant gives helpful, detailed, and polite answers to the user's questions. " | |
"The assistnat's answers are concise but high-quality." | |
), | |
"instruct": ( | |
"Below is an instruction that describes a task. " | |
"Write a response that appropriately completes the request." | |
), | |
"instruct-concise": ( | |
"Below is an instruction that describes a task. " | |
"Write a response that appropriately completes the request." | |
"The response should be concise but high-quality." | |
), | |
} | |
def main( | |
model_path: str, | |
input_file: str, | |
device_index: int = 0, | |
task: Literal[tuple(SYSTEM_PROMPTS)] = "chat", # type: ignore | |
load_8bit: bool = False, | |
temperature: float = 0.7, | |
repitition_penalty: float = 1.0, | |
max_new_tokens: int = 512, | |
) -> None: | |
"""Run the main routine. | |
Code structure is based on | |
https://github.com/lm-sys/FastChat/blob/57dea54055/fastchat/serve/inference.py#L249 | |
Args: | |
model_path: Path to or Huggingface Hub Id of the model. | |
input_file: Path to the input JSON file. Assumed to be our cleaned ShareGPT data. | |
device_index: Index of the GPU to use for inference. | |
task: Type of task to perform inference on. | |
load_8bit: Whether to load the model in 8-bit mode. | |
temperature: Temperature to use for sampling. | |
repitition_penalty: Repitition penalty to use for the model. | |
max_new_tokens: Maximum numbers of tokens to generate, ignoring the prompt. | |
""" | |
# NOTE(JW): ChatGLM is implemented as a special case in FastChat inference. | |
# Also, it's primarily a model that's fine-tuned for Chinese, so it doesn't | |
# make sense to prompt it in English and talk about its verbosity. | |
if "chatglm" in model_path.lower(): | |
raise ValueError("ChatGLM is not supported.") | |
# Print out what we're about to do. | |
model_name_cleaned = "--".join(model_path.split("/")[-2:]) | |
output_dir = f"data/{task}/{model_name_cleaned}" | |
output_csv_path = f"{output_dir}/benchmark.json" | |
config_json_path = f"{output_dir}/config.json" | |
table = Table(title="Benchmark") | |
table.add_column("Configuration") | |
table.add_column("Value") | |
table.add_row("Model", f"{model_name_cleaned} (path: {model_path})") | |
table.add_row("Input", input_file) | |
table.add_row("Device", f"cuda:{device_index}") | |
table.add_row("Task", task) | |
table.add_row("8-bit", str(load_8bit)) | |
table.add_row("Temperature", f"{temperature:.2f}") | |
table.add_row("Repitition Penalty", f"{repitition_penalty:.2f}") | |
table.add_row("Max New Tokens", str(max_new_tokens)) | |
table.add_row("Output CSV", output_csv_path) | |
table.add_row("Config JSON", config_json_path) | |
rich.get_console().print(table) | |
# Set the device. | |
torch.cuda.set_device(f"cuda:{device_index}") | |
# Load the model (Huggingface PyTorch) and tokenizer (Huggingface). | |
model, tokenizer = load_model( | |
model_path=model_path, | |
device="cuda", | |
num_gpus=1, | |
max_gpu_memory=None, | |
load_8bit=load_8bit, | |
cpu_offloading=False, | |
gptq_config=None, | |
debug=False, | |
) | |
# Chats are accumulated in a conversation helper object. | |
conv_base = get_conversation_template(model_path) | |
# Standardize the system prompt for every model. | |
conv_base.system = SYSTEM_PROMPTS[task] | |
conv_base.messages = [] | |
conv_base.offset = 0 | |
gen_params = { | |
"model": model_path, | |
"prompt": "EMPTY", | |
"temperature": temperature, | |
"repitition_penalty": repitition_penalty, | |
"max_new_tokens": max_new_tokens, | |
"stop": conv_base.stop_str, | |
"stop_token_ids": conv_base.stop_token_ids, | |
"echo": False, | |
} | |
monitor = ZeusMonitor(gpu_indices=[torch.cuda.current_device()]) | |
# Output files. | |
# Leave only the last two path components and replace slashes with double dashes. | |
os.makedirs(output_dir, exist_ok=True) | |
output_json = open(output_csv_path, "w") | |
output_json.write("[\n") | |
output_json.flush() | |
# Conclude the JSON file format with a closing bracket. Using `atexit` will | |
# handle all cases of the program exiting, including Ctrl-C and errors. | |
atexit.register(lambda: output_json.write("\n]\n")) | |
# Dump the configuration to a JSON file. | |
with open(config_json_path, "w") as config_json: | |
json.dump( | |
{ | |
"model_path": model_path, | |
"input_file": input_file, | |
"device_index": device_index, | |
"task": task, | |
"load_8bit": load_8bit, | |
"temperature": temperature, | |
"repitition_penalty": repitition_penalty, | |
"max_new_tokens": max_new_tokens, | |
}, | |
config_json, | |
indent=4, | |
) | |
config_json.write("\n") | |
def dataloader(input_file: str) -> Generator[tuple[bool, str], None, None]: | |
"""Yields a tuple of whether this is a warmup run and the input prompt.""" | |
for _ in range(3): | |
yield True, "Say something long and random. I don't care about the content." | |
for item in json.load(open(input_file, "r")): | |
input_prompt = item["conversations"][0]["value"] | |
yield False, input_prompt | |
# Warm up the GPU with some random prompts. | |
# Forward through all the prompts. | |
is_first = True | |
for is_warmup, input_prompt in dataloader(input_file): | |
# Construct the input prompt. | |
conv = copy.deepcopy(conv_base) | |
conv.append_message(conv.roles[0], input_prompt) | |
conv.append_message(conv.roles[1], "") | |
prompt = conv.get_prompt() | |
gen_params["prompt"] = prompt | |
# Print input prompt. | |
rich.print(f"\n[u]{'Warmup ' if is_warmup else ''}Prompt[/u]:\n{prompt.strip()}\n") | |
# Generate the ouptut from the model. | |
output_stream = generate_stream(model, tokenizer, gen_params, device="cuda") | |
output = {} | |
################################################# | |
# Inference and measurement zone! | |
################################################# | |
monitor.begin_window("inference") | |
for output in output_stream: | |
pass | |
measurements = monitor.end_window("inference") | |
################################################# | |
# Record numbers. | |
output_text = output["text"] | |
if not is_warmup: | |
response_length = len(tokenizer.encode(output_text)) # number of tokens | |
latency = measurements.time | |
throughput = response_length / latency | |
energy = measurements.total_energy | |
output = { | |
"model": model_name_cleaned, | |
"throughput": throughput, | |
"response_length": response_length, | |
"latency": latency, | |
"energy": energy, | |
"input": prompt.strip(), | |
"output": output_text.strip(), | |
} | |
output_str = json.dumps(output, indent=4) | |
if not is_warmup: | |
if not is_first: | |
output_json.write(",\n" + output_str) | |
else: | |
is_first = False | |
output_json.write(output_str) | |
output_json.flush() | |
# Print the response. | |
rich.print(f"\n[u]{'Warmup ' if is_warmup else ''}Response[/u]:\n{output_text.strip()}\n") | |
# Print measurement. | |
rich.print(measurements) | |
if __name__ == "__main__": | |
tyro.cli(main) | |