Lagent / lagent /llms /lmdeploy_wrapper.py
yanyoyo
update
ec878fd
import asyncio
import copy
import logging
from dataclasses import asdict
from typing import List, Optional, Union
import aiohttp
from lagent.llms.base_llm import AsyncLLMMixin, BaseLLM
from lagent.schema import ModelStatusCode
from lagent.utils.util import filter_suffix
class TritonClient(BaseLLM):
"""TritonClient is a wrapper of TritonClient for LLM.
Args:
tritonserver_addr (str): the address in format "ip:port" of
triton inference server
model_name (str): the name of the model
session_len (int): the context size
max_tokens (int): the expected generated token numbers
"""
def __init__(self,
tritonserver_addr: str,
model_name: str,
session_len: int = 32768,
log_level: str = 'WARNING',
**kwargs):
super().__init__(path=None, **kwargs)
try:
from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode
except Exception as e:
logging.error(f'{e}')
raise RuntimeError('DO NOT use turbomind.chatbot since it has '
'been removed by lmdeploy since v0.5.2')
self.state_map = {
StatusCode.TRITON_STREAM_END: ModelStatusCode.END,
StatusCode.TRITON_SERVER_ERR: ModelStatusCode.SERVER_ERR,
StatusCode.TRITON_SESSION_CLOSED: ModelStatusCode.SESSION_CLOSED,
StatusCode.TRITON_STREAM_ING: ModelStatusCode.STREAM_ING,
StatusCode.TRITON_SESSION_OUT_OF_LIMIT:
ModelStatusCode.SESSION_OUT_OF_LIMIT,
StatusCode.TRITON_SESSION_INVALID_ARG:
ModelStatusCode.SESSION_INVALID_ARG,
StatusCode.TRITON_SESSION_READY: ModelStatusCode.SESSION_READY
}
self.chatbot = Chatbot(
tritonserver_addr=tritonserver_addr,
model_name=model_name,
session_len=session_len,
log_level=log_level,
**kwargs)
def generate(self,
inputs: Union[str, List[str]],
session_id: int = 2967,
request_id: str = '',
sequence_start: bool = True,
sequence_end: bool = True,
skip_special_tokens: bool = False,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in non-stream mode.
Args:
inputs (str, List[str]): user's prompt(s) in this round
session_id (int): the identical id of a session
request_id (str): the identical id of this round conversation
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
from lmdeploy.serve.turbomind.chatbot import Session, get_logger
if isinstance(inputs, str):
inputs = [inputs]
prompt = inputs
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
self.chatbot.cfg = self._update_gen_params(**kwargs)
max_new_tokens = self.chatbot.cfg.max_new_tokens
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_new_tokens}')
if self.chatbot._session is None:
sequence_start = True
self.chatbot._session = Session(session_id=session_id)
elif self.chatbot._session.status == 0:
logger.error(f'session {session_id} has been ended. Please set '
f'`sequence_start` be True if you want to restart it')
return ''
self.chatbot._session.status = 1
self.chatbot._session.request_id = request_id
self.chatbot._session.response = ''
status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session,
prompt,
max_new_tokens,
sequence_start,
sequence_end,
skip_special_tokens=skip_special_tokens):
status = self.state_map.get(status)
if status < ModelStatusCode.END:
return ''
elif status == ModelStatusCode.END:
self.chatbot._session.histories = (
self.chatbot._session.histories +
self.chatbot._session.prompt +
self.chatbot._session.response)
# remove stop_words
res = filter_suffix(res, self.gen_params.get('stop_words'))
return res
def stream_chat(self,
inputs: List[dict],
session_id: int = 2967,
request_id: str = '',
sequence_start: bool = True,
sequence_end: bool = True,
skip_special_tokens: bool = False,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in stream mode.
Args:
session_id (int): the identical id of a session
inputs (List[dict]): user's inputs in this round conversation
request_id (str): the identical id of this round conversation
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
from lmdeploy.serve.turbomind.chatbot import Session, get_logger
assert isinstance(session_id, int), \
f'INT session id is required, but got {type(session_id)}'
self.chatbot.cfg = self._update_gen_params(**kwargs)
max_new_tokens = self.chatbot.cfg.max_new_tokens
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
logger.info(f'session {session_id}, request_id {request_id}, '
f'max_out_len {max_new_tokens}')
if self.chatbot._session is None:
sequence_start = True
self.chatbot._session = Session(session_id=session_id)
elif self.chatbot._session.status == 0:
logger.error(f'session {session_id} has been ended. Please set '
f'`sequence_start` be True if you want to restart it')
return ModelStatusCode.SESSION_CLOSED, '', 0
self.chatbot._session.status = 1
self.chatbot._session.request_id = request_id
self.chatbot._session.response = ''
prompt = self.template_parser(inputs)
status, res, _ = None, '', 0
for status, res, _ in self.chatbot._stream_infer(
self.chatbot._session,
prompt,
max_new_tokens,
sequence_start,
sequence_end,
skip_special_tokens=skip_special_tokens):
status = self.state_map.get(status)
# The stop symbol also appears in the output of the last STREAM_ING state.
res = filter_suffix(res, self.gen_params.get('stop_words'))
if status < ModelStatusCode.END:
return status, res, _
elif status == ModelStatusCode.END: # remove stop_words
self.chatbot._session.histories = (
self.chatbot._session.histories +
self.chatbot._session.prompt +
self.chatbot._session.response)
yield status, res, _
break
else:
yield status, res, _
def _update_gen_params(self, **kwargs):
import mmengine
new_gen_params = self.update_gen_params(**kwargs)
self.gen_params['stop_words'] = new_gen_params.pop('stop_words')
stop_words = self.chatbot._stop_words(
self.gen_params.get('stop_words'))
cfg = mmengine.Config(
dict(
session_len=self.chatbot.model.session_len,
stop_words=stop_words,
bad_words=self.chatbot.cfg.bad_words,
**new_gen_params))
return cfg
class LMDeployPipeline(BaseLLM):
"""
Args:
path (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
tp (int): tensor parallel
pipeline_cfg (dict): config of pipeline
"""
def __init__(self,
path: str,
model_name: Optional[str] = None,
tp: int = 1,
pipeline_cfg=dict(),
**kwargs):
import lmdeploy
from lmdeploy import ChatTemplateConfig, TurbomindEngineConfig, pipeline, version_info
self.str_version = lmdeploy.__version__
self.version = version_info
self.do_sample = kwargs.pop('do_sample', None)
if self.do_sample is not None and self.version < (0, 6, 0):
raise RuntimeError(
'`do_sample` parameter is not supported by lmdeploy until '
f'v0.6.0, but currently using lmdeloy {self.str_version}')
super().__init__(path=path, **kwargs)
backend_config = copy.deepcopy(pipeline_cfg)
backend_config.update(tp=tp)
backend_config = {
k: v
for k, v in backend_config.items()
if hasattr(TurbomindEngineConfig, k)
}
backend_config = TurbomindEngineConfig(**backend_config)
chat_template_config = ChatTemplateConfig(
model_name=model_name) if model_name else None
self.model = pipeline(
model_path=self.path,
backend_config=backend_config,
chat_template_config=chat_template_config,
log_level='WARNING')
def generate(self,
inputs: Union[str, List[str]],
do_preprocess: bool = None,
skip_special_tokens: bool = False,
return_dict: bool = False,
**kwargs):
"""Return the chat completions in non-stream mode.
Args:
inputs (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
from lmdeploy.messages import GenerationConfig
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
prompt = inputs
do_sample = kwargs.pop('do_sample', None)
gen_params = self.update_gen_params(**kwargs)
if do_sample is None:
do_sample = self.do_sample
if do_sample is not None and self.version < (0, 6, 0):
raise RuntimeError(
'`do_sample` parameter is not supported by lmdeploy until '
f'v0.6.0, but currently using lmdeloy {self.str_version}')
if self.version >= (0, 6, 0):
if do_sample is None:
do_sample = gen_params['top_k'] > 1 or gen_params[
'temperature'] > 0
gen_params.update(do_sample=do_sample)
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens, **gen_params)
response = self.model.batch_infer(
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
texts = [resp.text for resp in response]
# remove stop_words
texts = filter_suffix(texts, self.gen_params.get('stop_words'))
for resp, text in zip(response, texts):
resp.text = text
if batched:
return [asdict(resp)
for resp in response] if return_dict else texts
return asdict(response[0]) if return_dict else texts[0]
class LMDeployServer(BaseLLM):
"""
Args:
path (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download from
ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
server_name (str): host ip for serving
server_port (int): server port
tp (int): tensor parallel
log_level (str): set log level whose value among
[CRITICAL, ERROR, WARNING, INFO, DEBUG]
"""
def __init__(self,
path: str,
model_name: Optional[str] = None,
server_name: str = '0.0.0.0',
server_port: int = 23333,
tp: int = 1,
log_level: str = 'WARNING',
serve_cfg=dict(),
**kwargs):
super().__init__(path=path, **kwargs)
self.model_name = model_name
# TODO get_logger issue in multi processing
import lmdeploy
self.client = lmdeploy.serve(
model_path=self.path,
model_name=model_name,
server_name=server_name,
server_port=server_port,
tp=tp,
log_level=log_level,
**serve_cfg)
def generate(self,
inputs: Union[str, List[str]],
session_id: int = 2967,
sequence_start: bool = True,
sequence_end: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs) -> List[str]:
"""Start a new round conversation of a session. Return the chat
completions in non-stream mode.
Args:
inputs (str, List[str]): user's prompt(s) in this round
session_id (int): the identical id of a session
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
(a list of/batched) text/chat completion
"""
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)
resp = [''] * len(inputs)
for text in self.client.completions_v1(
self.model_name,
inputs,
session_id=session_id,
sequence_start=sequence_start,
sequence_end=sequence_end,
stream=False,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params):
resp = [
resp[i] + item['text']
for i, item in enumerate(text['choices'])
]
# remove stop_words
resp = filter_suffix(resp, self.gen_params.get('stop_words'))
if not batched:
return resp[0]
return resp
def stream_chat(self,
inputs: List[dict],
session_id=0,
sequence_start: bool = True,
sequence_end: bool = True,
stream: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs):
"""Start a new round conversation of a session. Return the chat
completions in stream mode.
Args:
session_id (int): the identical id of a session
inputs (List[dict]): user's inputs in this round conversation
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
stream (bool): return in a streaming format if enabled
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)
prompt = self.template_parser(inputs)
resp = ''
finished = False
stop_words = self.gen_params.get('stop_words')
for text in self.client.completions_v1(
self.model_name,
prompt,
session_id=session_id,
sequence_start=sequence_start,
sequence_end=sequence_end,
stream=stream,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params):
resp += text['choices'][0]['text']
if not resp:
continue
# remove stop_words
for sw in stop_words:
if sw in resp:
resp = filter_suffix(resp, stop_words)
finished = True
break
yield ModelStatusCode.STREAM_ING, resp, None
if finished:
break
yield ModelStatusCode.END, resp, None
class LMDeployClient(LMDeployServer):
"""
Args:
url (str): communicating address 'http://<ip>:<port>' of
api_server
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
"""
def __init__(self, url: str, model_name: str, **kwargs):
BaseLLM.__init__(self, path=url, **kwargs)
from lmdeploy.serve.openai.api_client import APIClient
self.client = APIClient(url)
self.model_name = model_name
class AsyncLMDeployPipeline(AsyncLLMMixin, LMDeployPipeline):
"""
Args:
path (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download
from ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
tp (int): tensor parallel
pipeline_cfg (dict): config of pipeline
"""
async def generate(self,
inputs: Union[str, List[str]],
session_ids: Union[int, List[int]] = None,
do_preprocess: bool = None,
skip_special_tokens: bool = False,
return_dict: bool = False,
**kwargs):
"""Return the chat completions in non-stream mode.
Args:
inputs (Union[str, List[str]]): input texts to be completed.
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
Returns:
(a list of/batched) text/chat completion
"""
from lmdeploy.messages import GenerationConfig, Response
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
if session_ids is None:
session_ids = list(range(len(inputs)))
elif isinstance(session_ids, (int, str)):
session_ids = [session_ids]
assert len(inputs) == len(session_ids)
prompt = inputs
gen_params = self.update_gen_params(**kwargs)
gen_config = GenerationConfig(
skip_special_tokens=skip_special_tokens, **gen_params)
async def _inner_generate(uid, text):
resp = Response('', 0, 0, uid)
async for out in self.model.generate(
text,
uid,
gen_config,
stream_response=True,
sequence_start=True,
sequence_end=True,
do_preprocess=do_preprocess,
**kwargs):
resp.text += out.response
resp.generate_token_len = out.generate_token_len
resp.input_token_len = out.input_token_len
resp.finish_reason = out.finish_reason
if out.token_ids:
resp.token_ids.extend(out.token_ids)
if out.logprobs:
if resp.logprobs is None:
resp.logprobs = []
resp.logprobs.extend(out.logprobs)
return resp
response = await asyncio.gather(*[
_inner_generate(sid, inp) for sid, inp in zip(session_ids, prompt)
])
texts = [resp.text for resp in response]
# remove stop_words
texts = filter_suffix(texts, self.gen_params.get('stop_words'))
for resp, text in zip(response, texts):
resp.text = text
if batched:
return [asdict(resp)
for resp in response] if return_dict else texts
return asdict(response[0]) if return_dict else texts[0]
class AsyncLMDeployServer(AsyncLLMMixin, LMDeployServer):
"""
Args:
path (str): The path to the model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download from
ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
server_name (str): host ip for serving
server_port (int): server port
tp (int): tensor parallel
log_level (str): set log level whose value among
[CRITICAL, ERROR, WARNING, INFO, DEBUG]
"""
async def generate(
self,
inputs: Union[str, List[str]],
session_ids: Union[int, List[int]] = None,
sequence_start: bool = True,
sequence_end: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs,
):
"""Start a new round conversation of a session. Return the chat
completions in non-stream mode.
Args:
inputs (str, List[str]): user's prompt(s) in this round
session_ids (int, List[int]): session id(s)
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
(a list of/batched) text/chat completion
"""
from lmdeploy.serve.openai.api_client import json_loads
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)
responses = [''] * len(inputs)
pload = dict(
model=self.model_name,
prompt=inputs,
sequence_start=sequence_start,
sequence_end=sequence_end,
stream=False,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params)
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(3 * 3600)) as session:
async with session.post(
self.client.completions_v1_url,
headers=self.client.headers,
json=pload) as resp:
async for chunk in resp.content:
if chunk:
decoded = chunk.decode('utf-8')
output = json_loads(decoded)
responses = [
response + item['text'] for response, item in zip(
responses, output['choices'])
]
# remove stop_words
responses = filter_suffix(responses, self.gen_params.get('stop_words'))
if not batched:
return responses[0]
return responses
async def stream_chat(
self,
inputs: List[dict],
session_id: int = None,
sequence_start: bool = True,
sequence_end: bool = True,
stream: bool = True,
ignore_eos: bool = False,
skip_special_tokens: Optional[bool] = False,
timeout: int = 30,
**kwargs,
):
"""Start a new round conversation of a session. Return the chat
completions in stream mode.
Args:
inputs (List[dict]): user's inputs in this round conversation
session_id (int): session id
sequence_start (bool): start flag of a session
sequence_end (bool): end flag of a session
stream (bool): return in a streaming format if enabled
ignore_eos (bool): indicator for ignoring eos
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be False.
timeout (int): max time to wait for response
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
from lmdeploy.serve.openai.api_client import json_loads
gen_params = self.update_gen_params(**kwargs)
max_new_tokens = gen_params.pop('max_new_tokens')
gen_params.update(max_tokens=max_new_tokens)
prompt = self.template_parser(inputs)
response = ''
finished = False
stop_words = self.gen_params.get('stop_words')
pload = dict(
model=self.model_name,
prompt=prompt,
sequence_start=sequence_start,
sequence_end=sequence_end,
stream=stream,
ignore_eos=ignore_eos,
skip_special_tokens=skip_special_tokens,
timeout=timeout,
**gen_params)
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(3 * 3600)) as session:
async with session.post(
self.client.completions_v1_url,
headers=self.client.headers,
json=pload) as resp:
async for chunk in resp.content:
if chunk:
decoded = chunk.decode('utf-8')
if not decoded.strip() or decoded.rstrip(
) == 'data: [DONE]':
continue
if decoded[:6] == 'data: ':
decoded = decoded[6:]
output = json_loads(decoded)
response += output['choices'][0]['text']
if not response:
continue
# remove stop_words
for sw in stop_words:
if sw in response:
response = filter_suffix(response, stop_words)
finished = True
break
yield ModelStatusCode.STREAM_ING, response, None
if finished:
break
yield ModelStatusCode.END, response, None
class AsyncLMDeployClient(AsyncLMDeployServer):
"""
Args:
url (str): communicating address 'http://<ip>:<port>' of
api_server
model_name (str): needed when model_path is a pytorch model on
huggingface.co, such as "internlm-chat-7b",
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
"""
def __init__(self, url: str, model_name: str, **kwargs):
BaseLLM.__init__(self, path=url, **kwargs)
from lmdeploy.serve.openai.api_client import APIClient
self.client = APIClient(url)
self.model_name = model_name