import gradio as gr from typing import List, Optional from transformers import AutoTokenizer, BertTokenizer, BartForConditionalGeneration title = "HIT-TMG/dialogue-bart-large-chinese" description = """ This is a seq2seq model pre-trained on several Chinese dialogue datasets, from bart-large-chinese. However it is just a simple demo for this pre-trained model. It's better to fine-tune it on downstream tasks for better performance \n See some details of model card at https://huggingface.co/HIT-TMG/dialogue-bart-large-chinese . \n\n Besides starting the conversation from scratch, you can also input the whole dialogue history utterance by utterance seperated by '[SEP]'. \n """ # tokenizer = BertTokenizer.from_pretrained("HIT-TMG/dialogue-bart-large-chinese") tokenizer = AutoTokenizer.from_pretrained("HIT-TMG/dialogue-bart-large-chinese") model = BartForConditionalGeneration.from_pretrained("HIT-TMG/dialogue-bart-large-chinese") tokenizer.truncation_side = 'left' max_length = 512 examples = [ ["你有什么爱好吗"], ["你好。[SEP]嘿嘿你好,请问你最近在忙什么呢?[SEP]我最近养了一只狗狗,我在训练它呢。"] ] def chat_func(input_utterance: str, history: Optional[List[str]] = None): if history is not None: history.extend(input_utterance.split(tokenizer.sep_token)) else: history = input_utterance.split(tokenizer.sep_token) history_str = "对话历史:" + tokenizer.sep_token.join(history) input_ids = tokenizer(history_str, return_tensors='pt', truncation=True, max_length=max_length, ).input_ids output_ids = model.generate(input_ids, max_new_tokens=30, top_k=32, num_beams=4, repetition_penalty=1.2, no_repeat_ngram_size=4)[0] response = tokenizer.decode(output_ids, skip_special_tokens=True) history.append(response) if len(history) % 2 == 0: display_utterances = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)] else: display_utterances = [("", history[0])] + [(history[i], history[i + 1]) for i in range(1, len(history) - 1, 2)] return display_utterances, history demo = gr.Interface(fn=chat_func, title=title, description=description, inputs=[gr.Textbox(lines=1, placeholder="Input current utterance"), "state"], examples=examples, outputs=["chatbot", "state"]) if __name__ == "__main__": demo.launch() # def chat(history): # history_prefix = "对话历史:" # history = history_prefix + history # # outputs = tokenizer(history, # return_tensors='pt', # padding=True, # truncation=True, # max_length=512) # # input_ids = outputs.input_ids # output_ids = model.generate(input_ids)[0] # # return tokenizer.decode(output_ids, skip_special_tokens=True) # # # chatbot = gr.Chatbot().style(color_map=("green", "pink")) # demo = gr.Interface( # chat, # inputs=gr.Textbox(lines=8, placeholder="输入你的对话历史(请以'[SEP]'作为每段对话的间隔)\nInput the dialogue history (Please split utterances by '[SEP]')"), # title=title, # description=description, # outputs =["text"] # ) # # # if __name__ == "__main__": # demo.launch()