File size: 6,870 Bytes
3aa655b
 
3ef14f8
e153a21
 
 
3aa655b
 
 
3ef14f8
c039d26
e153a21
3aa655b
 
 
 
 
 
e153a21
 
 
 
 
3aa655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e153a21
 
 
3aa655b
 
 
e153a21
 
3aa655b
 
 
 
 
 
 
 
 
 
 
e153a21
3aa655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e153a21
3aa655b
 
0e21bc9
3aa655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e21bc9
 
 
 
 
 
 
 
 
 
 
 
 
 
3aa655b
e153a21
 
 
3aa655b
e153a21
 
3aa655b
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#tuto : https://gradio.app/creating_a_chatbot/

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import re

ckpt = 'armandnlp/gpt2-TOD_finetuned_SGD'
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt)



def format_resp(system_resp):
  # format Belief, Action and Response tags
  system_resp = system_resp.replace('<|belief|>', '*Belief State: ')
  system_resp = system_resp.replace('<|action|>', '*Actions: ')
  system_resp = system_resp.replace('<|response|>', '*System Response: ')
  return system_resp



def predict(input, history=[]):

  if history != []:
    # model expects only user and system responses, no belief or action sequences
    # therefore we clean up the history first.

    # history is  a list of token ids which represents all the previous states in the conversation
    # ie. tokenied user inputs + tokenized model outputs
    history_str = tokenizer.decode(history[0])
    turns = re.split('<\|system\|>|<\|user\|>', history_str)[1:]
    for i in range(0, len(turns)-1, 2):
      turns[i] = '<|user|>' + turns[i]
      # keep only the response part of each system_out in the history (no belief and action)
      turns[i+1] = '<|system|>' + turns[i+1].split('<|response|>')[1]
    history4input = tokenizer.encode(''.join(turns), return_tensors='pt') 
  else:
    history4input = torch.LongTensor(history)

  # format input for model by concatenating <|context|> + history4input + new_input + <|endofcontext|>
  new_user_input_ids = tokenizer.encode(' <|user|> '+input, return_tensors='pt')
  context = tokenizer.encode('<|context|>', return_tensors='pt')
  endofcontext = tokenizer.encode(' <|endofcontext|>', return_tensors='pt')
  model_input = torch.cat([context, history4input, new_user_input_ids, endofcontext], dim=-1)
  
  # generate output
  out = model.generate(model_input, max_length=1024, eos_token_id=50262).tolist()[0]

  # formatting the history
  # leave out endof... tokens
  string_out = tokenizer.decode(out)
  system_out = string_out.split('<|endofcontext|>')[1].replace('<|endofbelief|>', '').replace('<|endofaction|>', '').replace('<|endofresponse|>', '')
  resp_tokenized = tokenizer.encode(' <|system|> '+system_out, return_tensors='pt')
  history = torch.cat([torch.LongTensor(history), new_user_input_ids, resp_tokenized], dim=-1).tolist()
  # history = history + last user input + <|system|> <|belief|> ... <|action|> ... <|response|>...

  # format responses to print out
  # need to output all of the turns, hence why the history must contain belief + action info 
  # even if we have to take it out of the model input
  turns = tokenizer.decode(history[0])
  turns = re.split('<\|system\|>|<\|user\|>', turns)[1:] # list of all the user and system turns until now
  # list of tuples [(user, system), (user, system)...]
  # 1 tuple represents 1 exchange at 1 turn 
  # system resp is formatted with function above to make more readable
  resps = [(turns[i], format_resp(turns[i+1])) for i in range(0, len(turns)-1, 2)] 

  return resps, history



examples = [["I want to book a restaurant for 2 people on Saturday."],
            ["What's the weather in Cambridge today ?"],
            ["I need to find a bus to Boston."],
            ["I want to add an event to my calendar."],
            ["I would like to book a plane ticket to New York."],
            ["I want to find a concert around LA."],
            ["Hi, I'd like to find an apartment in London please."],
            ["Can you find me a hotel room near Seattle please ?"],
            ["I want to watch a film online, a comedy would be nice"],
            ["I want to transfer some money please."],
            ["I want to reserve a movie ticket for tomorrow evening"],
            ["Can you play the song Learning to Fly by Tom Petty ?"],
            ["I need to rent a small car."]
            ]

description = """
This is an interactive window to chat with GPT-2 fine-tuned on the Schema-Guided Dialogues dataset,
in which we find domains such as travel, weather, media, calendar, banking, 
restaurant booking...
"""

article = """
### Model Outputs 
This task-oriented dialogue system is trained end-to-end, following the method detailed in 
[SimpleTOD](https://arxiv.org/pdf/2005.00796.pdf), where GPT-2 is trained by casting task-oriented 
dialogue as a seq2seq task.

From the dialogue history, composed of the previous user and system responses, the model is trained 
to output the belief state, the action decisions and the system response as a sequence.  We show all 
three outputs in this demo : the belief state tracks the user goal (restaurant cuisine : Indian or media 
genre : comedy for ex.), the action decisions show how the system should proceed (restaurants request city 
or media offer title for ex.) and the natural language response provides an output the user can interpret. 

The model responses are *de-lexicalized* : database values in the training set have been replaced with their 
slot names to make the learning process database agnostic.  These slots are meant to later be replaced by actual 
results from a database, using the belief state to issue calls. 

The model is capable of dealing with multiple domains : a list of possible inputs is provided to get the 
conversation going.

### Dataset
The SGD dataset ([blogpost](https://ai.googleblog.com/2019/10/introducing-schema-guided-dialogue.html) and 
[article](https://arxiv.org/pdf/1909.05855.pdf)) contains multiple task domains... Here is a list of some 
of the services and their descriptions from the dataset:
* **Restaurants** : A leading provider for restaurant search and reservations
* **Weather** : Check the weather for any place and any date
* **Buses** : Find a bus to take you to the city you want
* **Calendar** : Calendar service to manage personal events and reservations
* **Flights** : Find your next flight
* **Events** : Get tickets for the coolest concerts and sports in your area
* **Homes** : A widely used service for finding apartments and scheduling visits
* **Hotels** : A popular service for searching and reserving rooms in hotels
* **Media** : A leading provider of movies for searching and watching on-demand
* **Banks** : Manage bank accounts and transfer money
* **Movies** : A go-to provider for finding movies, searching for show times and booking tickets
* **Music** : A popular provider of a wide range of music content for searching and listening
* **RentalCars** : Car rental service with extensive coverage of locations and cars
"""


import gradio as gr

gr.Interface(fn=predict,
             inputs=["text", "state"],
             outputs=["chatbot", "state"],
             title="Chatting with multi task-oriented GPT2",
             examples=examples,
             description=description,
             article=article
             ).launch()