|
|
|
|
|
|
|
|
|
|
|
"""Module to generate OpenELM output given a model and an input prompt.""" |
|
import os |
|
import logging |
|
import time |
|
import argparse |
|
from typing import Optional, Union |
|
import torch |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
def generate( |
|
prompt: str, |
|
model: Union[str, AutoModelForCausalLM], |
|
hf_access_token: str = None, |
|
tokenizer: Union[str, AutoTokenizer] = 'meta-llama/Llama-2-7b-hf', |
|
device: Optional[str] = None, |
|
max_length: int = 1024, |
|
assistant_model: Optional[Union[str, AutoModelForCausalLM]] = None, |
|
generate_kwargs: Optional[dict] = None, |
|
) -> str: |
|
""" Generates output given a prompt. |
|
|
|
Args: |
|
prompt: The string prompt. |
|
model: The LLM Model. If a string is passed, it should be the path to |
|
the hf converted checkpoint. |
|
hf_access_token: Hugging face access token. |
|
tokenizer: Tokenizer instance. If model is set as a string path, |
|
the tokenizer will be loaded from the checkpoint. |
|
device: String representation of device to run the model on. If None |
|
and cuda available it would be set to cuda:0 else cpu. |
|
max_length: Maximum length of tokens, input prompt + generated tokens. |
|
assistant_model: If set, this model will be used for |
|
speculative generation. If a string is passed, it should be the |
|
path to the hf converted checkpoint. |
|
generate_kwargs: Extra kwargs passed to the hf generate function. |
|
|
|
Returns: |
|
output_text: output generated as a string. |
|
generation_time: generation time in seconds. |
|
|
|
Raises: |
|
ValueError: If device is set to CUDA but no CUDA device is detected. |
|
ValueError: If tokenizer is not set. |
|
ValueError: If hf_access_token is not specified. |
|
""" |
|
if not device: |
|
if torch.cuda.is_available() and torch.cuda.device_count(): |
|
device = "cuda:0" |
|
logging.warning( |
|
'inference device is not set, using cuda:0, %s', |
|
torch.cuda.get_device_name(0) |
|
) |
|
else: |
|
device = 'cpu' |
|
logging.warning( |
|
( |
|
'No CUDA device detected, using cpu, ' |
|
'expect slower speeds.' |
|
) |
|
) |
|
|
|
if 'cuda' in device and not torch.cuda.is_available(): |
|
raise ValueError('CUDA device requested but no CUDA device detected.') |
|
|
|
if not tokenizer: |
|
raise ValueError('Tokenizer is not set in the generate function.') |
|
|
|
if not hf_access_token: |
|
raise ValueError(( |
|
'Hugging face access token needs to be specified. ' |
|
'Please refer to https://huggingface.co/docs/hub/security-tokens' |
|
' to obtain one.' |
|
) |
|
) |
|
|
|
if isinstance(model, str): |
|
checkpoint_path = model |
|
model = AutoModelForCausalLM.from_pretrained( |
|
checkpoint_path, |
|
trust_remote_code=True |
|
) |
|
model.to(device).eval() |
|
if isinstance(tokenizer, str): |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
tokenizer, |
|
token=hf_access_token, |
|
) |
|
|
|
|
|
draft_model = None |
|
if assistant_model: |
|
draft_model = assistant_model |
|
if isinstance(assistant_model, str): |
|
draft_model = AutoModelForCausalLM.from_pretrained( |
|
assistant_model, |
|
trust_remote_code=True |
|
) |
|
draft_model.to(device).eval() |
|
|
|
|
|
tokenized_prompt = tokenizer(prompt) |
|
tokenized_prompt = torch.tensor( |
|
tokenized_prompt['input_ids'], |
|
device=device |
|
) |
|
|
|
tokenized_prompt = tokenized_prompt.unsqueeze(0) |
|
|
|
|
|
stime = time.time() |
|
output_ids = model.generate( |
|
tokenized_prompt, |
|
max_length=max_length, |
|
pad_token_id=0, |
|
assistant_model=draft_model, |
|
**(generate_kwargs if generate_kwargs else {}), |
|
) |
|
generation_time = time.time() - stime |
|
|
|
output_text = tokenizer.decode( |
|
output_ids[0].tolist(), |
|
skip_special_tokens=True |
|
) |
|
|
|
return output_text, generation_time |
|
|
|
|
|
def openelm_generate_parser(): |
|
"""Argument Parser""" |
|
|
|
class KwargsParser(argparse.Action): |
|
"""Parser action class to parse kwargs of form key=value""" |
|
def __call__(self, parser, namespace, values, option_string=None): |
|
setattr(namespace, self.dest, dict()) |
|
for val in values: |
|
if '=' not in val: |
|
raise ValueError( |
|
( |
|
'Argument parsing error, kwargs are expected in' |
|
' the form of key=value.' |
|
) |
|
) |
|
kwarg_k, kwarg_v = val.split('=') |
|
try: |
|
converted_v = int(kwarg_v) |
|
except ValueError: |
|
try: |
|
converted_v = float(kwarg_v) |
|
except ValueError: |
|
converted_v = kwarg_v |
|
getattr(namespace, self.dest)[kwarg_k] = converted_v |
|
|
|
parser = argparse.ArgumentParser('OpenELM Generate Module') |
|
parser.add_argument( |
|
'--model', |
|
dest='model', |
|
help='Path to the hf converted model.', |
|
required=True, |
|
type=str, |
|
) |
|
parser.add_argument( |
|
'--hf_access_token', |
|
dest='hf_access_token', |
|
help='Hugging face access token, starting with "hf_".', |
|
type=str, |
|
) |
|
parser.add_argument( |
|
'--prompt', |
|
dest='prompt', |
|
help='Prompt for LLM call.', |
|
default='', |
|
type=str, |
|
) |
|
parser.add_argument( |
|
'--device', |
|
dest='device', |
|
help='Device used for inference.', |
|
type=str, |
|
) |
|
parser.add_argument( |
|
'--max_length', |
|
dest='max_length', |
|
help='Maximum length of tokens.', |
|
default=256, |
|
type=int, |
|
) |
|
parser.add_argument( |
|
'--assistant_model', |
|
dest='assistant_model', |
|
help=( |
|
( |
|
'If set, this is used as a draft model ' |
|
'for assisted speculative generation.' |
|
) |
|
), |
|
type=str, |
|
) |
|
parser.add_argument( |
|
'--generate_kwargs', |
|
dest='generate_kwargs', |
|
help='Additional kwargs passed to the HF generate function.', |
|
type=str, |
|
nargs='*', |
|
action=KwargsParser, |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == '__main__': |
|
args = openelm_generate_parser() |
|
prompt = args.prompt |
|
|
|
output_text, genertaion_time = generate( |
|
prompt=prompt, |
|
model=args.model, |
|
device=args.device, |
|
max_length=args.max_length, |
|
assistant_model=args.assistant_model, |
|
generate_kwargs=args.generate_kwargs, |
|
hf_access_token=args.hf_access_token, |
|
) |
|
|
|
print_txt = ( |
|
f'\r\n{"=" * os.get_terminal_size().columns}\r\n' |
|
'\033[1m Prompt + Generated Output\033[0m\r\n' |
|
f'{"-" * os.get_terminal_size().columns}\r\n' |
|
f'{output_text}\r\n' |
|
f'{"-" * os.get_terminal_size().columns}\r\n' |
|
'\r\nGeneration took' |
|
f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m' |
|
'seconds.\r\n' |
|
) |
|
print(print_txt) |
|
|