import torch import transformers from .get_device import get_device from .streaming_generation_utils import Iteratorize, Stream def generate( # model model, tokenizer, # input prompt, generation_config, max_new_tokens, stopping_criteria=[], # output options stream_output=False ): device = get_device() inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"].to(device) generate_params = { "input_ids": input_ids, "generation_config": generation_config, "return_dict_in_generate": True, "output_scores": True, "max_new_tokens": max_new_tokens, "stopping_criteria": transformers.StoppingCriteriaList() + stopping_criteria } if stream_output: # Stream the reply 1 token at a time. # This is based on the trick of using 'stopping_criteria' to create an iterator, # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243. def generate_with_callback(callback=None, **kwargs): kwargs["stopping_criteria"].insert( 0, Stream(callback_func=callback) ) with torch.no_grad(): model.generate(**kwargs) def generate_with_streaming(**kwargs): return Iteratorize( generate_with_callback, kwargs, callback=None ) with generate_with_streaming(**generate_params) as generator: for output in generator: decoded_output = tokenizer.decode(output, skip_special_tokens=True) yield decoded_output, output if output[-1] in [tokenizer.eos_token_id]: break return # early return for stream_output # Without streaming with torch.no_grad(): generation_output = model.generate(**generate_params) output = generation_output.sequences[0] decoded_output = tokenizer.decode(output, skip_special_tokens=True) yield decoded_output, output return