|
import copy |
|
import logging |
|
from typing import Dict, List, Optional, Union |
|
|
|
from lagent.schema import ModelStatusCode |
|
from .base_api import APITemplateParser |
|
from .base_llm import BaseLLM |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class HFTransformer(BaseLLM): |
|
"""Model wrapper around HuggingFace general models. |
|
|
|
Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/ |
|
chat/web_demo.py) |
|
|
|
Args: |
|
path (str): The name or path to HuggingFace's model. |
|
tokenizer_path (str): The path to the tokenizer. Defaults to None. |
|
tokenizer_kwargs (dict): Keyword arguments for the tokenizer. |
|
Defaults to {}. |
|
tokenizer_only (bool): If True, only the tokenizer will be initialized. |
|
Defaults to False. |
|
model_kwargs (dict): Keyword arguments for the model, used in loader. |
|
Defaults to dict(device_map='auto'). |
|
meta_template (Dict, optional): The model's meta prompt |
|
template if needed, in case the requirement of injecting or |
|
wrapping of any meta instructions. |
|
""" |
|
|
|
def __init__(self, |
|
path: str, |
|
tokenizer_path: Optional[str] = None, |
|
tokenizer_kwargs: dict = dict(), |
|
tokenizer_only: bool = False, |
|
model_kwargs: dict = dict(device_map='auto'), |
|
meta_template: Optional[Dict] = None, |
|
stop_words_id: Union[List[int], int] = None, |
|
**kwargs): |
|
super().__init__( |
|
path=path, |
|
tokenizer_only=tokenizer_only, |
|
meta_template=meta_template, |
|
**kwargs) |
|
if isinstance(stop_words_id, int): |
|
stop_words_id = [stop_words_id] |
|
self.gen_params.update(stop_words_id=stop_words_id) |
|
if self.gen_params['stop_words'] is not None and \ |
|
self.gen_params['stop_words_id'] is not None: |
|
logger.warning('Both stop_words and stop_words_id are specified,' |
|
'only stop_words_id will be used.') |
|
|
|
self._load_tokenizer( |
|
path=path, |
|
tokenizer_path=tokenizer_path, |
|
tokenizer_kwargs=tokenizer_kwargs) |
|
if not tokenizer_only: |
|
self._load_model(path=path, model_kwargs=model_kwargs) |
|
|
|
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList |
|
self.logits_processor = LogitsProcessorList() |
|
self.stopping_criteria = StoppingCriteriaList() |
|
self.prefix_allowed_tokens_fn = None |
|
|
|
stop_words_id = [] |
|
if self.gen_params.get('stop_words_id'): |
|
stop_words_id = self.gen_params.get('stop_words_id') |
|
elif self.gen_params.get('stop_words'): |
|
for sw in self.gen_params.get('stop_words'): |
|
stop_words_id.append(self.tokenizer(sw)['input_ids'][-1]) |
|
self.additional_eos_token_id = stop_words_id |
|
|
|
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], |
|
tokenizer_kwargs: dict): |
|
from transformers import AutoTokenizer |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
tokenizer_path if tokenizer_path else path, |
|
trust_remote_code=True, |
|
**tokenizer_kwargs) |
|
|
|
if self.tokenizer.pad_token_id is None: |
|
if self.tokenizer.eos_token is not None: |
|
logger.warning( |
|
f'Using eos_token_id {self.tokenizer.eos_token} ' |
|
'as pad_token_id.') |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
else: |
|
from transformers.generation import GenerationConfig |
|
self.gcfg = GenerationConfig.from_pretrained(path) |
|
|
|
if self.gcfg.pad_token_id is not None: |
|
logger.warning( |
|
f'Using pad_token_id {self.gcfg.pad_token_id} ' |
|
'as pad_token_id.') |
|
self.tokenizer.pad_token_id = self.gcfg.pad_token_id |
|
else: |
|
raise ValueError( |
|
'pad_token_id is not set for this tokenizer. Try to ' |
|
'set pad_token_id via passing ' |
|
'`pad_token_id={PAD_TOKEN_ID}` in model_cfg.') |
|
|
|
def _load_model(self, path: str, model_kwargs: dict): |
|
import torch |
|
from transformers import AutoModel |
|
model_kwargs.setdefault('torch_dtype', torch.float16) |
|
self.model = AutoModel.from_pretrained( |
|
path, trust_remote_code=True, **model_kwargs) |
|
self.model.eval() |
|
|
|
def tokenize(self, inputs: str): |
|
assert isinstance(inputs, str) |
|
inputs = self.tokenizer( |
|
inputs, return_tensors='pt', return_length=True) |
|
return inputs['input_ids'].tolist() |
|
|
|
def generate( |
|
self, |
|
inputs: Union[str, List[str]], |
|
do_sample: bool = True, |
|
**kwargs, |
|
): |
|
"""Return the chat completions in non-stream mode. |
|
|
|
Args: |
|
inputs (Union[str, List[str]]): input texts to be completed. |
|
do_sample (bool): do sampling if enabled |
|
Returns: |
|
(a list of/batched) text/chat completion |
|
""" |
|
for status, chunk, _ in self.stream_generate(inputs, do_sample, |
|
**kwargs): |
|
response = chunk |
|
return response |
|
|
|
def stream_generate( |
|
self, |
|
inputs: List[str], |
|
do_sample: bool = True, |
|
**kwargs, |
|
): |
|
"""Return the chat completions in stream mode. |
|
|
|
Args: |
|
inputs (Union[str, List[str]]): input texts to be completed. |
|
do_sample (bool): do sampling if enabled |
|
Returns: |
|
tuple(Status, str, int): status, text/chat completion, |
|
generated token number |
|
""" |
|
import torch |
|
from torch import nn |
|
with torch.no_grad(): |
|
batched = True |
|
if isinstance(inputs, str): |
|
inputs = [inputs] |
|
batched = False |
|
inputs = self.tokenizer( |
|
inputs, padding=True, return_tensors='pt', return_length=True) |
|
input_length = inputs['length'] |
|
for k, v in inputs.items(): |
|
inputs[k] = v.cuda() |
|
input_ids = inputs['input_ids'] |
|
attention_mask = inputs['attention_mask'] |
|
batch_size = input_ids.shape[0] |
|
input_ids_seq_length = input_ids.shape[-1] |
|
generation_config = self.model.generation_config |
|
generation_config = copy.deepcopy(generation_config) |
|
new_gen_params = self.update_gen_params(**kwargs) |
|
generation_config.update(**new_gen_params) |
|
generation_config.update(**kwargs) |
|
model_kwargs = generation_config.to_dict() |
|
model_kwargs['attention_mask'] = attention_mask |
|
_, eos_token_id = ( |
|
generation_config.bos_token_id, |
|
generation_config.eos_token_id, |
|
) |
|
if eos_token_id is None: |
|
if self.gcfg.eos_token_id is not None: |
|
eos_token_id = self.gcfg.eos_token_id |
|
else: |
|
eos_token_id = [] |
|
if isinstance(eos_token_id, int): |
|
eos_token_id = [eos_token_id] |
|
if self.additional_eos_token_id is not None: |
|
eos_token_id.extend(self.additional_eos_token_id) |
|
eos_token_id_tensor = torch.tensor(eos_token_id).to( |
|
input_ids.device) if eos_token_id is not None else None |
|
generation_config.max_length = ( |
|
generation_config.max_new_tokens + input_ids_seq_length) |
|
|
|
logits_processor = self.logits_processor |
|
stopping_criteria = self.stopping_criteria |
|
|
|
logits_processor = self.model._get_logits_processor( |
|
generation_config=generation_config, |
|
input_ids_seq_length=input_ids_seq_length, |
|
encoder_input_ids=input_ids, |
|
prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn, |
|
logits_processor=logits_processor, |
|
) |
|
|
|
stopping_criteria = self.model._get_stopping_criteria( |
|
generation_config=generation_config, |
|
stopping_criteria=stopping_criteria) |
|
logits_warper = self.model._get_logits_warper(generation_config) |
|
|
|
unfinished_sequences = input_ids.new(batch_size).fill_(1) |
|
scores = None |
|
while True: |
|
model_inputs = self.model.prepare_inputs_for_generation( |
|
input_ids, **model_kwargs) |
|
|
|
outputs = self.model( |
|
**model_inputs, |
|
return_dict=True, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
) |
|
|
|
next_token_logits = outputs.logits[:, -1, :] |
|
|
|
|
|
next_token_scores = logits_processor(input_ids, |
|
next_token_logits) |
|
next_token_scores = logits_warper(input_ids, next_token_scores) |
|
|
|
|
|
probs = nn.functional.softmax(next_token_scores, dim=-1) |
|
if do_sample: |
|
next_tokens = torch.multinomial( |
|
probs, num_samples=1).squeeze(1) |
|
else: |
|
next_tokens = torch.argmax(probs, dim=-1) |
|
|
|
|
|
|
|
input_ids = torch.cat([input_ids, next_tokens[:, None]], |
|
dim=-1) |
|
model_kwargs = self.model._update_model_kwargs_for_generation( |
|
outputs, |
|
model_kwargs, |
|
is_encoder_decoder=False) |
|
unfinished_sequences = unfinished_sequences.mul( |
|
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne( |
|
eos_token_id_tensor.unsqueeze(1)).prod(dim=0)) |
|
output_token_ids = input_ids.cpu().tolist() |
|
for i in range(len(output_token_ids)): |
|
output_token_ids[i] = output_token_ids[i][:][ |
|
input_length[i]:] |
|
|
|
|
|
first_eos_idx = next( |
|
(idx |
|
for idx, token_id in enumerate(output_token_ids[i]) |
|
if token_id in eos_token_id), None) |
|
|
|
|
|
if first_eos_idx is not None: |
|
output_token_ids[i] = output_token_ids[ |
|
i][:first_eos_idx] |
|
|
|
response = self.tokenizer.batch_decode(output_token_ids) |
|
|
|
if not batched: |
|
response = response[0] |
|
yield ModelStatusCode.STREAM_ING, response, None |
|
|
|
|
|
if (unfinished_sequences.max() == 0 |
|
or stopping_criteria(input_ids, scores)): |
|
break |
|
yield ModelStatusCode.END, response, None |
|
|
|
def stream_chat( |
|
self, |
|
inputs: List[dict], |
|
do_sample: bool = True, |
|
**kwargs, |
|
): |
|
"""Return the chat completions in stream mode. |
|
|
|
Args: |
|
inputs (List[dict]): input messages to be completed. |
|
do_sample (bool): do sampling if enabled |
|
Returns: |
|
the text/chat completion |
|
""" |
|
prompt = self.template_parser(inputs) |
|
yield from self.stream_generate(prompt, do_sample, **kwargs) |
|
|
|
|
|
class HFTransformerCasualLM(HFTransformer): |
|
|
|
def _load_model(self, path: str, model_kwargs: dict): |
|
import torch |
|
from transformers import AutoModelForCausalLM |
|
model_kwargs.setdefault('torch_dtype', torch.float16) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
path, trust_remote_code=True, **model_kwargs) |
|
self.model.eval() |
|
|
|
|
|
class HFTransformerChat(HFTransformerCasualLM): |
|
|
|
def __init__(self, template_parser=APITemplateParser, **kwargs): |
|
super().__init__(template_parser=template_parser, **kwargs) |
|
|
|
def chat(self, |
|
inputs: Union[List[dict], List[List[dict]]], |
|
do_sample: bool = True, |
|
**kwargs): |
|
"""Return the chat completions in stream mode. |
|
|
|
Args: |
|
inputs (Union[List[dict], List[List[dict]]]): input messages to be completed. |
|
do_sample (bool): do sampling if enabled |
|
Returns: |
|
the text/chat completion |
|
""" |
|
|
|
if isinstance(inputs[0], list): |
|
resps = [] |
|
for input in inputs: |
|
resps.append(self.chat(input, do_sample, **kwargs)) |
|
return resps |
|
prompt = self.template_parser(inputs) |
|
query = prompt[-1]['content'] |
|
history = prompt[:-1] |
|
try: |
|
response, history = self.model.chat( |
|
self.tokenizer, query, history=history) |
|
except Exception as e: |
|
|
|
logger.warning(str(e)) |
|
response = '' |
|
return response |
|
|