Lagent / lagent /llms /base_api.py
yanyoyo
update
ec878fd
import warnings
from typing import Dict, List, Optional, Tuple, Union
from lagent.llms.base_llm import AsyncLLMMixin, BaseLLM
class APITemplateParser:
"""Intermidate prompt template parser, specifically for API models.
Args:
meta_template (Dict): The meta template for the model.
"""
def __init__(self, meta_template: Optional[Dict] = None):
self.meta_template = meta_template
# Check meta template
if meta_template:
assert isinstance(meta_template, list)
self.roles: Dict[str, dict] = dict() # maps role name to config
for item in meta_template:
assert isinstance(item, dict)
assert item['role'] not in self.roles, \
'role in meta prompt must be unique!'
self.roles[item['role']] = item.copy()
def __call__(self, dialog: List[Union[str, List]]):
"""Parse the intermidate prompt template, and wrap it with meta
template if applicable. When the meta template is set and the input is
a list, the return value will be a list containing the full
conversation history. Each item looks like:
.. code-block:: python
{'role': 'user', 'content': '...'}).
Args:
dialog (List[str or list]): An intermidate prompt
template (potentially before being wrapped by meta template).
Returns:
List[str or list]: The finalized prompt or a conversation.
"""
assert isinstance(dialog, (str, list))
if isinstance(dialog, str):
return dialog
if self.meta_template:
prompt = list()
# Whether to keep generating the prompt
generate = True
for i, item in enumerate(dialog):
if not generate:
break
if isinstance(item, str):
if item.strip():
# TODO: logger
warnings.warn('Non-empty string in prompt template '
'will be ignored in API models.')
else:
api_prompts = self._prompt2api(item)
prompt.append(api_prompts)
# merge the consecutive prompts assigned to the same role
new_prompt = list([prompt[0]])
last_role = prompt[0]['role']
for item in prompt[1:]:
if item['role'] == last_role:
new_prompt[-1]['content'] += '\n' + item['content']
else:
last_role = item['role']
new_prompt.append(item)
prompt = new_prompt
else:
# in case the model does not have any meta template
prompt = ''
last_sep = ''
for item in dialog:
if isinstance(item, str):
if item:
prompt += last_sep + item
elif item.get('content', ''):
prompt += last_sep + item.get('content', '')
last_sep = '\n'
return prompt
def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]:
"""Convert the prompts to a API-style prompts, given an updated
role_dict.
Args:
prompts (Union[List, str]): The prompts to be converted.
role_dict (Dict[str, Dict]): The updated role dict.
for_gen (bool): If True, the prompts will be converted for
generation tasks. The conversion stops before the first
role whose "generate" is set to True.
Returns:
Tuple[str, bool]: The converted string, and whether the follow-up
conversion should be proceeded.
"""
if isinstance(prompts, str):
return prompts
elif isinstance(prompts, dict):
api_role = self._role2api_role(prompts)
return api_role
res = []
for prompt in prompts:
if isinstance(prompt, str):
raise TypeError('Mixing str without explicit role is not '
'allowed in API models!')
else:
api_role = self._role2api_role(prompt)
res.append(api_role)
return res
def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
merged_prompt = self.roles[role_prompt['role']]
if merged_prompt.get('fallback_role'):
merged_prompt = self.roles[self.roles[
merged_prompt['fallback_role']]]
res = role_prompt.copy()
res['role'] = merged_prompt['api_role']
res['content'] = merged_prompt.get('begin', '')
res['content'] += role_prompt.get('content', '')
res['content'] += merged_prompt.get('end', '')
return res
class BaseAPILLM(BaseLLM):
"""Base class for API model wrapper.
Args:
model_type (str): The type of model.
retry (int): Number of retires if the API call fails. Defaults to 2.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
"""
is_api: bool = True
def __init__(self,
model_type: str,
retry: int = 2,
template_parser: 'APITemplateParser' = APITemplateParser,
meta_template: Optional[Dict] = None,
*,
max_new_tokens: int = 512,
top_p: float = 0.8,
top_k: int = 40,
temperature: float = 0.8,
repetition_penalty: float = 0.0,
stop_words: Union[List[str], str] = None):
self.model_type = model_type
self.meta_template = meta_template
self.retry = retry
if template_parser:
self.template_parser = template_parser(meta_template)
if isinstance(stop_words, str):
stop_words = [stop_words]
self.gen_params = dict(
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
stop_words=stop_words,
skip_special_tokens=False)
class AsyncBaseAPILLM(AsyncLLMMixin, BaseAPILLM):
pass