|
from copy import copy |
|
from typing import Dict, List, Optional, Tuple, Union |
|
|
|
|
|
class LMTemplateParser: |
|
"""Intermidate prompt template parser, specifically for language models. |
|
|
|
Args: |
|
meta_template (list of dict, optional): The meta template for the |
|
model. |
|
""" |
|
|
|
def __init__(self, meta_template: Optional[List[Dict]] = None): |
|
self.meta_template = meta_template |
|
if meta_template: |
|
assert isinstance(meta_template, list) |
|
self.roles: Dict[str, dict] = dict() |
|
for item in meta_template: |
|
assert isinstance(item, dict) |
|
assert item['role'] not in self.roles, \ |
|
'role in meta prompt must be unique!' |
|
self.roles[item['role']] = item.copy() |
|
|
|
def __call__(self, dialog) -> str: |
|
"""Parse a prompt template, and wrap it with meta template if |
|
applicable. |
|
|
|
Args: |
|
dialog (List[str or PromptList]): A prompt |
|
template (potentially before being wrapped by meta template). |
|
|
|
Returns: |
|
str: The final string. |
|
""" |
|
assert isinstance(dialog, (str, list)) |
|
if isinstance(dialog, str): |
|
return dialog |
|
if self.meta_template: |
|
|
|
prompt = '' |
|
for index, item in enumerate(dialog): |
|
if isinstance(item, str): |
|
prompt += item |
|
else: |
|
new_str = self._prompt2str(item, index == len(dialog) - 1) |
|
prompt += new_str |
|
else: |
|
|
|
prompt = '' |
|
last_sep = '' |
|
for item in dialog: |
|
if isinstance(item, str): |
|
if item: |
|
prompt += last_sep + item |
|
elif item.get('content', ''): |
|
prompt += last_sep + item.get('prompt', '') |
|
last_sep = '\n' |
|
return prompt |
|
|
|
def _format_begin(self, role_cfg, message): |
|
name = message.get('name', None) |
|
if name is not None: |
|
begin = role_cfg['begin'].get('with_name', '') |
|
if name in role_cfg['begin'].get('name', {}): |
|
begin = begin.format(name=role_cfg['begin']['name'][name]) |
|
else: |
|
begin = begin.format(name=name) |
|
else: |
|
if isinstance(role_cfg.get('begin', ''), str): |
|
begin = role_cfg.get('begin', '') |
|
elif isinstance(role_cfg['begin'], dict): |
|
begin = role_cfg['begin'].get('without_name', '') |
|
return begin |
|
|
|
def _prompt2str(self, |
|
prompt: Union[str, Dict], |
|
last: bool = False) -> Tuple[str, bool]: |
|
if isinstance(prompt, str): |
|
return prompt |
|
merged_prompt = self.roles.get(prompt['role']) |
|
|
|
if merged_prompt.get('fallback_role'): |
|
merged_prompt = self.roles.get(merged_prompt['fallback_role']) |
|
begin = self._format_begin(merged_prompt, prompt) |
|
res = begin |
|
if last and merged_prompt.get('generate', False): |
|
res += prompt.get('content', '') |
|
return res |
|
res += prompt.get('content', '') + merged_prompt.get('end', '') |
|
if last and merged_prompt['role'] != 'assistant': |
|
res += self._format_begin(self.roles['assistant'], {}) |
|
return res |
|
return res |
|
|
|
|
|
class BaseLLM: |
|
"""Base class for model wrapper. |
|
|
|
Args: |
|
path (str): The path to the model. |
|
max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults |
|
to 512. |
|
tokenizer_only (bool): If True, only the tokenizer will be initialized. |
|
Defaults to False. |
|
meta_template (list of dict, optional): The model's meta prompt |
|
template if needed, in case the requirement of injecting or |
|
wrapping of any meta instructions. |
|
""" |
|
|
|
def __init__(self, |
|
path: str, |
|
tokenizer_only: bool = False, |
|
template_parser: 'LMTemplateParser' = LMTemplateParser, |
|
meta_template: Optional[List[Dict]] = None, |
|
*, |
|
max_new_tokens: int = 512, |
|
top_p: float = 0.8, |
|
top_k: float = 40, |
|
temperature: float = 0.8, |
|
repetition_penalty: float = 1.0, |
|
stop_words: Union[List[str], str] = None): |
|
self.path = path |
|
self.tokenizer_only = tokenizer_only |
|
|
|
self.template_parser = template_parser(meta_template) |
|
self.eos_token_id = None |
|
if meta_template and 'eos_token_id' in meta_template: |
|
self.eos_token_id = meta_template['eos_token_id'] |
|
|
|
if isinstance(stop_words, str): |
|
stop_words = [stop_words] |
|
self.gen_params = dict( |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
top_k=top_k, |
|
temperature=temperature, |
|
repetition_penalty=repetition_penalty, |
|
stop_words=stop_words) |
|
|
|
def generate(self, inputs: Union[str, List[str]], **gen_params) -> str: |
|
"""Generate results given a str (or list of) inputs. |
|
|
|
Args: |
|
inputs (Union[str, List[str]]): |
|
gen_params (dict): The input params for generation. |
|
|
|
Returns: |
|
Union[str, List[str]]: A (list of) generated strings. |
|
|
|
eg. |
|
batched = True |
|
if isinstance(inputs, str): |
|
inputs = [inputs] |
|
batched = False |
|
response = [''] |
|
if batched: |
|
return response |
|
return response[0] |
|
""" |
|
raise NotImplementedError |
|
|
|
def stream_generate(self, inputs: str, **gen_params) -> List[str]: |
|
"""Generate results as streaming given a str inputs. |
|
|
|
Args: |
|
inputs (str): |
|
gen_params (dict): The input params for generation. |
|
|
|
Returns: |
|
str: A generated string. |
|
""" |
|
raise NotImplementedError |
|
|
|
def chat(self, |
|
inputs: Union[List[dict], List[List[dict]]], |
|
session_ids: Union[int, List[int]] = None, |
|
**gen_params): |
|
"""Generate completion from a list of templates. |
|
|
|
Args: |
|
inputs (Union[List[dict], List[List[dict]]]): |
|
gen_params (dict): The input params for generation. |
|
Returns: |
|
""" |
|
if isinstance(inputs[0], list): |
|
_inputs = list() |
|
for msg in inputs: |
|
_inputs.append(self.template_parser(msg)) |
|
else: |
|
_inputs = self.template_parser(inputs) |
|
return self.generate(_inputs, **gen_params) |
|
|
|
def stream_chat(self, inputs: List[dict], **gen_params): |
|
"""Generate results as streaming given a list of templates. |
|
|
|
Args: |
|
inputs (Union[List[dict]): |
|
gen_params (dict): The input params for generation. |
|
Returns: |
|
""" |
|
raise NotImplementedError |
|
|
|
def tokenize(self, prompts: Union[str, List[str], List[dict], |
|
List[List[dict]]]): |
|
"""Tokenize the input prompts. |
|
|
|
Args: |
|
prompts(str | List[str]): user's prompt, or a batch prompts |
|
|
|
Returns: |
|
Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token |
|
ids, ids' length and requested output length |
|
""" |
|
raise NotImplementedError |
|
|
|
def update_gen_params(self, **kwargs): |
|
gen_params = copy(self.gen_params) |
|
gen_params.update(kwargs) |
|
return gen_params |
|
|
|
|
|
class AsyncLLMMixin: |
|
|
|
async def generate(self, |
|
inputs: Union[str, List[str]], |
|
session_ids: Union[int, List[int]] = None, |
|
**gen_params) -> str: |
|
"""Generate results given a str (or list of) inputs. |
|
|
|
Args: |
|
inputs (Union[str, List[str]]): |
|
gen_params (dict): The input params for generation. |
|
|
|
Returns: |
|
Union[str, List[str]]: A (list of) generated strings. |
|
|
|
eg. |
|
batched = True |
|
if isinstance(inputs, str): |
|
inputs = [inputs] |
|
batched = False |
|
response = [''] |
|
if batched: |
|
return response |
|
return response[0] |
|
""" |
|
raise NotImplementedError |
|
|
|
async def stream_generate(self, inputs: str, **gen_params) -> List[str]: |
|
"""Generate results as streaming given a str inputs. |
|
|
|
Args: |
|
inputs (str): |
|
gen_params (dict): The input params for generation. |
|
|
|
Returns: |
|
str: A generated string. |
|
""" |
|
raise NotImplementedError |
|
|
|
async def chat(self, |
|
inputs: Union[List[dict], List[List[dict]]], |
|
session_ids: Union[int, List[int]] = None, |
|
**gen_params): |
|
"""Generate completion from a list of templates. |
|
|
|
Args: |
|
inputs (Union[List[dict], List[List[dict]]]): |
|
gen_params (dict): The input params for generation. |
|
Returns: |
|
""" |
|
if isinstance(inputs[0], list): |
|
_inputs = list() |
|
for msg in inputs: |
|
_inputs.append(self.template_parser(msg)) |
|
else: |
|
_inputs = self.template_parser(inputs) |
|
return await self.generate(_inputs, session_ids, **gen_params) |
|
|
|
async def stream_chat(self, inputs: List[dict], **gen_params): |
|
"""Generate results as streaming given a list of templates. |
|
|
|
Args: |
|
inputs (Union[List[dict]): |
|
gen_params (dict): The input params for generation. |
|
Returns: |
|
""" |
|
raise NotImplementedError |
|
|
|
async def tokenize(self, prompts: Union[str, List[str], List[dict], |
|
List[List[dict]]]): |
|
"""Tokenize the input prompts. |
|
|
|
Args: |
|
prompts(str | List[str]): user's prompt, or a batch prompts |
|
|
|
Returns: |
|
Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token |
|
ids, ids' length and requested output length |
|
""" |
|
raise NotImplementedError |
|
|
|
|
|
class AsyncBaseLLM(AsyncLLMMixin, BaseLLM): |
|
pass |
|
|