Lagent / lagent /llms /huggingface.py
Raymd9's picture
Add files
cfc816f
import copy
import logging
from typing import Dict, List, Optional, Union
from lagent.schema import ModelStatusCode
from .base_api import APITemplateParser
from .base_llm import BaseLLM
logger = logging.getLogger(__name__)
class HFTransformer(BaseLLM):
"""Model wrapper around HuggingFace general models.
Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/
chat/web_demo.py)
Args:
path (str): The name or path to HuggingFace's model.
tokenizer_path (str): The path to the tokenizer. Defaults to None.
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
Defaults to {}.
tokenizer_only (bool): If True, only the tokenizer will be initialized.
Defaults to False.
model_kwargs (dict): Keyword arguments for the model, used in loader.
Defaults to dict(device_map='auto').
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
"""
def __init__(self,
path: str,
tokenizer_path: Optional[str] = None,
tokenizer_kwargs: dict = dict(),
tokenizer_only: bool = False,
model_kwargs: dict = dict(device_map='auto'),
meta_template: Optional[Dict] = None,
stop_words_id: Union[List[int], int] = None,
**kwargs):
super().__init__(
path=path,
tokenizer_only=tokenizer_only,
meta_template=meta_template,
**kwargs)
if isinstance(stop_words_id, int):
stop_words_id = [stop_words_id]
self.gen_params.update(stop_words_id=stop_words_id)
if self.gen_params['stop_words'] is not None and \
self.gen_params['stop_words_id'] is not None:
logger.warning('Both stop_words and stop_words_id are specified,'
'only stop_words_id will be used.')
self._load_tokenizer(
path=path,
tokenizer_path=tokenizer_path,
tokenizer_kwargs=tokenizer_kwargs)
if not tokenizer_only:
self._load_model(path=path, model_kwargs=model_kwargs)
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList # noqa: E501
self.logits_processor = LogitsProcessorList()
self.stopping_criteria = StoppingCriteriaList()
self.prefix_allowed_tokens_fn = None
stop_words_id = []
if self.gen_params.get('stop_words_id'):
stop_words_id = self.gen_params.get('stop_words_id')
elif self.gen_params.get('stop_words'):
for sw in self.gen_params.get('stop_words'):
stop_words_id.append(self.tokenizer(sw)['input_ids'][-1])
self.additional_eos_token_id = stop_words_id
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
tokenizer_kwargs: dict):
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path if tokenizer_path else path,
trust_remote_code=True,
**tokenizer_kwargs)
if self.tokenizer.pad_token_id is None:
if self.tokenizer.eos_token is not None:
logger.warning(
f'Using eos_token_id {self.tokenizer.eos_token} '
'as pad_token_id.')
self.tokenizer.pad_token = self.tokenizer.eos_token
else:
from transformers.generation import GenerationConfig
self.gcfg = GenerationConfig.from_pretrained(path)
if self.gcfg.pad_token_id is not None:
logger.warning(
f'Using pad_token_id {self.gcfg.pad_token_id} '
'as pad_token_id.')
self.tokenizer.pad_token_id = self.gcfg.pad_token_id
else:
raise ValueError(
'pad_token_id is not set for this tokenizer. Try to '
'set pad_token_id via passing '
'`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
def _load_model(self, path: str, model_kwargs: dict):
import torch
from transformers import AutoModel
model_kwargs.setdefault('torch_dtype', torch.float16)
self.model = AutoModel.from_pretrained(
path, trust_remote_code=True, **model_kwargs)
self.model.eval()
def tokenize(self, inputs: str):
assert isinstance(inputs, str)
inputs = self.tokenizer(
inputs, return_tensors='pt', return_length=True)
return inputs['input_ids'].tolist()
def generate(
self,
inputs: Union[str, List[str]],
do_sample: bool = True,
**kwargs,
):
"""Return the chat completions in non-stream mode.
Args:
inputs (Union[str, List[str]]): input texts to be completed.
do_sample (bool): do sampling if enabled
Returns:
(a list of/batched) text/chat completion
"""
for status, chunk, _ in self.stream_generate(inputs, do_sample,
**kwargs):
response = chunk
return response
def stream_generate(
self,
inputs: List[str],
do_sample: bool = True,
**kwargs,
):
"""Return the chat completions in stream mode.
Args:
inputs (Union[str, List[str]]): input texts to be completed.
do_sample (bool): do sampling if enabled
Returns:
tuple(Status, str, int): status, text/chat completion,
generated token number
"""
import torch
from torch import nn
with torch.no_grad():
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
inputs = self.tokenizer(
inputs, padding=True, return_tensors='pt', return_length=True)
input_length = inputs['length']
for k, v in inputs.items():
inputs[k] = v.cuda()
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
batch_size = input_ids.shape[0]
input_ids_seq_length = input_ids.shape[-1]
generation_config = self.model.generation_config
generation_config = copy.deepcopy(generation_config)
new_gen_params = self.update_gen_params(**kwargs)
generation_config.update(**new_gen_params)
generation_config.update(**kwargs)
model_kwargs = generation_config.to_dict()
model_kwargs['attention_mask'] = attention_mask
_, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
generation_config.bos_token_id,
generation_config.eos_token_id,
)
if eos_token_id is None:
if self.gcfg.eos_token_id is not None:
eos_token_id = self.gcfg.eos_token_id
else:
eos_token_id = []
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if self.additional_eos_token_id is not None:
eos_token_id.extend(self.additional_eos_token_id)
eos_token_id_tensor = torch.tensor(eos_token_id).to(
input_ids.device) if eos_token_id is not None else None
generation_config.max_length = (
generation_config.max_new_tokens + input_ids_seq_length)
# Set generation parameters if not already defined
logits_processor = self.logits_processor
stopping_criteria = self.stopping_criteria
logits_processor = self.model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = self.model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria)
logits_warper = self.model._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(batch_size).fill_(1)
scores = None
while True:
model_inputs = self.model.prepare_inputs_for_generation(
input_ids, **model_kwargs)
# forward pass to get next token
outputs = self.model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids,
next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if do_sample:
next_tokens = torch.multinomial(
probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# update generated ids, model inputs,
# and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]],
dim=-1)
model_kwargs = self.model._update_model_kwargs_for_generation( # noqa: E501
outputs,
model_kwargs,
is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
output_token_ids = input_ids.cpu().tolist()
for i in range(len(output_token_ids)):
output_token_ids[i] = output_token_ids[i][:][
input_length[i]:]
# Find the first occurrence of
# an EOS token in the sequence
first_eos_idx = next(
(idx
for idx, token_id in enumerate(output_token_ids[i])
if token_id in eos_token_id), None)
# If an EOS token is found, only the previous
# part of it is retained
if first_eos_idx is not None:
output_token_ids[i] = output_token_ids[
i][:first_eos_idx]
response = self.tokenizer.batch_decode(output_token_ids)
# print(response)
if not batched:
response = response[0]
yield ModelStatusCode.STREAM_ING, response, None
# stop when each sentence is finished,
# or if we exceed the maximum length
if (unfinished_sequences.max() == 0
or stopping_criteria(input_ids, scores)):
break
yield ModelStatusCode.END, response, None
def stream_chat(
self,
inputs: List[dict],
do_sample: bool = True,
**kwargs,
):
"""Return the chat completions in stream mode.
Args:
inputs (List[dict]): input messages to be completed.
do_sample (bool): do sampling if enabled
Returns:
the text/chat completion
"""
prompt = self.template_parser(inputs)
yield from self.stream_generate(prompt, do_sample, **kwargs)
class HFTransformerCasualLM(HFTransformer):
def _load_model(self, path: str, model_kwargs: dict):
import torch
from transformers import AutoModelForCausalLM
model_kwargs.setdefault('torch_dtype', torch.float16)
self.model = AutoModelForCausalLM.from_pretrained(
path, trust_remote_code=True, **model_kwargs)
self.model.eval()
class HFTransformerChat(HFTransformerCasualLM):
def __init__(self, template_parser=APITemplateParser, **kwargs):
super().__init__(template_parser=template_parser, **kwargs)
def chat(self,
inputs: Union[List[dict], List[List[dict]]],
do_sample: bool = True,
**kwargs):
"""Return the chat completions in stream mode.
Args:
inputs (Union[List[dict], List[List[dict]]]): input messages to be completed.
do_sample (bool): do sampling if enabled
Returns:
the text/chat completion
"""
# handle batch inference with vanilla for loop
if isinstance(inputs[0], list):
resps = []
for input in inputs:
resps.append(self.chat(input, do_sample, **kwargs))
return resps
prompt = self.template_parser(inputs)
query = prompt[-1]['content']
history = prompt[:-1]
try:
response, history = self.model.chat(
self.tokenizer, query, history=history)
except Exception as e:
# handle over-length input error
logger.warning(str(e))
response = ''
return response