ChrisGeishauser commited on
Commit
9008d50
·
1 Parent(s): 43fdfbe

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. config_saved.json +1 -0
  3. supervised.pol.mdl +3 -0
  4. train_INFO.log +345 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ supervised.pol.mdl filter=lfs diff=lfs merge=lfs -text
config_saved.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"args": {"seed": 0, "eval_freq": 2, "dataset_name": "multiwoz21", "model_path": "experiments/seed0/save/supervised.pol.mdl"}, "config": {"batchsz": 64, "epoch": 40, "gamma": 0.99, "policy_lr": 5e-06, "supervised_lr": 1e-05, "entropy_weight": 0.01, "value_lr": 0.0001, "save_dir": "save", "log_dir": "log", "save_per_epoch": 5000, "hidden_size": 256, "load": "save/best", "logging_mode": "INFO", "use_cer": true, "memory_size": 5000, "behaviour_cloning_weight": 0.1, "supervised_weight": 0.0, "online_offline_ratio": 0.2, "smoothed_value_function": false, "use_reservoir_sampling": false, "seed": 0, "lambda": 1, "tau": 0.001, "policy_freq": 1, "print_per_batch": 400, "c": 1.0, "rho_bar": 1, "max_length": 10, "noisy_linear": false, "dataset_name": "multiwoz21", "data_percentage": 0.01, "dialogue_order": 0, "multiwoz_like": false, "regularization_weight": 0.0, "enc_input_dim": 128, "enc_nhead": 2, "enc_d_hid": 128, "enc_nlayers": 4, "enc_dropout": 0.1, "dec_input_dim": 128, "dec_nhead": 2, "dec_d_hid": 128, "dec_nlayers": 2, "dec_dropout": 0.0, "action_embedding_dim": 128, "domain_embedding_dim": 64, "value_embedding_dim": 12, "node_embedding_dim": 128, "roberta_path": "", "node_attention": true, "semantic_descriptions": true, "freeze_roberta": true, "use_pooled": false, "mean": true, "roberta_actions": true, "independent_descriptions": true, "random_matrix": false, "distance_metric": false, "verbose": false, "ignore_features": [], "domains_removed": ["hospital", "police", "train", "hotel", "attraction", "taxi"], "only_active_values": false, "permuted_data": false, "need_weights": false, "cls_dim": 128, "independent": true, "old_critic": false, "pos_weight": 5, "weight_decay": 1e-05}, "policy_config": null}
supervised.pol.mdl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:167f64fd660907849c157f0423778600b6613ce4c6fc98247484c0e279b36206
3
+ size 9331458
train_INFO.log ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Visible device: cuda
2
+ Seed used: 0
3
+ Batch size: 64
4
+ Epochs: 40
5
+ Learning rate: 1e-05
6
+ Entropy weight: 0.01
7
+ Regularization weight: 0.0
8
+ Only use multiwoz like domains: False
9
+ We use: 1.0% of the data
10
+ Dialogue order used: 0
11
+ Vectorizer: Data set used is multiwoz21
12
+ We filter state by active domains: True
13
+ Vectorizer: Data set used is multiwoz21
14
+ Embedding semantic descriptions: True
15
+ Embedded descriptions successfully. Size: torch.Size([338, 768])
16
+ Data set used for descriptions: multiwoz21
17
+ We use Roberta to embed actions.
18
+ Loaded model from experiments/seed0/save/supervised.pol.mdl
19
+ Start training
20
+ Epoch: 0
21
+ Average actions: 1.9973957538604736
22
+ Average target actions: 2.5520834922790527
23
+ Precision: 0.09615384615384616
24
+ Recall: 0.07462686567164178
25
+ F1: 0.08403361344537816
26
+ <<dialog policy>> epoch 0: saved network to mdl
27
+ Best Precision: 0.09615384615384616
28
+ Best Recall: 0.07462686567164178
29
+ Best F1: 0.08403361344537816
30
+ Epoch: 1
31
+ Precision: 0.09615384615384616
32
+ Recall: 0.07462686567164178
33
+ F1: 0.08403361344537816
34
+ Best Precision: 0.09615384615384616
35
+ Best Recall: 0.07462686567164178
36
+ Best F1: 0.08403361344537816
37
+ Epoch: 2
38
+ Average actions: 2.3515625
39
+ Average target actions: 2.6197917461395264
40
+ Precision: 0.10526315789473684
41
+ Recall: 0.08955223880597014
42
+ F1: 0.0967741935483871
43
+ <<dialog policy>> epoch 2: saved network to mdl
44
+ Best Precision: 0.10526315789473684
45
+ Best Recall: 0.08955223880597014
46
+ Best F1: 0.0967741935483871
47
+ Epoch: 3
48
+ Precision: 0.10526315789473684
49
+ Recall: 0.08955223880597014
50
+ F1: 0.0967741935483871
51
+ Best Precision: 0.10526315789473684
52
+ Best Recall: 0.08955223880597014
53
+ Best F1: 0.0967741935483871
54
+ Epoch: 4
55
+ Average actions: 1.6770832538604736
56
+ Average target actions: 2.8567709922790527
57
+ Precision: 0.1347517730496454
58
+ Recall: 0.0945273631840796
59
+ F1: 0.11111111111111112
60
+ <<dialog policy>> epoch 4: saved network to mdl
61
+ Best Precision: 0.1347517730496454
62
+ Best Recall: 0.0945273631840796
63
+ Best F1: 0.11111111111111112
64
+ Epoch: 5
65
+ Precision: 0.1347517730496454
66
+ Recall: 0.0945273631840796
67
+ F1: 0.11111111111111112
68
+ Best Precision: 0.1347517730496454
69
+ Best Recall: 0.0945273631840796
70
+ Best F1: 0.11111111111111112
71
+ Epoch: 6
72
+ Average actions: 1.9088542461395264
73
+ Average target actions: 2.7213542461395264
74
+ Precision: 0.12080536912751678
75
+ Recall: 0.08955223880597014
76
+ F1: 0.10285714285714286
77
+ Best Precision: 0.1347517730496454
78
+ Best Recall: 0.0945273631840796
79
+ Best F1: 0.11111111111111112
80
+ Epoch: 7
81
+ Precision: 0.12080536912751678
82
+ Recall: 0.08955223880597014
83
+ F1: 0.10285714285714286
84
+ Best Precision: 0.1347517730496454
85
+ Best Recall: 0.0945273631840796
86
+ Best F1: 0.11111111111111112
87
+ Epoch: 8
88
+ Average actions: 2.0572915077209473
89
+ Average target actions: 2.8229167461395264
90
+ Precision: 0.12903225806451613
91
+ Recall: 0.09950248756218906
92
+ F1: 0.11235955056179776
93
+ <<dialog policy>> epoch 8: saved network to mdl
94
+ Best Precision: 0.1347517730496454
95
+ Best Recall: 0.09950248756218906
96
+ Best F1: 0.11235955056179776
97
+ Epoch: 9
98
+ Precision: 0.12903225806451613
99
+ Recall: 0.09950248756218906
100
+ F1: 0.11235955056179776
101
+ Best Precision: 0.1347517730496454
102
+ Best Recall: 0.09950248756218906
103
+ Best F1: 0.11235955056179776
104
+ Epoch: 10
105
+ Average actions: 2.0911459922790527
106
+ Average target actions: 2.6875
107
+ Precision: 0.11612903225806452
108
+ Recall: 0.08955223880597014
109
+ F1: 0.10112359550561797
110
+ Best Precision: 0.1347517730496454
111
+ Best Recall: 0.09950248756218906
112
+ Best F1: 0.11235955056179776
113
+ Epoch: 11
114
+ Precision: 0.11612903225806452
115
+ Recall: 0.08955223880597014
116
+ F1: 0.10112359550561797
117
+ Best Precision: 0.1347517730496454
118
+ Best Recall: 0.09950248756218906
119
+ Best F1: 0.11235955056179776
120
+ Epoch: 12
121
+ Average actions: 2.0833332538604736
122
+ Average target actions: 2.5859375
123
+ Precision: 0.11976047904191617
124
+ Recall: 0.09950248756218906
125
+ F1: 0.10869565217391305
126
+ Best Precision: 0.1347517730496454
127
+ Best Recall: 0.09950248756218906
128
+ Best F1: 0.11235955056179776
129
+ Epoch: 13
130
+ Precision: 0.11976047904191617
131
+ Recall: 0.09950248756218906
132
+ F1: 0.10869565217391305
133
+ Best Precision: 0.1347517730496454
134
+ Best Recall: 0.09950248756218906
135
+ Best F1: 0.11235955056179776
136
+ Epoch: 14
137
+ Average actions: 2.1119790077209473
138
+ Average target actions: 2.7213542461395264
139
+ Precision: 0.16778523489932887
140
+ Recall: 0.12437810945273632
141
+ F1: 0.14285714285714285
142
+ <<dialog policy>> epoch 14: saved network to mdl
143
+ Best Precision: 0.16778523489932887
144
+ Best Recall: 0.12437810945273632
145
+ Best F1: 0.14285714285714285
146
+ Epoch: 15
147
+ Precision: 0.16778523489932887
148
+ Recall: 0.12437810945273632
149
+ F1: 0.14285714285714285
150
+ Best Precision: 0.16778523489932887
151
+ Best Recall: 0.12437810945273632
152
+ Best F1: 0.14285714285714285
153
+ Epoch: 16
154
+ Average actions: 1.7994792461395264
155
+ Average target actions: 2.5520834922790527
156
+ Precision: 0.10135135135135136
157
+ Recall: 0.07462686567164178
158
+ F1: 0.08595988538681948
159
+ Best Precision: 0.16778523489932887
160
+ Best Recall: 0.12437810945273632
161
+ Best F1: 0.14285714285714285
162
+ Epoch: 17
163
+ Precision: 0.10135135135135136
164
+ Recall: 0.07462686567164178
165
+ F1: 0.08595988538681948
166
+ Best Precision: 0.16778523489932887
167
+ Best Recall: 0.12437810945273632
168
+ Best F1: 0.14285714285714285
169
+ Epoch: 18
170
+ Average actions: 2.0572915077209473
171
+ Average target actions: 2.7552084922790527
172
+ Precision: 0.13548387096774195
173
+ Recall: 0.1044776119402985
174
+ F1: 0.11797752808988765
175
+ Best Precision: 0.16778523489932887
176
+ Best Recall: 0.12437810945273632
177
+ Best F1: 0.14285714285714285
178
+ Epoch: 19
179
+ Precision: 0.13548387096774195
180
+ Recall: 0.1044776119402985
181
+ F1: 0.11797752808988765
182
+ Best Precision: 0.16778523489932887
183
+ Best Recall: 0.12437810945273632
184
+ Best F1: 0.14285714285714285
185
+ Epoch: 20
186
+ Average actions: 1.9661457538604736
187
+ Average target actions: 2.7213542461395264
188
+ Precision: 0.1118421052631579
189
+ Recall: 0.0845771144278607
190
+ F1: 0.0963172804532578
191
+ Best Precision: 0.16778523489932887
192
+ Best Recall: 0.12437810945273632
193
+ Best F1: 0.14285714285714285
194
+ Epoch: 21
195
+ Precision: 0.1118421052631579
196
+ Recall: 0.0845771144278607
197
+ F1: 0.0963172804532578
198
+ Best Precision: 0.16778523489932887
199
+ Best Recall: 0.12437810945273632
200
+ Best F1: 0.14285714285714285
201
+ Epoch: 22
202
+ Average actions: 1.9557292461395264
203
+ Average target actions: 2.5520834922790527
204
+ Precision: 0.07741935483870968
205
+ Recall: 0.05970149253731343
206
+ F1: 0.06741573033707865
207
+ Best Precision: 0.16778523489932887
208
+ Best Recall: 0.12437810945273632
209
+ Best F1: 0.14285714285714285
210
+ Epoch: 23
211
+ Precision: 0.07741935483870968
212
+ Recall: 0.05970149253731343
213
+ F1: 0.06741573033707865
214
+ Best Precision: 0.16778523489932887
215
+ Best Recall: 0.12437810945273632
216
+ Best F1: 0.14285714285714285
217
+ Epoch: 24
218
+ Average actions: 2.0833334922790527
219
+ Average target actions: 2.8229167461395264
220
+ Precision: 0.09090909090909091
221
+ Recall: 0.06965174129353234
222
+ F1: 0.07887323943661972
223
+ Best Precision: 0.16778523489932887
224
+ Best Recall: 0.12437810945273632
225
+ Best F1: 0.14285714285714285
226
+ Epoch: 25
227
+ Precision: 0.09090909090909091
228
+ Recall: 0.06965174129353234
229
+ F1: 0.07887323943661972
230
+ Best Precision: 0.16778523489932887
231
+ Best Recall: 0.12437810945273632
232
+ Best F1: 0.14285714285714285
233
+ Epoch: 26
234
+ Average actions: 1.7135417461395264
235
+ Average target actions: 2.6197917461395264
236
+ Precision: 0.145985401459854
237
+ Recall: 0.09950248756218906
238
+ F1: 0.1183431952662722
239
+ Best Precision: 0.16778523489932887
240
+ Best Recall: 0.12437810945273632
241
+ Best F1: 0.14285714285714285
242
+ Epoch: 27
243
+ Precision: 0.145985401459854
244
+ Recall: 0.09950248756218906
245
+ F1: 0.1183431952662722
246
+ Best Precision: 0.16778523489932887
247
+ Best Recall: 0.12437810945273632
248
+ Best F1: 0.14285714285714285
249
+ Epoch: 28
250
+ Average actions: 2.0364584922790527
251
+ Average target actions: 2.5520834922790527
252
+ Precision: 0.16891891891891891
253
+ Recall: 0.12437810945273632
254
+ F1: 0.14326647564469916
255
+ <<dialog policy>> epoch 28: saved network to mdl
256
+ Best Precision: 0.16891891891891891
257
+ Best Recall: 0.12437810945273632
258
+ Best F1: 0.14326647564469916
259
+ Epoch: 29
260
+ Precision: 0.16891891891891891
261
+ Recall: 0.12437810945273632
262
+ F1: 0.14326647564469916
263
+ Best Precision: 0.16891891891891891
264
+ Best Recall: 0.12437810945273632
265
+ Best F1: 0.14326647564469916
266
+ Epoch: 30
267
+ Average actions: 2.0026040077209473
268
+ Average target actions: 2.3828125
269
+ Precision: 0.16216216216216217
270
+ Recall: 0.11940298507462686
271
+ F1: 0.13753581661891118
272
+ Best Precision: 0.16891891891891891
273
+ Best Recall: 0.12437810945273632
274
+ Best F1: 0.14326647564469916
275
+ Epoch: 31
276
+ Precision: 0.16216216216216217
277
+ Recall: 0.11940298507462686
278
+ F1: 0.13753581661891118
279
+ Best Precision: 0.16891891891891891
280
+ Best Recall: 0.12437810945273632
281
+ Best F1: 0.14326647564469916
282
+ Epoch: 32
283
+ Average actions: 1.8046875
284
+ Average target actions: 2.6875
285
+ Precision: 0.12142857142857143
286
+ Recall: 0.0845771144278607
287
+ F1: 0.09970674486803519
288
+ Best Precision: 0.16891891891891891
289
+ Best Recall: 0.12437810945273632
290
+ Best F1: 0.14326647564469916
291
+ Epoch: 33
292
+ Precision: 0.12142857142857143
293
+ Recall: 0.0845771144278607
294
+ F1: 0.09970674486803519
295
+ Best Precision: 0.16891891891891891
296
+ Best Recall: 0.12437810945273632
297
+ Best F1: 0.14326647564469916
298
+ Epoch: 34
299
+ Average actions: 1.9348957538604736
300
+ Average target actions: 2.6875
301
+ Precision: 0.12162162162162163
302
+ Recall: 0.08955223880597014
303
+ F1: 0.10315186246418337
304
+ Best Precision: 0.16891891891891891
305
+ Best Recall: 0.12437810945273632
306
+ Best F1: 0.14326647564469916
307
+ Epoch: 35
308
+ Precision: 0.12162162162162163
309
+ Recall: 0.08955223880597014
310
+ F1: 0.10315186246418337
311
+ Best Precision: 0.16891891891891891
312
+ Best Recall: 0.12437810945273632
313
+ Best F1: 0.14326647564469916
314
+ Epoch: 36
315
+ Average actions: 2.0989584922790527
316
+ Average target actions: 2.484375
317
+ Precision: 0.14743589743589744
318
+ Recall: 0.11442786069651742
319
+ F1: 0.1288515406162465
320
+ Best Precision: 0.16891891891891891
321
+ Best Recall: 0.12437810945273632
322
+ Best F1: 0.14326647564469916
323
+ Epoch: 37
324
+ Precision: 0.14743589743589744
325
+ Recall: 0.11442786069651742
326
+ F1: 0.1288515406162465
327
+ Best Precision: 0.16891891891891891
328
+ Best Recall: 0.12437810945273632
329
+ Best F1: 0.14326647564469916
330
+ Epoch: 38
331
+ Average actions: 2.0260415077209473
332
+ Average target actions: 2.5520834922790527
333
+ Precision: 0.1456953642384106
334
+ Recall: 0.10945273631840796
335
+ F1: 0.12499999999999997
336
+ Best Precision: 0.16891891891891891
337
+ Best Recall: 0.12437810945273632
338
+ Best F1: 0.14326647564469916
339
+ Epoch: 39
340
+ Precision: 0.1456953642384106
341
+ Recall: 0.10945273631840796
342
+ F1: 0.12499999999999997
343
+ Best Precision: 0.16891891891891891
344
+ Best Recall: 0.12437810945273632
345
+ Best F1: 0.14326647564469916