zjowowen commited on
Commit
ab6e4ac
·
1 Parent(s): 3dfe8fb

init space

Browse files
LightZero/.gitignore CHANGED
@@ -741,8 +741,8 @@ develop-eggs/
741
  downloads/
742
  eggs/
743
  .eggs/
744
- lib/
745
- lib64/
746
  parts/
747
  sdist/
748
  var/
@@ -982,11 +982,6 @@ dist
982
  ### VirtualEnv template
983
  # Virtualenv
984
  # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
985
- [Bb]in
986
- [Ii]nclude
987
- [Ll]ib
988
- [Ll]ib64
989
- [Ll]ocal
990
  pyvenv.cfg
991
  pip-selfcheck.json
992
 
@@ -1050,7 +1045,7 @@ Temporary Items
1050
  *.gch
1051
 
1052
  # Libraries
1053
- *.lib
1054
  *.a
1055
  *.la
1056
  *.lo
 
741
  downloads/
742
  eggs/
743
  .eggs/
744
+
745
+
746
  parts/
747
  sdist/
748
  var/
 
982
  ### VirtualEnv template
983
  # Virtualenv
984
  # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
 
 
 
 
 
985
  pyvenv.cfg
986
  pip-selfcheck.json
987
 
 
1045
  *.gch
1046
 
1047
  # Libraries
1048
+
1049
  *.a
1050
  *.la
1051
  *.lo
LightZero/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp ADDED
@@ -0,0 +1,792 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // C++11
2
+
3
+ #include <iostream>
4
+ #include "cnode.h"
5
+ #include <algorithm>
6
+ #include <map>
7
+ #include <cassert>
8
+
9
+ #ifdef _WIN32
10
+ #include "..\..\common_lib\utils.cpp"
11
+ #else
12
+ #include "../../common_lib/utils.cpp"
13
+ #endif
14
+
15
+
16
+ namespace tree
17
+ {
18
+
19
+ CSearchResults::CSearchResults()
20
+ {
21
+ /*
22
+ Overview:
23
+ Initialization of CSearchResults, the default result number is set to 0.
24
+ */
25
+ this->num = 0;
26
+ }
27
+
28
+ CSearchResults::CSearchResults(int num)
29
+ {
30
+ /*
31
+ Overview:
32
+ Initialization of CSearchResults with result number.
33
+ */
34
+ this->num = num;
35
+ for (int i = 0; i < num; ++i)
36
+ {
37
+ this->search_paths.push_back(std::vector<CNode *>());
38
+ }
39
+ }
40
+
41
+ CSearchResults::~CSearchResults() {}
42
+
43
+ //*********************************************************
44
+
45
+ CNode::CNode()
46
+ {
47
+ /*
48
+ Overview:
49
+ Initialization of CNode.
50
+ */
51
+ this->prior = 0;
52
+ this->legal_actions = legal_actions;
53
+
54
+ this->is_reset = 0;
55
+ this->visit_count = 0;
56
+ this->value_sum = 0;
57
+ this->best_action = -1;
58
+ this->to_play = 0;
59
+ this->value_prefix = 0.0;
60
+ this->parent_value_prefix = 0.0;
61
+ }
62
+
63
+ CNode::CNode(float prior, std::vector<int> &legal_actions)
64
+ {
65
+ /*
66
+ Overview:
67
+ Initialization of CNode with prior value and legal actions.
68
+ Arguments:
69
+ - prior: the prior value of this node.
70
+ - legal_actions: a vector of legal actions of this node.
71
+ */
72
+ this->prior = prior;
73
+ this->legal_actions = legal_actions;
74
+
75
+ this->is_reset = 0;
76
+ this->visit_count = 0;
77
+ this->value_sum = 0;
78
+ this->best_action = -1;
79
+ this->to_play = 0;
80
+ this->value_prefix = 0.0;
81
+ this->parent_value_prefix = 0.0;
82
+ this->current_latent_state_index = -1;
83
+ this->batch_index = -1;
84
+ }
85
+
86
+ CNode::~CNode() {}
87
+
88
+ void CNode::expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector<float> &policy_logits)
89
+ {
90
+ /*
91
+ Overview:
92
+ Expand the child nodes of the current node.
93
+ Arguments:
94
+ - to_play: which player to play the game in the current node.
95
+ - current_latent_state_index: the x/first index of hidden state vector of the current node, i.e. the search depth.
96
+ - batch_index: the y/second index of hidden state vector of the current node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``.
97
+ - value_prefix: the value prefix of the current node.
98
+ - policy_logits: the policy logit of the child nodes.
99
+ */
100
+ this->to_play = to_play;
101
+ this->current_latent_state_index = current_latent_state_index;
102
+ this->batch_index = batch_index;
103
+ this->value_prefix = value_prefix;
104
+
105
+ int action_num = policy_logits.size();
106
+ if (this->legal_actions.size() == 0)
107
+ {
108
+ for (int i = 0; i < action_num; ++i)
109
+ {
110
+ this->legal_actions.push_back(i);
111
+ }
112
+ }
113
+ float temp_policy;
114
+ float policy_sum = 0.0;
115
+
116
+ #ifdef _WIN32
117
+ // 创建动态数组
118
+ float* policy = new float[action_num];
119
+ #else
120
+ float policy[action_num];
121
+ #endif
122
+
123
+ float policy_max = FLOAT_MIN;
124
+ for (auto a : this->legal_actions)
125
+ {
126
+ if (policy_max < policy_logits[a])
127
+ {
128
+ policy_max = policy_logits[a];
129
+ }
130
+ }
131
+
132
+ for (auto a : this->legal_actions)
133
+ {
134
+ temp_policy = exp(policy_logits[a] - policy_max);
135
+ policy_sum += temp_policy;
136
+ policy[a] = temp_policy;
137
+ }
138
+
139
+ float prior;
140
+ for (auto a : this->legal_actions)
141
+ {
142
+ prior = policy[a] / policy_sum;
143
+ std::vector<int> tmp_empty;
144
+ this->children[a] = CNode(prior, tmp_empty); // only for muzero/efficient zero, not support alphazero
145
+ }
146
+ #ifdef _WIN32
147
+ // 释放数组内存
148
+ delete[] policy;
149
+ #else
150
+ #endif
151
+ }
152
+
153
+ void CNode::add_exploration_noise(float exploration_fraction, const std::vector<float> &noises)
154
+ {
155
+ /*
156
+ Overview:
157
+ Add a noise to the prior of the child nodes.
158
+ Arguments:
159
+ - exploration_fraction: the fraction to add noise.
160
+ - noises: the vector of noises added to each child node.
161
+ */
162
+ float noise, prior;
163
+ for (int i = 0; i < this->legal_actions.size(); ++i)
164
+ {
165
+ noise = noises[i];
166
+ CNode *child = this->get_child(this->legal_actions[i]);
167
+
168
+ prior = child->prior;
169
+ child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction;
170
+ }
171
+ }
172
+
173
+ float CNode::compute_mean_q(int isRoot, float parent_q, float discount_factor)
174
+ {
175
+ /*
176
+ Overview:
177
+ Compute the mean q value of the current node.
178
+ Arguments:
179
+ - isRoot: whether the current node is a root node.
180
+ - parent_q: the q value of the parent node.
181
+ - discount_factor: the discount_factor of reward.
182
+ */
183
+ float total_unsigned_q = 0.0;
184
+ int total_visits = 0;
185
+ float parent_value_prefix = this->value_prefix;
186
+ for (auto a : this->legal_actions)
187
+ {
188
+ CNode *child = this->get_child(a);
189
+ if (child->visit_count > 0)
190
+ {
191
+ float true_reward = child->value_prefix - parent_value_prefix;
192
+ if (this->is_reset == 1)
193
+ {
194
+ true_reward = child->value_prefix;
195
+ }
196
+ float qsa = true_reward + discount_factor * child->value();
197
+ total_unsigned_q += qsa;
198
+ total_visits += 1;
199
+ }
200
+ }
201
+
202
+ float mean_q = 0.0;
203
+ if (isRoot && total_visits > 0)
204
+ {
205
+ mean_q = (total_unsigned_q) / (total_visits);
206
+ }
207
+ else
208
+ {
209
+ mean_q = (parent_q + total_unsigned_q) / (total_visits + 1);
210
+ }
211
+ return mean_q;
212
+ }
213
+
214
+ void CNode::print_out()
215
+ {
216
+ return;
217
+ }
218
+
219
+ int CNode::expanded()
220
+ {
221
+ /*
222
+ Overview:
223
+ Return whether the current node is expanded.
224
+ */
225
+ return this->children.size() > 0;
226
+ }
227
+
228
+ float CNode::value()
229
+ {
230
+ /*
231
+ Overview:
232
+ Return the estimated value of the current tree.
233
+ */
234
+ float true_value = 0.0;
235
+ if (this->visit_count == 0)
236
+ {
237
+ return true_value;
238
+ }
239
+ else
240
+ {
241
+ true_value = this->value_sum / this->visit_count;
242
+ return true_value;
243
+ }
244
+ }
245
+
246
+ std::vector<int> CNode::get_trajectory()
247
+ {
248
+ /*
249
+ Overview:
250
+ Find the current best trajectory starts from the current node.
251
+ Outputs:
252
+ - traj: a vector of node index, which is the current best trajectory from this node.
253
+ */
254
+ std::vector<int> traj;
255
+
256
+ CNode *node = this;
257
+ int best_action = node->best_action;
258
+ while (best_action >= 0)
259
+ {
260
+ traj.push_back(best_action);
261
+
262
+ node = node->get_child(best_action);
263
+ best_action = node->best_action;
264
+ }
265
+ return traj;
266
+ }
267
+
268
+ std::vector<int> CNode::get_children_distribution()
269
+ {
270
+ /*
271
+ Overview:
272
+ Get the distribution of child nodes in the format of visit_count.
273
+ Outputs:
274
+ - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
275
+ */
276
+ std::vector<int> distribution;
277
+ if (this->expanded())
278
+ {
279
+ for (auto a : this->legal_actions)
280
+ {
281
+ CNode *child = this->get_child(a);
282
+ distribution.push_back(child->visit_count);
283
+ }
284
+ }
285
+ return distribution;
286
+ }
287
+
288
+ CNode *CNode::get_child(int action)
289
+ {
290
+ /*
291
+ Overview:
292
+ Get the child node corresponding to the input action.
293
+ Arguments:
294
+ - action: the action to get child.
295
+ */
296
+ return &(this->children[action]);
297
+ }
298
+
299
+ //*********************************************************
300
+
301
+ CRoots::CRoots()
302
+ {
303
+ /*
304
+ Overview:
305
+ The initialization of CRoots.
306
+ */
307
+ this->root_num = 0;
308
+ }
309
+
310
+ CRoots::CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list)
311
+ {
312
+ /*
313
+ Overview:
314
+ The initialization of CRoots with root num and legal action lists.
315
+ Arguments:
316
+ - root_num: the number of the current root.
317
+ - legal_action_list: the vector of the legal action of this root.
318
+ */
319
+ this->root_num = root_num;
320
+ this->legal_actions_list = legal_actions_list;
321
+
322
+ for (int i = 0; i < root_num; ++i)
323
+ {
324
+ this->roots.push_back(CNode(0, this->legal_actions_list[i]));
325
+ }
326
+ }
327
+
328
+ CRoots::~CRoots() {}
329
+
330
+ void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
331
+ {
332
+ /*
333
+ Overview:
334
+ Expand the roots and add noises.
335
+ Arguments:
336
+ - root_noise_weight: the exploration fraction of roots
337
+ - noises: the vector of noise add to the roots.
338
+ - value_prefixs: the vector of value prefixs of each root.
339
+ - policies: the vector of policy logits of each root.
340
+ - to_play_batch: the vector of the player side of each root.
341
+ */
342
+ for (int i = 0; i < this->root_num; ++i)
343
+ {
344
+ this->roots[i].expand(to_play_batch[i], 0, i, value_prefixs[i], policies[i]);
345
+ this->roots[i].add_exploration_noise(root_noise_weight, noises[i]);
346
+ this->roots[i].visit_count += 1;
347
+ }
348
+ }
349
+
350
+ void CRoots::prepare_no_noise(const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
351
+ {
352
+ /*
353
+ Overview:
354
+ Expand the roots without noise.
355
+ Arguments:
356
+ - value_prefixs: the vector of value prefixs of each root.
357
+ - policies: the vector of policy logits of each root.
358
+ - to_play_batch: the vector of the player side of each root.
359
+ */
360
+ for (int i = 0; i < this->root_num; ++i)
361
+ {
362
+ this->roots[i].expand(to_play_batch[i], 0, i, value_prefixs[i], policies[i]);
363
+ this->roots[i].visit_count += 1;
364
+ }
365
+ }
366
+
367
+ void CRoots::clear()
368
+ {
369
+ /*
370
+ Overview:
371
+ Clear the roots vector.
372
+ */
373
+ this->roots.clear();
374
+ }
375
+
376
+ std::vector<std::vector<int> > CRoots::get_trajectories()
377
+ {
378
+ /*
379
+ Overview:
380
+ Find the current best trajectory starts from each root.
381
+ Outputs:
382
+ - traj: a vector of node index, which is the current best trajectory from each root.
383
+ */
384
+ std::vector<std::vector<int> > trajs;
385
+ trajs.reserve(this->root_num);
386
+
387
+ for (int i = 0; i < this->root_num; ++i)
388
+ {
389
+ trajs.push_back(this->roots[i].get_trajectory());
390
+ }
391
+ return trajs;
392
+ }
393
+
394
+ std::vector<std::vector<int> > CRoots::get_distributions()
395
+ {
396
+ /*
397
+ Overview:
398
+ Get the children distribution of each root.
399
+ Outputs:
400
+ - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
401
+ */
402
+ std::vector<std::vector<int> > distributions;
403
+ distributions.reserve(this->root_num);
404
+
405
+ for (int i = 0; i < this->root_num; ++i)
406
+ {
407
+ distributions.push_back(this->roots[i].get_children_distribution());
408
+ }
409
+ return distributions;
410
+ }
411
+
412
+ std::vector<float> CRoots::get_values()
413
+ {
414
+ /*
415
+ Overview:
416
+ Return the estimated value of each root.
417
+ */
418
+ std::vector<float> values;
419
+ for (int i = 0; i < this->root_num; ++i)
420
+ {
421
+ values.push_back(this->roots[i].value());
422
+ }
423
+ return values;
424
+ }
425
+
426
+ //*********************************************************
427
+ //
428
+ void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players)
429
+ {
430
+ /*
431
+ Overview:
432
+ Update the q value of the root and its child nodes.
433
+ Arguments:
434
+ - root: the root that update q value from.
435
+ - min_max_stats: a tool used to min-max normalize the q value.
436
+ - discount_factor: the discount factor of reward.
437
+ - players: the number of players.
438
+ */
439
+ std::stack<CNode *> node_stack;
440
+ node_stack.push(root);
441
+ float parent_value_prefix = 0.0;
442
+ int is_reset = 0;
443
+ while (node_stack.size() > 0)
444
+ {
445
+ CNode *node = node_stack.top();
446
+ node_stack.pop();
447
+
448
+ if (node != root)
449
+ {
450
+ // NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
451
+ // but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
452
+ // true_reward = node.value_prefix - (- parent_value_prefix)
453
+ float true_reward = node->value_prefix - node->parent_value_prefix;
454
+
455
+ if (is_reset == 1)
456
+ {
457
+ true_reward = node->value_prefix;
458
+ }
459
+ float qsa;
460
+ if (players == 1)
461
+ {
462
+ qsa = true_reward + discount_factor * node->value();
463
+ }
464
+ else if (players == 2)
465
+ {
466
+ // TODO(pu): why only the last reward multiply the discount_factor?
467
+ qsa = true_reward + discount_factor * (-1) * node->value();
468
+ }
469
+
470
+ min_max_stats.update(qsa);
471
+ }
472
+
473
+ for (auto a : node->legal_actions)
474
+ {
475
+ CNode *child = node->get_child(a);
476
+ if (child->expanded())
477
+ {
478
+ child->parent_value_prefix = node->value_prefix;
479
+ node_stack.push(child);
480
+ }
481
+ }
482
+
483
+ is_reset = node->is_reset;
484
+ }
485
+ }
486
+
487
+ void cbackpropagate(std::vector<CNode *> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor)
488
+ {
489
+ /*
490
+ Overview:
491
+ Update the value sum and visit count of nodes along the search path.
492
+ Arguments:
493
+ - search_path: a vector of nodes on the search path.
494
+ - min_max_stats: a tool used to min-max normalize the q value.
495
+ - to_play: which player to play the game in the current node.
496
+ - value: the value to propagate along the search path.
497
+ - discount_factor: the discount factor of reward.
498
+ */
499
+ assert(to_play == -1 || to_play == 1 || to_play == 2);
500
+ if (to_play == -1)
501
+ {
502
+ // for play-with-bot-mode
503
+ float bootstrap_value = value;
504
+ int path_len = search_path.size();
505
+ for (int i = path_len - 1; i >= 0; --i)
506
+ {
507
+ CNode *node = search_path[i];
508
+ node->value_sum += bootstrap_value;
509
+ node->visit_count += 1;
510
+
511
+ float parent_value_prefix = 0.0;
512
+ int is_reset = 0;
513
+ if (i >= 1)
514
+ {
515
+ CNode *parent = search_path[i - 1];
516
+ parent_value_prefix = parent->value_prefix;
517
+ is_reset = parent->is_reset;
518
+ }
519
+
520
+ float true_reward = node->value_prefix - parent_value_prefix;
521
+ min_max_stats.update(true_reward + discount_factor * node->value());
522
+
523
+ if (is_reset == 1)
524
+ {
525
+ // parent is reset
526
+ true_reward = node->value_prefix;
527
+ }
528
+
529
+ bootstrap_value = true_reward + discount_factor * bootstrap_value;
530
+ }
531
+ }
532
+ else
533
+ {
534
+ // for self-play-mode
535
+ float bootstrap_value = value;
536
+ int path_len = search_path.size();
537
+ for (int i = path_len - 1; i >= 0; --i)
538
+ {
539
+ CNode *node = search_path[i];
540
+ if (node->to_play == to_play)
541
+ {
542
+ node->value_sum += bootstrap_value;
543
+ }
544
+ else
545
+ {
546
+ node->value_sum += -bootstrap_value;
547
+ }
548
+ node->visit_count += 1;
549
+
550
+ float parent_value_prefix = 0.0;
551
+ int is_reset = 0;
552
+ if (i >= 1)
553
+ {
554
+ CNode *parent = search_path[i - 1];
555
+ parent_value_prefix = parent->value_prefix;
556
+ is_reset = parent->is_reset;
557
+ }
558
+
559
+ // NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
560
+ // but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
561
+ float true_reward = node->value_prefix - parent_value_prefix;
562
+
563
+ min_max_stats.update(true_reward + discount_factor * node->value());
564
+
565
+ if (is_reset == 1)
566
+ {
567
+ // parent is reset
568
+ true_reward = node->value_prefix;
569
+ }
570
+ if (node->to_play == to_play)
571
+ {
572
+ bootstrap_value = -true_reward + discount_factor * bootstrap_value;
573
+ }
574
+ else
575
+ {
576
+ bootstrap_value = true_reward + discount_factor * bootstrap_value;
577
+ }
578
+ }
579
+ }
580
+ }
581
+
582
+ void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> is_reset_list, std::vector<int> &to_play_batch)
583
+ {
584
+ /*
585
+ Overview:
586
+ Expand the nodes along the search path and update the infos.
587
+ Arguments:
588
+ - current_latent_state_index: The index of latent state of the leaf node in the search path.
589
+ - discount_factor: the discount factor of reward.
590
+ - value_prefixs: the value prefixs of nodes along the search path.
591
+ - values: the values to propagate along the search path.
592
+ - policies: the policy logits of nodes along the search path.
593
+ - min_max_stats: a tool used to min-max normalize the q value.
594
+ - results: the search results.
595
+ - is_reset_list: the vector of is_reset nodes along the search path, where is_reset represents for whether the parent value prefix needs to be reset.
596
+ - to_play_batch: the batch of which player is playing on this node.
597
+ */
598
+ for (int i = 0; i < results.num; ++i)
599
+ {
600
+ results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[i], policies[i]);
601
+ // reset
602
+ results.nodes[i]->is_reset = is_reset_list[i];
603
+
604
+ cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[i], discount_factor);
605
+ }
606
+ }
607
+
608
+ int cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players)
609
+ {
610
+ /*
611
+ Overview:
612
+ Select the child node of the roots according to ucb scores.
613
+ Arguments:
614
+ - root: the roots to select the child node.
615
+ - min_max_stats: a tool used to min-max normalize the score.
616
+ - pb_c_base: constants c2 in muzero.
617
+ - pb_c_init: constants c1 in muzero.
618
+ - disount_factor: the discount factor of reward.
619
+ - mean_q: the mean q value of the parent node.
620
+ - players: the number of players.
621
+ Outputs:
622
+ - action: the action to select.
623
+ */
624
+ float max_score = FLOAT_MIN;
625
+ const float epsilon = 0.000001;
626
+ std::vector<int> max_index_lst;
627
+ for (auto a : root->legal_actions)
628
+ {
629
+ CNode *child = root->get_child(a);
630
+ float temp_score = cucb_score(child, min_max_stats, mean_q, root->is_reset, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players);
631
+
632
+ if (max_score < temp_score)
633
+ {
634
+ max_score = temp_score;
635
+
636
+ max_index_lst.clear();
637
+ max_index_lst.push_back(a);
638
+ }
639
+ else if (temp_score >= max_score - epsilon)
640
+ {
641
+ max_index_lst.push_back(a);
642
+ }
643
+ }
644
+
645
+ int action = 0;
646
+ if (max_index_lst.size() > 0)
647
+ {
648
+ int rand_index = rand() % max_index_lst.size();
649
+ action = max_index_lst[rand_index];
650
+ }
651
+ return action;
652
+ }
653
+
654
+ float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players)
655
+ {
656
+ /*
657
+ Overview:
658
+ Compute the ucb score of the child.
659
+ Arguments:
660
+ - child: the child node to compute ucb score.
661
+ - min_max_stats: a tool used to min-max normalize the score.
662
+ - parent_mean_q: the mean q value of the parent node.
663
+ - is_reset: whether the value prefix needs to be reset.
664
+ - total_children_visit_counts: the total visit counts of the child nodes of the parent node.
665
+ - parent_value_prefix: the value prefix of parent node.
666
+ - pb_c_base: constants c2 in muzero.
667
+ - pb_c_init: constants c1 in muzero.
668
+ - disount_factor: the discount factor of reward.
669
+ - players: the number of players.
670
+ Outputs:
671
+ - ucb_value: the ucb score of the child.
672
+ */
673
+ float pb_c = 0.0, prior_score = 0.0, value_score = 0.0;
674
+ pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init;
675
+ pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1));
676
+
677
+ prior_score = pb_c * child->prior;
678
+ if (child->visit_count == 0)
679
+ {
680
+ value_score = parent_mean_q;
681
+ }
682
+ else
683
+ {
684
+ float true_reward = child->value_prefix - parent_value_prefix;
685
+ if (is_reset == 1)
686
+ {
687
+ true_reward = child->value_prefix;
688
+ }
689
+
690
+ if (players == 1)
691
+ {
692
+ value_score = true_reward + discount_factor * child->value();
693
+ }
694
+ else if (players == 2)
695
+ {
696
+ value_score = true_reward + discount_factor * (-child->value());
697
+ }
698
+ }
699
+
700
+ value_score = min_max_stats.normalize(value_score);
701
+
702
+ if (value_score < 0)
703
+ {
704
+ value_score = 0;
705
+ }
706
+ else if (value_score > 1)
707
+ {
708
+ value_score = 1;
709
+ }
710
+
711
+ return prior_score + value_score; // ucb_value
712
+ }
713
+
714
+ void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch)
715
+ {
716
+ /*
717
+ Overview:
718
+ Search node path from the roots.
719
+ Arguments:
720
+ - roots: the roots that search from.
721
+ - pb_c_base: constants c2 in muzero.
722
+ - pb_c_init: constants c1 in muzero.
723
+ - disount_factor: the discount factor of reward.
724
+ - min_max_stats: a tool used to min-max normalize the score.
725
+ - results: the search results.
726
+ - virtual_to_play_batch: the batch of which player is playing on this node.
727
+ */
728
+ // set seed
729
+ get_time_and_set_rand_seed();
730
+
731
+ int last_action = -1;
732
+ float parent_q = 0.0;
733
+ results.search_lens = std::vector<int>();
734
+
735
+ int players = 0;
736
+ int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2
737
+ if (largest_element == -1)
738
+ {
739
+ players = 1;
740
+ }
741
+ else
742
+ {
743
+ players = 2;
744
+ }
745
+
746
+ for (int i = 0; i < results.num; ++i)
747
+ {
748
+ CNode *node = &(roots->roots[i]);
749
+ int is_root = 1;
750
+ int search_len = 0;
751
+ results.search_paths[i].push_back(node);
752
+
753
+ while (node->expanded())
754
+ {
755
+ float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor);
756
+ is_root = 0;
757
+ parent_q = mean_q;
758
+
759
+ int action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players);
760
+ if (players > 1)
761
+ {
762
+ assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2);
763
+ if (virtual_to_play_batch[i] == 1)
764
+ {
765
+ virtual_to_play_batch[i] = 2;
766
+ }
767
+ else
768
+ {
769
+ virtual_to_play_batch[i] = 1;
770
+ }
771
+ }
772
+
773
+ node->best_action = action;
774
+ // next
775
+ node = node->get_child(action);
776
+ last_action = action;
777
+ results.search_paths[i].push_back(node);
778
+ search_len += 1;
779
+ }
780
+
781
+ CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2];
782
+
783
+ results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index);
784
+ results.latent_state_index_in_batch.push_back(parent->batch_index);
785
+
786
+ results.last_actions.push_back(last_action);
787
+ results.search_lens.push_back(search_len);
788
+ results.nodes.push_back(node);
789
+ results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
790
+ }
791
+ }
792
+ }
LightZero/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.h ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // C++11
2
+
3
+ #ifndef CNODE_H
4
+ #define CNODE_H
5
+
6
+ #include "../../common_lib/cminimax.h"
7
+ #include <math.h>
8
+ #include <vector>
9
+ #include <stack>
10
+ #include <stdlib.h>
11
+ #include <time.h>
12
+ #include <cmath>
13
+ #include <sys/timeb.h>
14
+ #include <time.h>
15
+ #include <map>
16
+
17
+ const int DEBUG_MODE = 0;
18
+
19
+ namespace tree {
20
+ class CNode {
21
+ public:
22
+ int visit_count, to_play, current_latent_state_index, batch_index, best_action, is_reset;
23
+ float value_prefix, prior, value_sum;
24
+ float parent_value_prefix;
25
+ std::vector<int> children_index;
26
+ std::map<int, CNode> children;
27
+
28
+ std::vector<int> legal_actions;
29
+
30
+ CNode();
31
+ CNode(float prior, std::vector<int> &legal_actions);
32
+ ~CNode();
33
+
34
+ void expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector<float> &policy_logits);
35
+ void add_exploration_noise(float exploration_fraction, const std::vector<float> &noises);
36
+ float compute_mean_q(int isRoot, float parent_q, float discount_factor);
37
+ void print_out();
38
+
39
+ int expanded();
40
+
41
+ float value();
42
+
43
+ std::vector<int> get_trajectory();
44
+ std::vector<int> get_children_distribution();
45
+ CNode* get_child(int action);
46
+ };
47
+
48
+ class CRoots{
49
+ public:
50
+ int root_num;
51
+ std::vector<CNode> roots;
52
+ std::vector<std::vector<int> > legal_actions_list;
53
+
54
+ CRoots();
55
+ CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list);
56
+ ~CRoots();
57
+
58
+ void prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
59
+ void prepare_no_noise(const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
60
+ void clear();
61
+ std::vector<std::vector<int> > get_trajectories();
62
+ std::vector<std::vector<int> > get_distributions();
63
+ std::vector<float> get_values();
64
+ CNode* get_root(int index);
65
+ };
66
+
67
+ class CSearchResults{
68
+ public:
69
+ int num;
70
+ std::vector<int> latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens;
71
+ std::vector<int> virtual_to_play_batchs;
72
+ std::vector<CNode*> nodes;
73
+ std::vector<std::vector<CNode*> > search_paths;
74
+
75
+ CSearchResults();
76
+ CSearchResults(int num);
77
+ ~CSearchResults();
78
+
79
+ };
80
+
81
+
82
+ //*********************************************************
83
+ void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players);
84
+ void cbackpropagate(std::vector<CNode*> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor);
85
+ void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> is_reset_list, std::vector<int> &to_play_batch);
86
+ int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players);
87
+ float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players);
88
+ void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch);
89
+ }
90
+
91
+ #endif
LightZero/lzero/mcts/ctree/ctree_gumbel_muzero/lib/cnode.cpp ADDED
@@ -0,0 +1,1154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // C++11
2
+
3
+ #include <iostream>
4
+ #include "cnode.h"
5
+ #include <algorithm>
6
+ #include <map>
7
+ #include <cmath>
8
+ #include <random>
9
+ #include <numeric>
10
+
11
+ #ifdef _WIN32
12
+ #include "..\..\common_lib\utils.cpp"
13
+ #else
14
+ #include "../../common_lib/utils.cpp"
15
+ #endif
16
+
17
+ namespace tree{
18
+
19
+ CSearchResults::CSearchResults()
20
+ {
21
+ /*
22
+ Overview:
23
+ Initialization of CSearchResults, the default result number is set to 0.
24
+ */
25
+ this->num = 0;
26
+ }
27
+
28
+ CSearchResults::CSearchResults(int num)
29
+ {
30
+ /*
31
+ Overview:
32
+ Initialization of CSearchResults with result number.
33
+ */
34
+ this->num = num;
35
+ for (int i = 0; i < num; ++i)
36
+ {
37
+ this->search_paths.push_back(std::vector<CNode *>());
38
+ }
39
+ }
40
+
41
+ CSearchResults::~CSearchResults(){}
42
+
43
+ //*********************************************************
44
+
45
+ CNode::CNode()
46
+ {
47
+ /*
48
+ Overview:
49
+ Initialization of CNode.
50
+ */
51
+ this->prior = 0;
52
+ this->legal_actions = legal_actions;
53
+
54
+ this->visit_count = 0;
55
+ this->value_sum = 0;
56
+ this->raw_value = 0; // the value network approximation of value
57
+ this->best_action = -1;
58
+ this->to_play = 0;
59
+ this->reward = 0.0;
60
+
61
+ // gumbel muzero related code
62
+ this->gumbel_scale = 10.0;
63
+ this->gumbel_rng=0.0;
64
+ }
65
+
66
+ CNode::CNode(float prior, std::vector<int> &legal_actions)
67
+ {
68
+ /*
69
+ Overview:
70
+ Initialization of CNode with prior value and legal actions.
71
+ Arguments:
72
+ - prior: the prior value of this node.
73
+ - legal_actions: a vector of legal actions of this node.
74
+ */
75
+ this->prior = prior;
76
+ this->legal_actions = legal_actions;
77
+
78
+ this->visit_count = 0;
79
+ this->value_sum = 0;
80
+ this->raw_value = 0; // the value network approximation of value
81
+ this->best_action = -1;
82
+ this->to_play = 0;
83
+ this->current_latent_state_index = -1;
84
+ this->batch_index = -1;
85
+
86
+ // gumbel muzero related code
87
+ this->gumbel_scale = 10.0;
88
+ this->gumbel_rng=0.0;
89
+ this->gumbel = generate_gumbel(this->gumbel_scale, this->gumbel_rng, legal_actions.size());
90
+ }
91
+
92
+ CNode::~CNode(){}
93
+
94
+ void CNode::expand(int to_play, int current_latent_state_index, int batch_index, float reward, float value, const std::vector<float> &policy_logits)
95
+ {
96
+ /*
97
+ Overview:
98
+ Expand the child nodes of the current node.
99
+ Arguments:
100
+ - to_play: which player to play the game in the current node.
101
+ - current_latent_state_index: The index of latent state of the leaf node in the search path of the current node.
102
+ - batch_index: The index of latent state of the leaf node in the search path of the current node.
103
+ - reward: the reward of the current node.
104
+ - value: the value network approximation of current node.
105
+ - policy_logits: the logit of the child nodes.
106
+ */
107
+ this->to_play = to_play;
108
+ this->current_latent_state_index = current_latent_state_index;
109
+ this->batch_index = batch_index;
110
+ this->reward = reward;
111
+ this->raw_value = value;
112
+
113
+ int action_num = policy_logits.size();
114
+ if (this->legal_actions.size() == 0)
115
+ {
116
+ for (int i = 0; i < action_num; ++i)
117
+ {
118
+ this->legal_actions.push_back(i);
119
+ }
120
+ }
121
+ float temp_policy;
122
+ float policy_sum = 0.0;
123
+
124
+ #ifdef _WIN32
125
+ // 创建动态数组
126
+ float* policy = new float[action_num];
127
+ #else
128
+ float policy[action_num];
129
+ #endif
130
+
131
+ float policy_max = FLOAT_MIN;
132
+ for(auto a: this->legal_actions){
133
+ if(policy_max < policy_logits[a]){
134
+ policy_max = policy_logits[a];
135
+ }
136
+ }
137
+
138
+ for(auto a: this->legal_actions){
139
+ temp_policy = exp(policy_logits[a] - policy_max);
140
+ policy_sum += temp_policy;
141
+ policy[a] = temp_policy;
142
+ }
143
+
144
+ float prior;
145
+ for(auto a: this->legal_actions){
146
+ prior = policy[a] / policy_sum;
147
+ std::vector<int> tmp_empty;
148
+ this->children[a] = CNode(prior, tmp_empty); // only for muzero/efficient zero, not support alphazero
149
+ }
150
+
151
+ #ifdef _WIN32
152
+ // 释放数组内存
153
+ delete[] policy;
154
+ #else
155
+ #endif
156
+ }
157
+
158
+ void CNode::add_exploration_noise(float exploration_fraction, const std::vector<float> &noises)
159
+ {
160
+ /*
161
+ Overview:
162
+ Add a noise to the prior of the child nodes.
163
+ Arguments:
164
+ - exploration_fraction: the fraction to add noise.
165
+ - noises: the vector of noises added to each child node.
166
+ */
167
+ float noise, prior;
168
+ for(int i =0; i<this->legal_actions.size(); ++i){
169
+ noise = noises[i];
170
+ CNode* child = this->get_child(this->legal_actions[i]);
171
+
172
+ prior = child->prior;
173
+ child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction;
174
+ }
175
+ }
176
+
177
+ //*********************************************************
178
+ // Gumbel Muzero related code
179
+ //*********************************************************
180
+
181
+ std::vector<float> CNode::get_q(float discount_factor)
182
+ {
183
+ /*
184
+ Overview:
185
+ Compute the q value of the current node.
186
+ Arguments:
187
+ - discount_factor: the discount_factor of reward.
188
+ */
189
+ std::vector<float> child_value;
190
+ for(auto a: this->legal_actions){
191
+ CNode* child = this->get_child(a);
192
+ float true_reward = child->reward;
193
+ float qsa = true_reward + discount_factor * child->value();
194
+ child_value.push_back(qsa);
195
+ }
196
+ return child_value;
197
+ }
198
+
199
+ float CNode::compute_mean_q(int isRoot, float parent_q, float discount_factor)
200
+ {
201
+ /*
202
+ Overview:
203
+ Compute the mean q value of the current node.
204
+ Arguments:
205
+ - isRoot: whether the current node is a root node.
206
+ - parent_q: the q value of the parent node.
207
+ - discount_factor: the discount_factor of reward.
208
+ */
209
+ float total_unsigned_q = 0.0;
210
+ int total_visits = 0;
211
+ for(auto a: this->legal_actions){
212
+ CNode* child = this->get_child(a);
213
+ if(child->visit_count > 0){
214
+ float true_reward = child->reward;
215
+ float qsa = true_reward + discount_factor * child->value();
216
+ total_unsigned_q += qsa;
217
+ total_visits += 1;
218
+ }
219
+ }
220
+
221
+ float mean_q = 0.0;
222
+ if(isRoot && total_visits > 0){
223
+ mean_q = (total_unsigned_q) / (total_visits);
224
+ }
225
+ else{
226
+ mean_q = (parent_q + total_unsigned_q) / (total_visits + 1);
227
+ }
228
+ return mean_q;
229
+ }
230
+
231
+ void CNode::print_out()
232
+ {
233
+ return;
234
+ }
235
+
236
+ int CNode::expanded()
237
+ {
238
+ /*
239
+ Overview:
240
+ Return whether the current node is expanded.
241
+ */
242
+ return this->children.size() > 0;
243
+ }
244
+
245
+ float CNode::value()
246
+ {
247
+ /*
248
+ Overview:
249
+ Return the real value of the current tree.
250
+ */
251
+ float true_value = 0.0;
252
+ if (this->visit_count == 0)
253
+ {
254
+ return true_value;
255
+ }
256
+ else
257
+ {
258
+ true_value = this->value_sum / this->visit_count;
259
+ return true_value;
260
+ }
261
+ }
262
+
263
+ std::vector<int> CNode::get_trajectory()
264
+ {
265
+ /*
266
+ Overview:
267
+ Find the current best trajectory starts from the current node.
268
+ Outputs:
269
+ - traj: a vector of node index, which is the current best trajectory from this node.
270
+ */
271
+ std::vector<int> traj;
272
+
273
+ CNode *node = this;
274
+ int best_action = node->best_action;
275
+ while (best_action >= 0)
276
+ {
277
+ traj.push_back(best_action);
278
+
279
+ node = node->get_child(best_action);
280
+ best_action = node->best_action;
281
+ }
282
+ return traj;
283
+ }
284
+
285
+ std::vector<int> CNode::get_children_distribution()
286
+ {
287
+ /*
288
+ Overview:
289
+ Get the distribution of child nodes in the format of visit_count.
290
+ Outputs:
291
+ - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
292
+ */
293
+ std::vector<int> distribution;
294
+ if (this->expanded())
295
+ {
296
+ for (auto a : this->legal_actions)
297
+ {
298
+ CNode *child = this->get_child(a);
299
+ distribution.push_back(child->visit_count);
300
+ }
301
+ }
302
+ return distribution;
303
+ }
304
+
305
+ //*********************************************************
306
+ // Gumbel Muzero related code
307
+ //*********************************************************
308
+
309
+ std::vector<float> CNode::get_children_value(float discount_factor, int action_space_size)
310
+ {
311
+ /*
312
+ Overview:
313
+ Get the completed value of child nodes.
314
+ Outputs:
315
+ - discount_factor: the discount_factor of reward.
316
+ - action_space_size: the size of action space.
317
+ */
318
+ float infymin = -std::numeric_limits<float>::infinity();
319
+ std::vector<int> child_visit_count;
320
+ std::vector<float> child_prior;
321
+ for(auto a: this->legal_actions){
322
+ CNode* child = this->get_child(a);
323
+ child_visit_count.push_back(child->visit_count);
324
+ child_prior.push_back(child->prior);
325
+ }
326
+ assert(child_visit_count.size()==child_prior.size());
327
+ // compute the completed value
328
+ std::vector<float> completed_qvalues = qtransform_completed_by_mix_value(this, child_visit_count, child_prior, discount_factor);
329
+ std::vector<float> values;
330
+ for (int i=0;i<action_space_size;i++){
331
+ values.push_back(infymin);
332
+ }
333
+ for (int i=0;i<child_prior.size();i++){
334
+ values[this->legal_actions[i]] = completed_qvalues[i];
335
+ }
336
+
337
+ return values;
338
+ }
339
+
340
+ CNode *CNode::get_child(int action)
341
+ {
342
+ /*
343
+ Overview:
344
+ Get the child node corresponding to the input action.
345
+ Arguments:
346
+ - action: the action to get child.
347
+ */
348
+ return &(this->children[action]);
349
+ }
350
+
351
+ //*********************************************************
352
+ // Gumbel Muzero related code
353
+ //*********************************************************
354
+
355
+ std::vector<float> CNode::get_policy(float discount_factor, int action_space_size){
356
+ /*
357
+ Overview:
358
+ Compute the improved policy of the current node.
359
+ Arguments:
360
+ - discount_factor: the discount_factor of reward.
361
+ - action_space_size: the action space size of environment.
362
+ */
363
+ float infymin = -std::numeric_limits<float>::infinity();
364
+ std::vector<int> child_visit_count;
365
+ std::vector<float> child_prior;
366
+ for(auto a: this->legal_actions){
367
+ CNode* child = this->get_child(a);
368
+ child_visit_count.push_back(child->visit_count);
369
+ child_prior.push_back(child->prior);
370
+ }
371
+ assert(child_visit_count.size()==child_prior.size());
372
+ // compute the completed value
373
+ std::vector<float> completed_qvalues = qtransform_completed_by_mix_value(this, child_visit_count, child_prior, discount_factor);
374
+ std::vector<float> probs;
375
+ for (int i=0;i<action_space_size;i++){
376
+ probs.push_back(infymin);
377
+ }
378
+ for (int i=0;i<child_prior.size();i++){
379
+ probs[this->legal_actions[i]] = child_prior[i] + completed_qvalues[i];
380
+ }
381
+
382
+ csoftmax(probs, probs.size());
383
+
384
+ return probs;
385
+ }
386
+
387
+ //*********************************************************
388
+
389
+ CRoots::CRoots()
390
+ {
391
+ /*
392
+ Overview:
393
+ The initialization of CRoots.
394
+ */
395
+ this->root_num = 0;
396
+ }
397
+
398
+ CRoots::CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list)
399
+ {
400
+ /*
401
+ Overview:
402
+ The initialization of CRoots with root num and legal action lists.
403
+ Arguments:
404
+ - root_num: the number of the current root.
405
+ - legal_action_list: the vector of the legal action of this root.
406
+ */
407
+ this->root_num = root_num;
408
+ this->legal_actions_list = legal_actions_list;
409
+
410
+ for (int i = 0; i < root_num; ++i)
411
+ {
412
+ this->roots.push_back(CNode(0, this->legal_actions_list[i]));
413
+ }
414
+ }
415
+
416
+ CRoots::~CRoots() {}
417
+
418
+ void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &rewards, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
419
+ {
420
+ /*
421
+ Overview:
422
+ Expand the roots and add noises.
423
+ Arguments:
424
+ - root_noise_weight: the exploration fraction of roots.
425
+ - noises: the vector of noise add to the roots.
426
+ - rewards: the vector of rewards of each root.
427
+ - values: the vector of values of each root.
428
+ - policies: the vector of policy logits of each root.
429
+ - to_play_batch: the vector of the player side of each root.
430
+ */
431
+ for(int i = 0; i < this->root_num; ++i){
432
+ this->roots[i].expand(to_play_batch[i], 0, i, rewards[i], values[i], policies[i]);
433
+ this->roots[i].add_exploration_noise(root_noise_weight, noises[i]);
434
+
435
+ this->roots[i].visit_count += 1;
436
+ }
437
+ }
438
+
439
+ void CRoots::prepare_no_noise(const std::vector<float> &rewards, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
440
+ {
441
+ /*
442
+ Overview:
443
+ Expand the roots without noise.
444
+ Arguments:
445
+ - rewards: the vector of rewards of each root.
446
+ - values: the vector of values of each root.
447
+ - policies: the vector of policy logits of each root.
448
+ - to_play_batch: the vector of the player side of each root.
449
+ */
450
+ for(int i = 0; i < this->root_num; ++i){
451
+ this->roots[i].expand(to_play_batch[i], 0, i, rewards[i], values[i], policies[i]);
452
+
453
+ this->roots[i].visit_count += 1;
454
+ }
455
+ }
456
+
457
+ void CRoots::clear()
458
+ {
459
+ /*
460
+ Overview:
461
+ Clear the roots vector.
462
+ */
463
+ this->roots.clear();
464
+ }
465
+
466
+ std::vector<std::vector<int> > CRoots::get_trajectories()
467
+ {
468
+ /*
469
+ Overview:
470
+ Find the current best trajectory starts from each root.
471
+ Outputs:
472
+ - traj: a vector of node index, which is the current best trajectory from each root.
473
+ */
474
+ std::vector<std::vector<int> > trajs;
475
+ trajs.reserve(this->root_num);
476
+
477
+ for (int i = 0; i < this->root_num; ++i)
478
+ {
479
+ trajs.push_back(this->roots[i].get_trajectory());
480
+ }
481
+ return trajs;
482
+ }
483
+
484
+ std::vector<std::vector<int> > CRoots::get_distributions()
485
+ {
486
+ /*
487
+ Overview:
488
+ Get the children distribution of each root.
489
+ Outputs:
490
+ - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
491
+ */
492
+ std::vector<std::vector<int> > distributions;
493
+ distributions.reserve(this->root_num);
494
+
495
+ for (int i = 0; i < this->root_num; ++i)
496
+ {
497
+ distributions.push_back(this->roots[i].get_children_distribution());
498
+ }
499
+ return distributions;
500
+ }
501
+
502
+ //*********************************************************
503
+ // Gumbel Muzero related code
504
+ //*********************************************************
505
+
506
+ std::vector<std::vector<float> > CRoots::get_children_values(float discount_factor, int action_space_size)
507
+ {
508
+ /*
509
+ Overview:
510
+ Compute the completed value of each root.
511
+ Arguments:
512
+ - discount_factor: the discount_factor of reward.
513
+ - action_space_size: the action space size of environment.
514
+ */
515
+ std::vector<std::vector<float> > values;
516
+ values.reserve(this->root_num);
517
+
518
+ for (int i = 0; i < this->root_num; ++i)
519
+ {
520
+ values.push_back(this->roots[i].get_children_value(discount_factor, action_space_size));
521
+ }
522
+ return values;
523
+ }
524
+
525
+ std::vector<std::vector<float> > CRoots::get_policies(float discount_factor, int action_space_size)
526
+ {
527
+ /*
528
+ Overview:
529
+ Compute the improved policy of each root.
530
+ Arguments:
531
+ - discount_factor: the discount_factor of reward.
532
+ - action_space_size: the action space size of environment.
533
+ */
534
+ std::vector<std::vector<float> > probs;
535
+ probs.reserve(this->root_num);
536
+
537
+ for(int i = 0; i < this->root_num; ++i){
538
+ probs.push_back(this->roots[i].get_policy(discount_factor, action_space_size));
539
+ }
540
+ return probs;
541
+ }
542
+
543
+ std::vector<float> CRoots::get_values()
544
+ {
545
+ /*
546
+ Overview:
547
+ Return the real value of each root.
548
+ */
549
+ std::vector<float> values;
550
+ for (int i = 0; i < this->root_num; ++i)
551
+ {
552
+ values.push_back(this->roots[i].value());
553
+ }
554
+ return values;
555
+ }
556
+
557
+ //*********************************************************
558
+ //
559
+ void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players)
560
+ {
561
+ /*
562
+ Overview:
563
+ Update the q value of the root and its child nodes.
564
+ Arguments:
565
+ - root: the root that update q value from.
566
+ - min_max_stats: a tool used to min-max normalize the q value.
567
+ - discount_factor: the discount factor of reward.
568
+ - players: the number of players.
569
+ */
570
+ std::stack<CNode*> node_stack;
571
+ node_stack.push(root);
572
+ // float parent_value_prefix = 0.0;
573
+ while(node_stack.size() > 0){
574
+ CNode* node = node_stack.top();
575
+ node_stack.pop();
576
+
577
+ if(node != root){
578
+ // # NOTE: in 2 player mode, value_prefix is not calculated according to the perspective of current player of node,
579
+ // # but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
580
+ // # true_reward = node.value_prefix - (- parent_value_prefix)
581
+ // float true_reward = node->value_prefix - node->parent_value_prefix;
582
+ float true_reward = node->reward;
583
+
584
+ float qsa;
585
+ if(players == 1)
586
+ qsa = true_reward + discount_factor * node->value();
587
+ else if(players == 2)
588
+ // TODO(pu):
589
+ qsa = true_reward + discount_factor * (-1) * node->value();
590
+
591
+ min_max_stats.update(qsa);
592
+ }
593
+
594
+ for(auto a: node->legal_actions){
595
+ CNode* child = node->get_child(a);
596
+ if(child->expanded()){
597
+ // child->parent_value_prefix = node->value_prefix;
598
+ node_stack.push(child);
599
+ }
600
+ }
601
+
602
+ }
603
+ }
604
+
605
+ void cback_propagate(std::vector<CNode*> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor)
606
+ {
607
+ /*
608
+ Overview:
609
+ Update the value sum and visit count of nodes along the search path.
610
+ Arguments:
611
+ - search_path: a vector of nodes on the search path.
612
+ - min_max_stats: a tool used to min-max normalize the q value.
613
+ - to_play: which player to play the game in the current node.
614
+ - value: the value to propagate along the search path.
615
+ - discount_factor: the discount factor of reward.
616
+ */
617
+ assert(to_play == -1);
618
+ float bootstrap_value = value;
619
+ int path_len = search_path.size();
620
+ for(int i = path_len - 1; i >= 0; --i){
621
+ CNode* node = search_path[i];
622
+ node->value_sum += bootstrap_value;
623
+ node->visit_count += 1;
624
+
625
+ float true_reward = node->reward;
626
+
627
+ min_max_stats.update(true_reward + discount_factor * node->value());
628
+
629
+ bootstrap_value = true_reward + discount_factor * bootstrap_value;
630
+ }
631
+ }
632
+
633
+ void cbatch_back_propagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &to_play_batch)
634
+ {
635
+ /*
636
+ Overview:
637
+ Expand the nodes along the search path and update the infos.
638
+ Arguments:
639
+ - current_latent_state_index: The index of latent state of the leaf node in the search path.
640
+ - discount_factor: the discount factor of reward.
641
+ - value_prefixs: the value prefixs of nodes along the search path.
642
+ - values: the values to propagate along the search path.
643
+ - policies: the policy logits of nodes along the search path.
644
+ - min_max_stats: a tool used to min-max normalize the q value.
645
+ - results: the search results.
646
+ - to_play_batch: the batch of which player is playing on this node.
647
+ */
648
+ for(int i = 0; i < results.num; ++i){
649
+ results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[i], values[i], policies[i]);
650
+ cback_propagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[i], discount_factor);
651
+ }
652
+ }
653
+
654
+ int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players)
655
+ {
656
+ /*
657
+ Overview:
658
+ Select the child node of the roots according to ucb scores.
659
+ Arguments:
660
+ - root: the roots to select the child node.
661
+ - min_max_stats: a tool used to min-max normalize the score.
662
+ - pb_c_base: constants c2 in muzero.
663
+ - pb_c_init: constants c1 in muzero.
664
+ - disount_factor: the discount factor of reward.
665
+ - mean_q: the mean q value of the parent node.
666
+ - players: the number of players.
667
+ Outputs:
668
+ - action: the action to select.
669
+ */
670
+ float max_score = FLOAT_MIN;
671
+ const float epsilon = 0.000001;
672
+ std::vector<int> max_index_lst;
673
+ for(auto a: root->legal_actions){
674
+
675
+ CNode* child = root->get_child(a);
676
+ float temp_score = cucb_score(child, min_max_stats, mean_q, root->visit_count - 1, pb_c_base, pb_c_init, discount_factor, players);
677
+
678
+ if(max_score < temp_score){
679
+ max_score = temp_score;
680
+
681
+ max_index_lst.clear();
682
+ max_index_lst.push_back(a);
683
+ }
684
+ else if(temp_score >= max_score - epsilon){
685
+ max_index_lst.push_back(a);
686
+ }
687
+ }
688
+
689
+ int action = 0;
690
+ if(max_index_lst.size() > 0){
691
+ int rand_index = rand() % max_index_lst.size();
692
+ action = max_index_lst[rand_index];
693
+ }
694
+ return action;
695
+ }
696
+
697
+ //*********************************************************
698
+ // Gumbel Muzero related code
699
+ //*********************************************************
700
+
701
+ int cselect_root_child(CNode* root, float discount_factor, int num_simulations, int max_num_considered_actions)
702
+ {
703
+ /*
704
+ Overview:
705
+ Select the child node of the roots in gumbel muzero.
706
+ Arguments:
707
+ - root: the roots to select the child node.
708
+ - disount_factor: the discount factor of reward.
709
+ - num_simulations: the upper limit number of simulations.
710
+ - max_num_considered_actions: the maximum number of considered actions.
711
+ Outputs:
712
+ - action: the action to select.
713
+ */
714
+ std::vector<int> child_visit_count;
715
+ std::vector<float> child_prior;
716
+ for(auto a: root->legal_actions){
717
+ CNode* child = root->get_child(a);
718
+ child_visit_count.push_back(child->visit_count);
719
+ child_prior.push_back(child->prior);
720
+ }
721
+ assert(child_visit_count.size()==child_prior.size());
722
+
723
+ std::vector<float> completed_qvalues = qtransform_completed_by_mix_value(root, child_visit_count, child_prior, discount_factor);
724
+ std::vector<std::vector<int> > visit_table = get_table_of_considered_visits(max_num_considered_actions, num_simulations);
725
+
726
+ int num_valid_actions = root->legal_actions.size();
727
+ int num_considered = std::min(max_num_considered_actions, num_simulations);
728
+ int simulation_index = std::accumulate(child_visit_count.begin(), child_visit_count.end(), 0);
729
+ int considered_visit = visit_table[num_considered][simulation_index];
730
+
731
+ std::vector<float> score = score_considered(considered_visit, root->gumbel, child_prior, completed_qvalues, child_visit_count);
732
+
733
+ float argmax = -std::numeric_limits<float>::infinity();
734
+ int max_action = root->legal_actions[0];
735
+ int index = 0;
736
+ for(auto a: root->legal_actions){
737
+ if(score[index] > argmax){
738
+ argmax = score[index];
739
+ max_action = a;
740
+ }
741
+ index += 1;
742
+ }
743
+
744
+ return max_action;
745
+ }
746
+
747
+ int cselect_interior_child(CNode* root, float discount_factor)
748
+ {
749
+ /*
750
+ Overview:
751
+ Select the child node of the interior node in gumbel muzero.
752
+ Arguments:
753
+ - root: the roots to select the child node.
754
+ - disount_factor: the discount factor of reward.
755
+ Outputs:
756
+ - action: the action to select.
757
+ */
758
+ std::vector<int> child_visit_count;
759
+ std::vector<float> child_prior;
760
+ for(auto a: root->legal_actions){
761
+ CNode* child = root->get_child(a);
762
+ child_visit_count.push_back(child->visit_count);
763
+ child_prior.push_back(child->prior);
764
+ }
765
+ assert(child_visit_count.size()==child_prior.size());
766
+ std::vector<float> completed_qvalues = qtransform_completed_by_mix_value(root, child_visit_count, child_prior, discount_factor);
767
+ std::vector<float> probs;
768
+ for (int i=0;i<child_prior.size();i++){
769
+ probs.push_back(child_prior[i] + completed_qvalues[i]);
770
+ }
771
+ csoftmax(probs, probs.size());
772
+ int visit_count_sum = std::accumulate(child_visit_count.begin(), child_visit_count.end(), 0);
773
+ std::vector<float> to_argmax;
774
+ for (int i=0;i<probs.size();i++){
775
+ to_argmax.push_back(probs[i] - (float)child_visit_count[i]/(float)(1+visit_count_sum));
776
+ }
777
+
778
+ float argmax = -std::numeric_limits<float>::infinity();
779
+ int max_action = root->legal_actions[0];
780
+ int index = 0;
781
+ for(auto a: root->legal_actions){
782
+ if(to_argmax[index] > argmax){
783
+ argmax = to_argmax[index];
784
+ max_action = a;
785
+ }
786
+ index += 1;
787
+ }
788
+
789
+ return max_action;
790
+ }
791
+
792
+ float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players)
793
+ {
794
+ /*
795
+ Overview:
796
+ Compute the ucb score of the child.
797
+ Arguments:
798
+ - child: the child node to compute ucb score.
799
+ - min_max_stats: a tool used to min-max normalize the score.
800
+ - mean_q: the mean q value of the parent node.
801
+ - total_children_visit_counts: the total visit counts of the child nodes of the parent node.
802
+ - pb_c_base: constants c2 in muzero.
803
+ - pb_c_init: constants c1 in muzero.
804
+ - disount_factor: the discount factor of reward.
805
+ - players: the number of players.
806
+ Outputs:
807
+ - ucb_value: the ucb score of the child.
808
+ */
809
+ float pb_c = 0.0, prior_score = 0.0, value_score = 0.0;
810
+ pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init;
811
+ pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1));
812
+
813
+ prior_score = pb_c * child->prior;
814
+ if (child->visit_count == 0){
815
+ value_score = parent_mean_q;
816
+ }
817
+ else {
818
+ float true_reward = child->reward;
819
+ if(players == 1)
820
+ value_score = true_reward + discount_factor * child->value();
821
+ else if(players == 2)
822
+ value_score = true_reward + discount_factor * (-child->value());
823
+ }
824
+
825
+ value_score = min_max_stats.normalize(value_score);
826
+
827
+ if (value_score < 0) value_score = 0;
828
+ if (value_score > 1) value_score = 1;
829
+
830
+ float ucb_value = prior_score + value_score;
831
+ return ucb_value;
832
+ }
833
+
834
+ void cbatch_traverse(CRoots *roots, int num_simulations, int max_num_considered_actions, float discount_factor, CSearchResults &results, std::vector<int> &virtual_to_play_batch)
835
+ {
836
+ /*
837
+ Overview:
838
+ Search node path from the roots.
839
+ Arguments:
840
+ - roots: the roots that search from.
841
+ - num_simulations: the upper limit number of simulations.
842
+ - max_num_considered_actions: the maximum number of considered actions.
843
+ - disount_factor: the discount factor of reward.
844
+ - results: the search results.
845
+ - virtual_to_play_batch: the batch of which player is playing on this node.
846
+ */
847
+ // set seed
848
+ timeval t1;
849
+ gettimeofday(&t1, NULL);
850
+ srand(t1.tv_usec);
851
+
852
+ int last_action = -1;
853
+ float parent_q = 0.0;
854
+ results.search_lens = std::vector<int>();
855
+
856
+ int players = 0;
857
+ int largest_element = *max_element(virtual_to_play_batch.begin(),virtual_to_play_batch.end()); // 0 or 2
858
+ if(largest_element==-1)
859
+ players = 1;
860
+ else
861
+ players = 2;
862
+
863
+ for(int i = 0; i < results.num; ++i){
864
+ CNode *node = &(roots->roots[i]);
865
+ int is_root = 1;
866
+ int search_len = 0;
867
+ int action = 0;
868
+ results.search_paths[i].push_back(node);
869
+
870
+ while(node->expanded()){
871
+ if(is_root){
872
+ action = cselect_root_child(node, discount_factor, num_simulations, max_num_considered_actions);
873
+ }
874
+ else{
875
+ action = cselect_interior_child(node, discount_factor);
876
+ }
877
+ is_root = 0;
878
+
879
+ node->best_action = action;
880
+ // next
881
+ node = node->get_child(action);
882
+ last_action = action;
883
+ results.search_paths[i].push_back(node);
884
+ search_len += 1;
885
+ }
886
+
887
+ CNode* parent = results.search_paths[i][results.search_paths[i].size() - 2];
888
+
889
+ results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index);
890
+ results.latent_state_index_in_batch.push_back(parent->batch_index);
891
+
892
+ results.last_actions.push_back(last_action);
893
+ results.search_lens.push_back(search_len);
894
+ results.nodes.push_back(node);
895
+ results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
896
+
897
+ }
898
+ }
899
+
900
+ //*********************************************************
901
+ // Gumbel Muzero related code
902
+ //*********************************************************
903
+
904
+ void csoftmax(std::vector<float> &input, int input_len)
905
+ {
906
+ /*
907
+ Overview:
908
+ Softmax transformation.
909
+ Arguments:
910
+ - input: the vector to be transformed.
911
+ - input_len: the length of input vector.
912
+ */
913
+ assert (input != NULL);
914
+ assert (input_len != 0);
915
+ int i;
916
+ float m;
917
+ // Find maximum value from input array
918
+ m = input[0];
919
+ for (i = 1; i < input_len; i++) {
920
+ if (input[i] > m) {
921
+ m = input[i];
922
+ }
923
+ }
924
+
925
+ float sum = 0;
926
+ for (i = 0; i < input_len; i++) {
927
+ sum += expf(input[i]-m);
928
+ }
929
+
930
+ for (i = 0; i < input_len; i++) {
931
+ input[i] = expf(input[i] - m - log(sum));
932
+ }
933
+ }
934
+
935
+ float compute_mixed_value(float raw_value, std::vector<float> q_values, std::vector<int> &child_visit, std::vector<float> &child_prior)
936
+ {
937
+ /*
938
+ Overview:
939
+ Compute the mixed Q value.
940
+ Arguments:
941
+ - raw_value: the approximated value of the current node from the value network.
942
+ - q_value: the q value of the current node.
943
+ - child_visit: the visit counts of the child nodes.
944
+ - child_prior: the prior of the child nodes.
945
+ Outputs:
946
+ - mixed Q value.
947
+ */
948
+ float visit_count_sum = 0.0;
949
+ float probs_sum = 0.0;
950
+ float weighted_q_sum = 0.0;
951
+ float min_num = -10e7;
952
+
953
+ for(unsigned int i = 0;i < child_visit.size();i++)
954
+ visit_count_sum += child_visit[i];
955
+
956
+ for(unsigned int i = 0;i < child_prior.size();i++)
957
+ // Ensuring non-nan prior
958
+ child_prior[i] = std::max(child_prior[i], min_num);
959
+
960
+ for(unsigned int i = 0;i < child_prior.size();i++)
961
+ if (child_visit[i] > 0)
962
+ probs_sum += child_prior[i];
963
+
964
+ for (unsigned int i = 0;i < child_prior.size();i++)
965
+ if (child_visit[i] > 0){
966
+ weighted_q_sum += child_prior[i] * q_values[i] / probs_sum;
967
+ }
968
+
969
+ return (raw_value + visit_count_sum * weighted_q_sum) / (visit_count_sum+1);
970
+ }
971
+
972
+ void rescale_qvalues(std::vector<float> &value, float epsilon){
973
+ /*
974
+ Overview:
975
+ Rescale the q value with max-min normalization.
976
+ Arguments:
977
+ - value: the value vector to be rescaled.
978
+ - epsilon: the lower limit of gap.
979
+ */
980
+ float max_value = *max_element(value.begin(), value.end());
981
+ float min_value = *min_element(value.begin(), value.end());
982
+ float gap = max_value - min_value;
983
+ gap = std::max(gap, epsilon);
984
+ for (unsigned int i = 0;i < value.size();i++){
985
+ value[i] = (value[i]-min_value)/gap;
986
+ }
987
+ }
988
+
989
+ std::vector<float> qtransform_completed_by_mix_value(CNode *root, std::vector<int> & child_visit, \
990
+ std::vector<float> & child_prior, float discount_factor, float maxvisit_init, float value_scale, \
991
+ bool rescale_values, float epsilon)
992
+ {
993
+ /*
994
+ Overview:
995
+ Calculate the q value with mixed value.
996
+ Arguments:
997
+ - root: the roots that search from.
998
+ - child_visit: the visit counts of the child nodes.
999
+ - child_prior: the prior of the child nodes.
1000
+ - discount_factor: the discount factor of reward.
1001
+ - maxvisit_init: the init of the maximization of visit counts.
1002
+ - value_cale: the scale of value.
1003
+ - rescale_values: whether to rescale the values.
1004
+ - epsilon: the lower limit of gap in max-min normalization
1005
+ Outputs:
1006
+ - completed Q value.
1007
+ */
1008
+ assert (child_visit.size() == child_prior.size());
1009
+ std::vector<float> qvalues;
1010
+ std::vector<float> child_prior_tmp;
1011
+
1012
+ child_prior_tmp.assign(child_prior.begin(), child_prior.end());
1013
+ qvalues = root->get_q(discount_factor);
1014
+ csoftmax(child_prior_tmp, child_prior_tmp.size());
1015
+ // TODO: should be raw_value here
1016
+ float value = compute_mixed_value(root->raw_value, qvalues, child_visit, child_prior_tmp);
1017
+ std::vector<float> completed_qvalue;
1018
+
1019
+ for (unsigned int i = 0;i < child_prior_tmp.size();i++){
1020
+ if (child_visit[i] > 0){
1021
+ completed_qvalue.push_back(qvalues[i]);
1022
+ }
1023
+ else{
1024
+ completed_qvalue.push_back(value);
1025
+ }
1026
+ }
1027
+
1028
+ if (rescale_values){
1029
+ rescale_qvalues(completed_qvalue, epsilon);
1030
+ }
1031
+
1032
+ float max_visit = *max_element(child_visit.begin(), child_visit.end());
1033
+ float visit_scale = maxvisit_init + max_visit;
1034
+
1035
+ for (unsigned int i=0;i < completed_qvalue.size();i++){
1036
+ completed_qvalue[i] = completed_qvalue[i] * visit_scale * value_scale;
1037
+ }
1038
+ return completed_qvalue;
1039
+
1040
+ }
1041
+
1042
+ std::vector<int> get_sequence_of_considered_visits(int max_num_considered_actions, int num_simulations)
1043
+ {
1044
+ /*
1045
+ Overview:
1046
+ Calculate the considered visit sequence.
1047
+ Arguments:
1048
+ - max_num_considered_actions: the maximum number of considered actions.
1049
+ - num_simulations: the upper limit number of simulations.
1050
+ Outputs:
1051
+ - the considered visit sequence.
1052
+ */
1053
+ std::vector<int> visit_seq;
1054
+ if(max_num_considered_actions <= 1){
1055
+ for (int i=0;i < num_simulations;i++)
1056
+ visit_seq.push_back(i);
1057
+ return visit_seq;
1058
+ }
1059
+
1060
+ int log2max = std::ceil(std::log2(max_num_considered_actions));
1061
+ std::vector<int> visits;
1062
+ for (int i = 0;i < max_num_considered_actions;i++)
1063
+ visits.push_back(0);
1064
+ int num_considered = max_num_considered_actions;
1065
+ while (visit_seq.size() < num_simulations){
1066
+ int num_extra_visits = std::max(1, (int)(num_simulations / (log2max * num_considered)));
1067
+ for (int i = 0;i < num_extra_visits;i++){
1068
+ visit_seq.insert(visit_seq.end(), visits.begin(), visits.begin() + num_considered);
1069
+ for (int j = 0;j < num_considered;j++)
1070
+ visits[j] += 1;
1071
+ }
1072
+ num_considered = std::max(2, num_considered/2);
1073
+ }
1074
+ std::vector<int> visit_seq_slice;
1075
+ visit_seq_slice.assign(visit_seq.begin(), visit_seq.begin() + num_simulations);
1076
+ return visit_seq_slice;
1077
+ }
1078
+
1079
+ std::vector<std::vector<int> > get_table_of_considered_visits(int max_num_considered_actions, int num_simulations)
1080
+ {
1081
+ /*
1082
+ Overview:
1083
+ Calculate the table of considered visits.
1084
+ Arguments:
1085
+ - max_num_considered_actions: the maximum number of considered actions.
1086
+ - num_simulations: the upper limit number of simulations.
1087
+ Outputs:
1088
+ - the table of considered visits.
1089
+ */
1090
+ std::vector<std::vector<int> > table;
1091
+ for (int m=0;m < max_num_considered_actions+1;m++){
1092
+ table.push_back(get_sequence_of_considered_visits(m, num_simulations));
1093
+ }
1094
+ return table;
1095
+ }
1096
+
1097
+ std::vector<float> score_considered(int considered_visit, std::vector<float> gumbel, std::vector<float> logits, std::vector<float> normalized_qvalues, std::vector<int> visit_counts)
1098
+ {
1099
+ /*
1100
+ Overview:
1101
+ Calculate the score of nodes to be considered according to the considered visit.
1102
+ Arguments:
1103
+ - considered_visit: the visit counts of node to be considered.
1104
+ - gumbel: the gumbel vector.
1105
+ - logits: the logits vector of child nodes.
1106
+ - normalized_qvalues: the normalized Q values of child nodes.
1107
+ - visit_counts: the visit counts of child nodes.
1108
+ Outputs:
1109
+ - the score of nodes to be considered.
1110
+ */
1111
+ float low_logit = -1e9;
1112
+ float max_logit = *max_element(logits.begin(), logits.end());
1113
+ for (unsigned int i=0;i < logits.size();i++){
1114
+ logits[i] -= max_logit;
1115
+ }
1116
+ std::vector<float> penalty;
1117
+ for (unsigned int i=0;i < visit_counts.size();i++){
1118
+ // Only consider the nodes with specific visit counts
1119
+ if (visit_counts[i]==considered_visit)
1120
+ penalty.push_back(0);
1121
+ else
1122
+ penalty.push_back(-std::numeric_limits<float>::infinity());
1123
+ }
1124
+
1125
+ assert(gumbel.size()==logits.size()==normalized_qvalues.size()==penalty.size());
1126
+ std::vector<float> score;
1127
+ for (unsigned int i=0;i < visit_counts.size();i++){
1128
+ score.push_back(std::max(low_logit, gumbel[i] + logits[i] + normalized_qvalues[i]) + penalty[i]);
1129
+ }
1130
+
1131
+ return score;
1132
+ }
1133
+
1134
+ std::vector<float> generate_gumbel(float gumbel_scale, float gumbel_rng, int shape){
1135
+ /*
1136
+ Overview:
1137
+ Generate gumbel vectors.
1138
+ Arguments:
1139
+ - gumbel_scale: the scale of gumbel.
1140
+ - gumbel_rng: the seed to generate gumbel.
1141
+ - shape: the shape of gumbel vectors to be generated
1142
+ Outputs:
1143
+ - gumbel vectors.
1144
+ */
1145
+ std::mt19937 gen(static_cast<unsigned int>(gumbel_rng));
1146
+ std::extreme_value_distribution<float> d(0, 1);
1147
+
1148
+ std::vector<float> gumbel;
1149
+ for (int i = 0;i < shape;i++)
1150
+ gumbel.push_back(gumbel_scale * d(gen));
1151
+ return gumbel;
1152
+ }
1153
+
1154
+ }
LightZero/lzero/mcts/ctree/ctree_gumbel_muzero/lib/cnode.h ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // C++11
2
+
3
+ #ifndef CNODE_H
4
+ #define CNODE_H
5
+
6
+ #include "./../common_lib/cminimax.h"
7
+ #include <math.h>
8
+ #include <vector>
9
+ #include <stack>
10
+ #include <stdlib.h>
11
+ #include <time.h>
12
+ #include <cmath>
13
+ #include <sys/timeb.h>
14
+ #include <sys/time.h>
15
+ #include <map>
16
+
17
+ const int DEBUG_MODE = 0;
18
+
19
+ namespace tree {
20
+
21
+ class CNode {
22
+ public:
23
+ int visit_count, to_play, current_latent_state_index, batch_index, best_action;
24
+ float reward, prior, value_sum, raw_value, gumbel_scale, gumbel_rng;
25
+ std::vector<int> children_index;
26
+ std::map<int, CNode> children;
27
+
28
+ std::vector<int> legal_actions;
29
+ std::vector<float> gumbel;
30
+
31
+ CNode();
32
+ CNode(float prior, std::vector<int> &legal_actions);
33
+ ~CNode();
34
+
35
+ void expand(int to_play, int current_latent_state_index, int batch_index, float reward, float value, const std::vector<float> &policy_logits);
36
+ void add_exploration_noise(float exploration_fraction, const std::vector<float> &noises);
37
+ std::vector<float> get_q(float discount);
38
+ float compute_mean_q(int isRoot, float parent_q, float discount);
39
+ void print_out();
40
+
41
+ int expanded();
42
+
43
+ float value();
44
+
45
+ std::vector<int> get_trajectory();
46
+ std::vector<int> get_children_distribution();
47
+ std::vector<float> get_children_value(float discount_factor, int action_space_size);
48
+ std::vector<float> get_policy(float discount, int action_space_size);
49
+ CNode* get_child(int action);
50
+ };
51
+
52
+ class CRoots{
53
+ public:
54
+ int root_num;
55
+ std::vector<CNode> roots;
56
+ std::vector<std::vector<int> > legal_actions_list;
57
+
58
+ CRoots();
59
+ CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list);
60
+ ~CRoots();
61
+
62
+ void prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &rewards, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
63
+ void prepare_no_noise(const std::vector<float> &rewards, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
64
+ void clear();
65
+ std::vector<std::vector<int> > get_trajectories();
66
+ std::vector<std::vector<int> > get_distributions();
67
+ std::vector<std::vector<float> > get_children_values(float discount, int action_space_size);
68
+ std::vector<std::vector<float> > get_policies(float discount, int action_space_size);
69
+ std::vector<float> get_values();
70
+
71
+ };
72
+
73
+ class CSearchResults{
74
+ public:
75
+ int num;
76
+ std::vector<int> latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens;
77
+ std::vector<int> virtual_to_play_batchs;
78
+ std::vector<CNode*> nodes;
79
+ std::vector<std::vector<CNode*> > search_paths;
80
+
81
+ CSearchResults();
82
+ CSearchResults(int num);
83
+ ~CSearchResults();
84
+
85
+ };
86
+
87
+
88
+ //*********************************************************
89
+ void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount, int players);
90
+ void cback_propagate(std::vector<CNode*> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount);
91
+ void cbatch_back_propagate(int current_latent_state_index, float discount, const std::vector<float> &rewards, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &to_play_batch);
92
+ int cselect_root_child(CNode* root, float discount, int num_simulations, int max_num_considered_actions);
93
+ int cselect_interior_child(CNode* root, float discount);
94
+ int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount, float mean_q, int players);
95
+ float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount, int players);
96
+ void cbatch_traverse(CRoots *roots, int num_simulations, int max_num_considered_actions, float discount, CSearchResults &results, std::vector<int> &virtual_to_play_batch);
97
+ void csoftmax(std::vector<float> &input, int input_len);
98
+ float compute_mixed_value(float raw_value, std::vector<float> q_values, std::vector<int> &child_visit, std::vector<float> &child_prior);
99
+ void rescale_qvalues(std::vector<float> &value, float epsilon);
100
+ std::vector<float> qtransform_completed_by_mix_value(CNode *root, std::vector<int> & child_visit, \
101
+ std::vector<float> & child_prior, float discount= 0.99, float maxvisit_init = 50.0, float value_scale = 0.1, \
102
+ bool rescale_values = true, float epsilon = 1e-8);
103
+ std::vector<int> get_sequence_of_considered_visits(int max_num_considered_actions, int num_simulations);
104
+ std::vector<std::vector<int> > get_table_of_considered_visits(int max_num_considered_actions, int num_simulations);
105
+ std::vector<float> score_considered(int considered_visit, std::vector<float> gumbel, std::vector<float> logits, std::vector<float> normalized_qvalues, std::vector<int> visit_counts);
106
+ std::vector<float> generate_gumbel(float gumbel_scale, float gumbel_rng, int shape);
107
+ }
108
+
109
+ #endif
LightZero/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // C++11
2
+
3
+ #include <iostream>
4
+ #include "cnode.h"
5
+ #include <algorithm>
6
+ #include <map>
7
+ #include <cassert>
8
+
9
+ #ifdef _WIN32
10
+ #include "..\..\common_lib\utils.cpp"
11
+ #else
12
+ #include "../../common_lib/utils.cpp"
13
+ #endif
14
+
15
+
16
+ namespace tree
17
+ {
18
+
19
+ CSearchResults::CSearchResults()
20
+ {
21
+ /*
22
+ Overview:
23
+ Initialization of CSearchResults, the default result number is set to 0.
24
+ */
25
+ this->num = 0;
26
+ }
27
+
28
+ CSearchResults::CSearchResults(int num)
29
+ {
30
+ /*
31
+ Overview:
32
+ Initialization of CSearchResults with result number.
33
+ */
34
+ this->num = num;
35
+ for (int i = 0; i < num; ++i)
36
+ {
37
+ this->search_paths.push_back(std::vector<CNode *>());
38
+ }
39
+ }
40
+
41
+ CSearchResults::~CSearchResults() {}
42
+
43
+ //*********************************************************
44
+
45
+ CNode::CNode()
46
+ {
47
+ /*
48
+ Overview:
49
+ Initialization of CNode.
50
+ */
51
+ this->prior = 0;
52
+ this->legal_actions = legal_actions;
53
+
54
+ this->visit_count = 0;
55
+ this->value_sum = 0;
56
+ this->best_action = -1;
57
+ this->to_play = 0;
58
+ this->reward = 0.0;
59
+ }
60
+
61
+ CNode::CNode(float prior, std::vector<int> &legal_actions)
62
+ {
63
+ /*
64
+ Overview:
65
+ Initialization of CNode with prior value and legal actions.
66
+ Arguments:
67
+ - prior: the prior value of this node.
68
+ - legal_actions: a vector of legal actions of this node.
69
+ */
70
+ this->prior = prior;
71
+ this->legal_actions = legal_actions;
72
+
73
+ this->visit_count = 0;
74
+ this->value_sum = 0;
75
+ this->best_action = -1;
76
+ this->to_play = 0;
77
+ this->current_latent_state_index = -1;
78
+ this->batch_index = -1;
79
+ }
80
+
81
+ CNode::~CNode() {}
82
+
83
+ void CNode::expand(int to_play, int current_latent_state_index, int batch_index, float reward, const std::vector<float> &policy_logits)
84
+ {
85
+ /*
86
+ Overview:
87
+ Expand the child nodes of the current node.
88
+ Arguments:
89
+ - to_play: which player to play the game in the current node.
90
+ - current_latent_state_index: The index of latent state of the leaf node in the search path of the current node.
91
+ - batch_index: The index of latent state of the leaf node in the search path of the current node.
92
+ - reward: the reward of the current node.
93
+ - policy_logits: the logit of the child nodes.
94
+ */
95
+ this->to_play = to_play;
96
+ this->current_latent_state_index = current_latent_state_index;
97
+ this->batch_index = batch_index;
98
+ this->reward = reward;
99
+
100
+ int action_num = policy_logits.size();
101
+ if (this->legal_actions.size() == 0)
102
+ {
103
+ for (int i = 0; i < action_num; ++i)
104
+ {
105
+ this->legal_actions.push_back(i);
106
+ }
107
+ }
108
+ float temp_policy;
109
+ float policy_sum = 0.0;
110
+
111
+ #ifdef _WIN32
112
+ // 创建动态数组
113
+ float* policy = new float[action_num];
114
+ #else
115
+ float policy[action_num];
116
+ #endif
117
+
118
+ float policy_max = FLOAT_MIN;
119
+ for (auto a : this->legal_actions)
120
+ {
121
+ if (policy_max < policy_logits[a])
122
+ {
123
+ policy_max = policy_logits[a];
124
+ }
125
+ }
126
+
127
+ for (auto a : this->legal_actions)
128
+ {
129
+ temp_policy = exp(policy_logits[a] - policy_max);
130
+ policy_sum += temp_policy;
131
+ policy[a] = temp_policy;
132
+ }
133
+
134
+ float prior;
135
+ for (auto a : this->legal_actions)
136
+ {
137
+ prior = policy[a] / policy_sum;
138
+ std::vector<int> tmp_empty;
139
+ this->children[a] = CNode(prior, tmp_empty); // only for muzero/efficient zero, not support alphazero
140
+ }
141
+
142
+ #ifdef _WIN32
143
+ // 释放数组内存
144
+ delete[] policy;
145
+ #else
146
+ #endif
147
+ }
148
+
149
+ void CNode::add_exploration_noise(float exploration_fraction, const std::vector<float> &noises)
150
+ {
151
+ /*
152
+ Overview:
153
+ Add a noise to the prior of the child nodes.
154
+ Arguments:
155
+ - exploration_fraction: the fraction to add noise.
156
+ - noises: the vector of noises added to each child node.
157
+ */
158
+ float noise, prior;
159
+ for (int i = 0; i < this->legal_actions.size(); ++i)
160
+ {
161
+ noise = noises[i];
162
+ CNode *child = this->get_child(this->legal_actions[i]);
163
+
164
+ prior = child->prior;
165
+ child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction;
166
+ }
167
+ }
168
+
169
+ float CNode::compute_mean_q(int isRoot, float parent_q, float discount_factor)
170
+ {
171
+ /*
172
+ Overview:
173
+ Compute the mean q value of the current node.
174
+ Arguments:
175
+ - isRoot: whether the current node is a root node.
176
+ - parent_q: the q value of the parent node.
177
+ - discount_factor: the discount_factor of reward.
178
+ */
179
+ float total_unsigned_q = 0.0;
180
+ int total_visits = 0;
181
+ for (auto a : this->legal_actions)
182
+ {
183
+ CNode *child = this->get_child(a);
184
+ if (child->visit_count > 0)
185
+ {
186
+ float true_reward = child->reward;
187
+ float qsa = true_reward + discount_factor * child->value();
188
+ total_unsigned_q += qsa;
189
+ total_visits += 1;
190
+ }
191
+ }
192
+
193
+ float mean_q = 0.0;
194
+ if (isRoot && total_visits > 0)
195
+ {
196
+ mean_q = (total_unsigned_q) / (total_visits);
197
+ }
198
+ else
199
+ {
200
+ mean_q = (parent_q + total_unsigned_q) / (total_visits + 1);
201
+ }
202
+ return mean_q;
203
+ }
204
+
205
+ void CNode::print_out()
206
+ {
207
+ return;
208
+ }
209
+
210
+ int CNode::expanded()
211
+ {
212
+ /*
213
+ Overview:
214
+ Return whether the current node is expanded.
215
+ */
216
+ return this->children.size() > 0;
217
+ }
218
+
219
+ float CNode::value()
220
+ {
221
+ /*
222
+ Overview:
223
+ Return the real value of the current tree.
224
+ */
225
+ float true_value = 0.0;
226
+ if (this->visit_count == 0)
227
+ {
228
+ return true_value;
229
+ }
230
+ else
231
+ {
232
+ true_value = this->value_sum / this->visit_count;
233
+ return true_value;
234
+ }
235
+ }
236
+
237
+ std::vector<int> CNode::get_trajectory()
238
+ {
239
+ /*
240
+ Overview:
241
+ Find the current best trajectory starts from the current node.
242
+ Outputs:
243
+ - traj: a vector of node index, which is the current best trajectory from this node.
244
+ */
245
+ std::vector<int> traj;
246
+
247
+ CNode *node = this;
248
+ int best_action = node->best_action;
249
+ while (best_action >= 0)
250
+ {
251
+ traj.push_back(best_action);
252
+
253
+ node = node->get_child(best_action);
254
+ best_action = node->best_action;
255
+ }
256
+ return traj;
257
+ }
258
+
259
+ std::vector<int> CNode::get_children_distribution()
260
+ {
261
+ /*
262
+ Overview:
263
+ Get the distribution of child nodes in the format of visit_count.
264
+ Outputs:
265
+ - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
266
+ */
267
+ std::vector<int> distribution;
268
+ if (this->expanded())
269
+ {
270
+ for (auto a : this->legal_actions)
271
+ {
272
+ CNode *child = this->get_child(a);
273
+ distribution.push_back(child->visit_count);
274
+ }
275
+ }
276
+ return distribution;
277
+ }
278
+
279
+ CNode *CNode::get_child(int action)
280
+ {
281
+ /*
282
+ Overview:
283
+ Get the child node corresponding to the input action.
284
+ Arguments:
285
+ - action: the action to get child.
286
+ */
287
+ return &(this->children[action]);
288
+ }
289
+
290
+ //*********************************************************
291
+
292
+ CRoots::CRoots()
293
+ {
294
+ /*
295
+ Overview:
296
+ The initialization of CRoots.
297
+ */
298
+ this->root_num = 0;
299
+ }
300
+
301
+ CRoots::CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list)
302
+ {
303
+ /*
304
+ Overview:
305
+ The initialization of CRoots with root num and legal action lists.
306
+ Arguments:
307
+ - root_num: the number of the current root.
308
+ - legal_action_list: the vector of the legal action of this root.
309
+ */
310
+ this->root_num = root_num;
311
+ this->legal_actions_list = legal_actions_list;
312
+
313
+ for (int i = 0; i < root_num; ++i)
314
+ {
315
+ this->roots.push_back(CNode(0, this->legal_actions_list[i]));
316
+ }
317
+ }
318
+
319
+ CRoots::~CRoots() {}
320
+
321
+ void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
322
+ {
323
+ /*
324
+ Overview:
325
+ Expand the roots and add noises.
326
+ Arguments:
327
+ - root_noise_weight: the exploration fraction of roots
328
+ - noises: the vector of noise add to the roots.
329
+ - rewards: the vector of rewards of each root.
330
+ - policies: the vector of policy logits of each root.
331
+ - to_play_batch: the vector of the player side of each root.
332
+ */
333
+ for (int i = 0; i < this->root_num; ++i)
334
+ {
335
+ this->roots[i].expand(to_play_batch[i], 0, i, rewards[i], policies[i]);
336
+ this->roots[i].add_exploration_noise(root_noise_weight, noises[i]);
337
+
338
+ this->roots[i].visit_count += 1;
339
+ }
340
+ }
341
+
342
+ void CRoots::prepare_no_noise(const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
343
+ {
344
+ /*
345
+ Overview:
346
+ Expand the roots without noise.
347
+ Arguments:
348
+ - rewards: the vector of rewards of each root.
349
+ - policies: the vector of policy logits of each root.
350
+ - to_play_batch: the vector of the player side of each root.
351
+ */
352
+ for (int i = 0; i < this->root_num; ++i)
353
+ {
354
+ this->roots[i].expand(to_play_batch[i], 0, i, rewards[i], policies[i]);
355
+
356
+ this->roots[i].visit_count += 1;
357
+ }
358
+ }
359
+
360
+ void CRoots::clear()
361
+ {
362
+ /*
363
+ Overview:
364
+ Clear the roots vector.
365
+ */
366
+ this->roots.clear();
367
+ }
368
+
369
+ std::vector<std::vector<int> > CRoots::get_trajectories()
370
+ {
371
+ /*
372
+ Overview:
373
+ Find the current best trajectory starts from each root.
374
+ Outputs:
375
+ - traj: a vector of node index, which is the current best trajectory from each root.
376
+ */
377
+ std::vector<std::vector<int> > trajs;
378
+ trajs.reserve(this->root_num);
379
+
380
+ for (int i = 0; i < this->root_num; ++i)
381
+ {
382
+ trajs.push_back(this->roots[i].get_trajectory());
383
+ }
384
+ return trajs;
385
+ }
386
+
387
+ std::vector<std::vector<int> > CRoots::get_distributions()
388
+ {
389
+ /*
390
+ Overview:
391
+ Get the children distribution of each root.
392
+ Outputs:
393
+ - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
394
+ */
395
+ std::vector<std::vector<int> > distributions;
396
+ distributions.reserve(this->root_num);
397
+
398
+ for (int i = 0; i < this->root_num; ++i)
399
+ {
400
+ distributions.push_back(this->roots[i].get_children_distribution());
401
+ }
402
+ return distributions;
403
+ }
404
+
405
+ std::vector<float> CRoots::get_values()
406
+ {
407
+ /*
408
+ Overview:
409
+ Return the real value of each root.
410
+ */
411
+ std::vector<float> values;
412
+ for (int i = 0; i < this->root_num; ++i)
413
+ {
414
+ values.push_back(this->roots[i].value());
415
+ }
416
+ return values;
417
+ }
418
+
419
+ //*********************************************************
420
+ //
421
+ void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players)
422
+ {
423
+ /*
424
+ Overview:
425
+ Update the q value of the root and its child nodes.
426
+ Arguments:
427
+ - root: the root that update q value from.
428
+ - min_max_stats: a tool used to min-max normalize the q value.
429
+ - discount_factor: the discount factor of reward.
430
+ - players: the number of players.
431
+ */
432
+ std::stack<CNode *> node_stack;
433
+ node_stack.push(root);
434
+ while (node_stack.size() > 0)
435
+ {
436
+ CNode *node = node_stack.top();
437
+ node_stack.pop();
438
+
439
+ if (node != root)
440
+ {
441
+ // # NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
442
+ // # but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
443
+ // # true_reward = node.value_prefix - (- parent_value_prefix)
444
+ // float true_reward = node->value_prefix - node->parent_value_prefix;
445
+ float true_reward = node->reward;
446
+
447
+ float qsa;
448
+ if (players == 1)
449
+ qsa = true_reward + discount_factor * node->value();
450
+ else if (players == 2)
451
+ // TODO(pu):
452
+ qsa = true_reward + discount_factor * (-1) * node->value();
453
+
454
+ min_max_stats.update(qsa);
455
+ }
456
+
457
+ for (auto a : node->legal_actions)
458
+ {
459
+ CNode *child = node->get_child(a);
460
+ if (child->expanded())
461
+ {
462
+ node_stack.push(child);
463
+ }
464
+ }
465
+ }
466
+ }
467
+
468
+ void cbackpropagate(std::vector<CNode *> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor)
469
+ {
470
+ /*
471
+ Overview:
472
+ Update the value sum and visit count of nodes along the search path.
473
+ Arguments:
474
+ - search_path: a vector of nodes on the search path.
475
+ - min_max_stats: a tool used to min-max normalize the q value.
476
+ - to_play: which player to play the game in the current node.
477
+ - value: the value to propagate along the search path.
478
+ - discount_factor: the discount factor of reward.
479
+ */
480
+ assert(to_play == -1 || to_play == 1 || to_play == 2);
481
+ if (to_play == -1)
482
+ {
483
+ // for play-with-bot-mode
484
+ float bootstrap_value = value;
485
+ int path_len = search_path.size();
486
+ for (int i = path_len - 1; i >= 0; --i)
487
+ {
488
+ CNode *node = search_path[i];
489
+ node->value_sum += bootstrap_value;
490
+ node->visit_count += 1;
491
+
492
+ float true_reward = node->reward;
493
+
494
+ min_max_stats.update(true_reward + discount_factor * node->value());
495
+
496
+ bootstrap_value = true_reward + discount_factor * bootstrap_value;
497
+ }
498
+ }
499
+ else
500
+ {
501
+ // for self-play-mode
502
+ float bootstrap_value = value;
503
+ int path_len = search_path.size();
504
+ for (int i = path_len - 1; i >= 0; --i)
505
+ {
506
+ CNode *node = search_path[i];
507
+ if (node->to_play == to_play)
508
+ node->value_sum += bootstrap_value;
509
+ else
510
+ node->value_sum += -bootstrap_value;
511
+ node->visit_count += 1;
512
+
513
+ // NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
514
+ // but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
515
+ // float true_reward = node->value_prefix - parent_value_prefix;
516
+ float true_reward = node->reward;
517
+
518
+ // TODO(pu): why in muzero-general is - node.value
519
+ min_max_stats.update(true_reward + discount_factor * -node->value());
520
+
521
+ if (node->to_play == to_play)
522
+ bootstrap_value = -true_reward + discount_factor * bootstrap_value;
523
+ else
524
+ bootstrap_value = true_reward + discount_factor * bootstrap_value;
525
+ }
526
+ }
527
+ }
528
+
529
+ void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &to_play_batch)
530
+ {
531
+ /*
532
+ Overview:
533
+ Expand the nodes along the search path and update the infos.
534
+ Arguments:
535
+ - current_latent_state_index: The index of latent state of the leaf node in the search path.
536
+ - discount_factor: the discount factor of reward.
537
+ - value_prefixs: the value prefixs of nodes along the search path.
538
+ - values: the values to propagate along the search path.
539
+ - policies: the policy logits of nodes along the search path.
540
+ - min_max_stats: a tool used to min-max normalize the q value.
541
+ - results: the search results.
542
+ - to_play_batch: the batch of which player is playing on this node.
543
+ */
544
+ for (int i = 0; i < results.num; ++i)
545
+ {
546
+ results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[i], policies[i]);
547
+ cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[i], discount_factor);
548
+ }
549
+ }
550
+
551
+ int cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players)
552
+ {
553
+ /*
554
+ Overview:
555
+ Select the child node of the roots according to ucb scores.
556
+ Arguments:
557
+ - root: the roots to select the child node.
558
+ - min_max_stats: a tool used to min-max normalize the score.
559
+ - pb_c_base: constants c2 in muzero.
560
+ - pb_c_init: constants c1 in muzero.
561
+ - disount_factor: the discount factor of reward.
562
+ - mean_q: the mean q value of the parent node.
563
+ - players: the number of players.
564
+ Outputs:
565
+ - action: the action to select.
566
+ */
567
+ float max_score = FLOAT_MIN;
568
+ const float epsilon = 0.000001;
569
+ std::vector<int> max_index_lst;
570
+ for (auto a : root->legal_actions)
571
+ {
572
+
573
+ CNode *child = root->get_child(a);
574
+ float temp_score = cucb_score(child, min_max_stats, mean_q, root->visit_count - 1, pb_c_base, pb_c_init, discount_factor, players);
575
+
576
+ if (max_score < temp_score)
577
+ {
578
+ max_score = temp_score;
579
+
580
+ max_index_lst.clear();
581
+ max_index_lst.push_back(a);
582
+ }
583
+ else if (temp_score >= max_score - epsilon)
584
+ {
585
+ max_index_lst.push_back(a);
586
+ }
587
+ }
588
+
589
+ int action = 0;
590
+ if (max_index_lst.size() > 0)
591
+ {
592
+ int rand_index = rand() % max_index_lst.size();
593
+ action = max_index_lst[rand_index];
594
+ }
595
+ return action;
596
+ }
597
+
598
+ float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players)
599
+ {
600
+ /*
601
+ Overview:
602
+ Compute the ucb score of the child.
603
+ Arguments:
604
+ - child: the child node to compute ucb score.
605
+ - min_max_stats: a tool used to min-max normalize the score.
606
+ - mean_q: the mean q value of the parent node.
607
+ - total_children_visit_counts: the total visit counts of the child nodes of the parent node.
608
+ - pb_c_base: constants c2 in muzero.
609
+ - pb_c_init: constants c1 in muzero.
610
+ - disount_factor: the discount factor of reward.
611
+ - players: the number of players.
612
+ Outputs:
613
+ - ucb_value: the ucb score of the child.
614
+ */
615
+ float pb_c = 0.0, prior_score = 0.0, value_score = 0.0;
616
+ pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init;
617
+ pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1));
618
+
619
+ prior_score = pb_c * child->prior;
620
+ if (child->visit_count == 0)
621
+ {
622
+ value_score = parent_mean_q;
623
+ }
624
+ else
625
+ {
626
+ float true_reward = child->reward;
627
+ if (players == 1)
628
+ value_score = true_reward + discount_factor * child->value();
629
+ else if (players == 2)
630
+ value_score = true_reward + discount_factor * (-child->value());
631
+ }
632
+
633
+ value_score = min_max_stats.normalize(value_score);
634
+
635
+ if (value_score < 0)
636
+ value_score = 0;
637
+ if (value_score > 1)
638
+ value_score = 1;
639
+
640
+ float ucb_value = prior_score + value_score;
641
+ return ucb_value;
642
+ }
643
+
644
+ void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch)
645
+ {
646
+ /*
647
+ Overview:
648
+ Search node path from the roots.
649
+ Arguments:
650
+ - roots: the roots that search from.
651
+ - pb_c_base: constants c2 in muzero.
652
+ - pb_c_init: constants c1 in muzero.
653
+ - disount_factor: the discount factor of reward.
654
+ - min_max_stats: a tool used to min-max normalize the score.
655
+ - results: the search results.
656
+ - virtual_to_play_batch: the batch of which player is playing on this node.
657
+ */
658
+ // set seed
659
+ get_time_and_set_rand_seed();
660
+
661
+ int last_action = -1;
662
+ float parent_q = 0.0;
663
+ results.search_lens = std::vector<int>();
664
+
665
+ int players = 0;
666
+ int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2
667
+ if (largest_element == -1)
668
+ players = 1;
669
+ else
670
+ players = 2;
671
+
672
+ for (int i = 0; i < results.num; ++i)
673
+ {
674
+ CNode *node = &(roots->roots[i]);
675
+ int is_root = 1;
676
+ int search_len = 0;
677
+ results.search_paths[i].push_back(node);
678
+
679
+ while (node->expanded())
680
+ {
681
+ float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor);
682
+ is_root = 0;
683
+ parent_q = mean_q;
684
+
685
+ int action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players);
686
+ if (players > 1)
687
+ {
688
+ assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2);
689
+ if (virtual_to_play_batch[i] == 1)
690
+ virtual_to_play_batch[i] = 2;
691
+ else
692
+ virtual_to_play_batch[i] = 1;
693
+ }
694
+
695
+ node->best_action = action;
696
+ // next
697
+ node = node->get_child(action);
698
+ last_action = action;
699
+ results.search_paths[i].push_back(node);
700
+ search_len += 1;
701
+ }
702
+
703
+ CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2];
704
+
705
+ results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index);
706
+ results.latent_state_index_in_batch.push_back(parent->batch_index);
707
+
708
+ results.last_actions.push_back(last_action);
709
+ results.search_lens.push_back(search_len);
710
+ results.nodes.push_back(node);
711
+ results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
712
+ }
713
+ }
714
+
715
+ }
LightZero/lzero/mcts/ctree/ctree_muzero/lib/cnode.h ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // C++11
2
+
3
+ #ifndef CNODE_H
4
+ #define CNODE_H
5
+
6
+ #include "./../common_lib/cminimax.h"
7
+ #include <math.h>
8
+ #include <vector>
9
+ #include <stack>
10
+ #include <stdlib.h>
11
+ #include <time.h>
12
+ #include <cmath>
13
+ #include <sys/timeb.h>
14
+ #include <time.h>
15
+ #include <map>
16
+
17
+ const int DEBUG_MODE = 0;
18
+
19
+ namespace tree {
20
+
21
+ class CNode {
22
+ public:
23
+ int visit_count, to_play, current_latent_state_index, batch_index, best_action;
24
+ float reward, prior, value_sum;
25
+ std::vector<int> children_index;
26
+ std::map<int, CNode> children;
27
+
28
+ std::vector<int> legal_actions;
29
+
30
+ CNode();
31
+ CNode(float prior, std::vector<int> &legal_actions);
32
+ ~CNode();
33
+
34
+ void expand(int to_play, int current_latent_state_index, int batch_index, float reward, const std::vector<float> &policy_logits);
35
+ void add_exploration_noise(float exploration_fraction, const std::vector<float> &noises);
36
+ float compute_mean_q(int isRoot, float parent_q, float discount_factor);
37
+ void print_out();
38
+
39
+ int expanded();
40
+
41
+ float value();
42
+
43
+ std::vector<int> get_trajectory();
44
+ std::vector<int> get_children_distribution();
45
+ CNode* get_child(int action);
46
+ };
47
+
48
+ class CRoots{
49
+ public:
50
+ int root_num;
51
+ std::vector<CNode> roots;
52
+ std::vector<std::vector<int> > legal_actions_list;
53
+
54
+ CRoots();
55
+ CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list);
56
+ ~CRoots();
57
+
58
+ void prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
59
+ void prepare_no_noise(const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
60
+ void clear();
61
+ std::vector<std::vector<int> > get_trajectories();
62
+ std::vector<std::vector<int> > get_distributions();
63
+ std::vector<float> get_values();
64
+
65
+ };
66
+
67
+ class CSearchResults{
68
+ public:
69
+ int num;
70
+ std::vector<int> latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens;
71
+ std::vector<int> virtual_to_play_batchs;
72
+ std::vector<CNode*> nodes;
73
+ std::vector<std::vector<CNode*> > search_paths;
74
+
75
+ CSearchResults();
76
+ CSearchResults(int num);
77
+ ~CSearchResults();
78
+
79
+ };
80
+
81
+
82
+ //*********************************************************
83
+ void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players);
84
+ void cbackpropagate(std::vector<CNode*> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor);
85
+ void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &rewards, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &to_play_batch);
86
+ int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players);
87
+ float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players);
88
+ void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch);
89
+ }
90
+
91
+ #endif
LightZero/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp ADDED
@@ -0,0 +1,1189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // C++11
2
+
3
+ #include <iostream>
4
+ #include "cnode.h"
5
+ #include <algorithm>
6
+ #include <map>
7
+ #include <random>
8
+ #include <chrono>
9
+ #include <iostream>
10
+ #include <vector>
11
+ #include <stack>
12
+ #include <math.h>
13
+
14
+ #include <stdlib.h>
15
+ #include <time.h>
16
+ #include <cmath>
17
+ #include <sys/timeb.h>
18
+ #include <time.h>
19
+ #include <cassert>
20
+
21
+ #ifdef _WIN32
22
+ #include "..\..\common_lib\utils.cpp"
23
+ #else
24
+ #include "../../common_lib/utils.cpp"
25
+ #endif
26
+
27
+
28
+
29
+ template <class T>
30
+ size_t hash_combine(std::size_t &seed, const T &val)
31
+ {
32
+ /*
33
+ Overview:
34
+ Combines a hash value with a new value using a bitwise XOR and a rotation.
35
+ This function is used to create a hash value for multiple values.
36
+ Arguments:
37
+ - seed The current hash value to be combined with.
38
+ - val The new value to be hashed and combined with the seed.
39
+ */
40
+ std::hash<T> hasher; // Create a hash object for the new value.
41
+ seed ^= hasher(val) + 0x9e3779b9 + (seed << 6) + (seed >> 2); // Combine the new hash value with the seed.
42
+ return seed;
43
+ }
44
+
45
+ // Sort by the value of second in descending order.
46
+ bool cmp(std::pair<int, double> x, std::pair<int, double> y)
47
+ {
48
+ return x.second > y.second;
49
+ }
50
+
51
+ namespace tree
52
+ {
53
+ //*********************************************************
54
+
55
+ CAction::CAction()
56
+ {
57
+ /*
58
+ Overview:
59
+ Initialization of CAction. Parameterized constructor.
60
+ */
61
+ this->is_root_action = 0;
62
+ }
63
+
64
+ CAction::CAction(std::vector<float> value, int is_root_action)
65
+ {
66
+ /*
67
+ Overview:
68
+ Initialization of CAction with value and is_root_action. Default constructor.
69
+ Arguments:
70
+ - value: a multi-dimensional action.
71
+ - is_root_action: whether value is a root node.
72
+ */
73
+ this->value = value;
74
+ this->is_root_action = is_root_action;
75
+ }
76
+
77
+ CAction::~CAction() {} // Destructors.
78
+
79
+ std::vector<size_t> CAction::get_hash(void)
80
+ {
81
+ /*
82
+ Overview:
83
+ get a hash value for each dimension in the multi-dimensional action.
84
+ */
85
+ std::vector<size_t> hash;
86
+ for (int i = 0; i < this->value.size(); ++i)
87
+ {
88
+ std::size_t hash_i = std::hash<std::string>()(std::to_string(this->value[i]));
89
+ hash.push_back(hash_i);
90
+ }
91
+ return hash;
92
+ }
93
+ size_t CAction::get_combined_hash(void)
94
+ {
95
+ /*
96
+ Overview:
97
+ get the final combined hash value from the hash values of each dimension of the multi-dimensional action.
98
+ */
99
+ std::vector<size_t> hash = this->get_hash();
100
+ size_t combined_hash = hash[0];
101
+
102
+ if (hash.size() >= 1)
103
+ {
104
+ for (int i = 1; i < hash.size(); ++i)
105
+ {
106
+ combined_hash = hash_combine(combined_hash, hash[i]);
107
+ }
108
+ }
109
+
110
+ return combined_hash;
111
+ }
112
+
113
+ //*********************************************************
114
+
115
+ CSearchResults::CSearchResults()
116
+ {
117
+ /*
118
+ Overview:
119
+ Initialization of CSearchResults, the default result number is set to 0.
120
+ */
121
+ this->num = 0;
122
+ }
123
+
124
+ CSearchResults::CSearchResults(int num)
125
+ {
126
+ /*
127
+ Overview:
128
+ Initialization of CSearchResults with result number.
129
+ */
130
+ this->num = num;
131
+ for (int i = 0; i < num; ++i)
132
+ {
133
+ this->search_paths.push_back(std::vector<CNode *>());
134
+ }
135
+ }
136
+
137
+ CSearchResults::~CSearchResults() {}
138
+
139
+ //*********************************************************
140
+
141
+ CNode::CNode()
142
+ {
143
+ /*
144
+ Overview:
145
+ Initialization of CNode.
146
+ */
147
+ this->prior = 0;
148
+ this->action_space_size = 9;
149
+ this->num_of_sampled_actions = 20;
150
+ this->continuous_action_space = false;
151
+
152
+ this->is_reset = 0;
153
+ this->visit_count = 0;
154
+ this->value_sum = 0;
155
+ CAction best_action;
156
+ this->best_action = best_action;
157
+
158
+ this->to_play = 0;
159
+ this->value_prefix = 0.0;
160
+ this->parent_value_prefix = 0.0;
161
+ }
162
+
163
+ CNode::CNode(float prior, std::vector<CAction> &legal_actions, int action_space_size, int num_of_sampled_actions, bool continuous_action_space)
164
+ {
165
+ /*
166
+ Overview:
167
+ Initialization of CNode with prior, legal actions, action_space_size, num_of_sampled_actions, continuous_action_space.
168
+ Arguments:
169
+ - prior: the prior value of this node.
170
+ - legal_actions: a vector of legal actions of this node.
171
+ - action_space_size: the size of action space of the current env.
172
+ - num_of_sampled_actions: the number of sampled actions, i.e. K in the Sampled MuZero papers.
173
+ - continuous_action_space: whether the action space is continous in current env.
174
+ */
175
+ this->prior = prior;
176
+ this->legal_actions = legal_actions;
177
+
178
+ this->action_space_size = action_space_size;
179
+ this->num_of_sampled_actions = num_of_sampled_actions;
180
+ this->continuous_action_space = continuous_action_space;
181
+ this->is_reset = 0;
182
+ this->visit_count = 0;
183
+ this->value_sum = 0;
184
+ this->to_play = 0;
185
+ this->value_prefix = 0.0;
186
+ this->parent_value_prefix = 0.0;
187
+ this->current_latent_state_index = -1;
188
+ this->batch_index = -1;
189
+ }
190
+
191
+ CNode::~CNode() {}
192
+
193
+
194
+ void CNode::expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector<float> &policy_logits)
195
+ {
196
+ /*
197
+ Overview:
198
+ Expand the child nodes of the current node.
199
+ Arguments:
200
+ - to_play: which player to play the game in the current node.
201
+ - current_latent_state_index: the x/first index of hidden state vector of the current node, i.e. the search depth.
202
+ - batch_index: the y/second index of hidden state vector of the current node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``.
203
+ - value_prefix: the value prefix of the current node.
204
+ - policy_logits: the logit of the child nodes.
205
+ */
206
+ this->to_play = to_play;
207
+ this->current_latent_state_index = current_latent_state_index;
208
+ this->batch_index = batch_index;
209
+ this->value_prefix = value_prefix;
210
+ int action_num = policy_logits.size();
211
+
212
+ #ifdef _WIN32
213
+ // 创建动态数组
214
+ float* policy = new float[action_num];
215
+ #else
216
+ float policy[action_num];
217
+ #endif
218
+
219
+ std::vector<int> all_actions;
220
+ for (int i = 0; i < action_num; ++i)
221
+ {
222
+ all_actions.push_back(i);
223
+ }
224
+ std::vector<std::vector<float> > sampled_actions_after_tanh;
225
+ std::vector<float> sampled_actions_log_probs_after_tanh;
226
+
227
+ std::vector<int> sampled_actions;
228
+ std::vector<float> sampled_actions_log_probs;
229
+ std::vector<float> sampled_actions_probs;
230
+ std::vector<float> probs;
231
+
232
+ /*
233
+ Overview:
234
+ When the currennt env has continuous action space, sampled K actions from continuous gaussia distribution policy.
235
+ When the currennt env has discrete action space, sampled K actions from discrete categirical distribution policy.
236
+
237
+ */
238
+ if (this->continuous_action_space == true)
239
+ {
240
+ // continuous action space for sampled algo..
241
+ this->action_space_size = policy_logits.size() / 2;
242
+ std::vector<float> mu;
243
+ std::vector<float> sigma;
244
+ for (int i = 0; i < this->action_space_size; ++i)
245
+ {
246
+ mu.push_back(policy_logits[i]);
247
+ sigma.push_back(policy_logits[this->action_space_size + i]);
248
+ }
249
+
250
+ // The number of nanoseconds that have elapsed since epoch(1970: 00: 00 UTC on January 1, 1970). unsigned type will truncate this value.
251
+ unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
252
+
253
+ // SAC-like tanh, pleasee refer to paper https://arxiv.org/abs/1801.01290.
254
+ std::vector<std::vector<float> > sampled_actions_before_tanh;
255
+
256
+ float sampled_action_one_dim_before_tanh;
257
+ std::vector<float> sampled_actions_log_probs_before_tanh;
258
+
259
+ std::default_random_engine generator(seed);
260
+ for (int i = 0; i < this->num_of_sampled_actions; ++i)
261
+ {
262
+ float sampled_action_prob_before_tanh = 1;
263
+ // TODO(pu): why here
264
+ std::vector<float> sampled_action_before_tanh;
265
+ std::vector<float> sampled_action_after_tanh;
266
+ std::vector<float> y;
267
+
268
+ for (int j = 0; j < this->action_space_size; ++j)
269
+ {
270
+ std::normal_distribution<float> distribution(mu[j], sigma[j]);
271
+ sampled_action_one_dim_before_tanh = distribution(generator);
272
+ // refer to python normal log_prob method
273
+ sampled_action_prob_before_tanh *= exp(-pow((sampled_action_one_dim_before_tanh - mu[j]), 2) / (2 * pow(sigma[j], 2)) - log(sigma[j]) - log(sqrt(2 * M_PI)));
274
+ sampled_action_before_tanh.push_back(sampled_action_one_dim_before_tanh);
275
+ sampled_action_after_tanh.push_back(tanh(sampled_action_one_dim_before_tanh));
276
+ y.push_back(1 - pow(tanh(sampled_action_one_dim_before_tanh), 2) + 1e-6);
277
+ }
278
+ sampled_actions_before_tanh.push_back(sampled_action_before_tanh);
279
+ sampled_actions_after_tanh.push_back(sampled_action_after_tanh);
280
+ sampled_actions_log_probs_before_tanh.push_back(log(sampled_action_prob_before_tanh));
281
+ float y_sum = std::accumulate(y.begin(), y.end(), 0.);
282
+ sampled_actions_log_probs_after_tanh.push_back(log(sampled_action_prob_before_tanh) - log(y_sum));
283
+ }
284
+ }
285
+ else
286
+ {
287
+ // discrete action space for sampled algo..
288
+
289
+ //========================================================
290
+ // python code
291
+ //========================================================
292
+ // if self.legal_actions is not None:
293
+ // # fisrt use the self.legal_actions to exclude the illegal actions
294
+ // policy_tmp = [0. for _ in range(self.action_space_size)]
295
+ // for index, legal_action in enumerate(self.legal_actions):
296
+ // policy_tmp[legal_action] = policy_logits[index]
297
+ // policy_logits = policy_tmp
298
+ // # then empty the self.legal_actions
299
+ // self.legal_actions = []
300
+ // then empty the self.legal_actions
301
+ // prob = torch.softmax(torch.tensor(policy_logits), dim=-1)
302
+ // sampled_actions = torch.multinomial(prob, self.num_of_sampled_actions, replacement=False)
303
+
304
+ //========================================================
305
+ // TODO(pu): legal actions
306
+ //========================================================
307
+ // std::vector<float> policy_tmp;
308
+ // for (int i = 0; i < this->action_space_size; ++i)
309
+ // {
310
+ // policy_tmp.push_back(0.);
311
+ // }
312
+ // for (int i = 0; i < this->legal_actions.size(); ++i)
313
+ // {
314
+ // policy_tmp[this->legal_actions[i].value] = policy_logits[i];
315
+ // }
316
+ // for (int i = 0; i < this->action_space_size; ++i)
317
+ // {
318
+ // policy_logits[i] = policy_tmp[i];
319
+ // }
320
+ // std::cout << "position 3" << std::endl;
321
+
322
+ // python code: legal_actions = []
323
+ std::vector<CAction> legal_actions;
324
+
325
+ // python code: probs = softmax(policy_logits)
326
+ float logits_exp_sum = 0;
327
+ for (int i = 0; i < policy_logits.size(); ++i)
328
+ {
329
+ logits_exp_sum += exp(policy_logits[i]);
330
+ }
331
+ for (int i = 0; i < policy_logits.size(); ++i)
332
+ {
333
+ probs.push_back(exp(policy_logits[i]) / (logits_exp_sum + 1e-6));
334
+ }
335
+
336
+ unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
337
+
338
+ // cout << "sampled_action[0]:" << sampled_action[0] <<endl;
339
+
340
+ // std::vector<int> sampled_actions;
341
+ // std::vector<float> sampled_actions_log_probs;
342
+ // std::vector<float> sampled_actions_probs;
343
+ std::default_random_engine generator(seed);
344
+
345
+ // 有放回抽样
346
+ // for (int i = 0; i < num_of_sampled_actions; ++i)
347
+ // {
348
+ // float sampled_action_prob = 1;
349
+ // int sampled_action;
350
+
351
+ // std::discrete_distribution<float> distribution(probs.begin(), probs.end());
352
+
353
+ // // for (float x:distribution.probabilities()) std::cout << x << " ";
354
+ // sampled_action = distribution(generator);
355
+ // // std::cout << "sampled_action: " << sampled_action << std::endl;
356
+
357
+ // sampled_actions.push_back(sampled_action);
358
+ // sampled_actions_probs.push_back(probs[sampled_action]);
359
+ // std::cout << "sampled_actions_probs" << '[' << i << ']' << sampled_actions_probs[i] << std::endl;
360
+
361
+ // sampled_actions_log_probs.push_back(log(probs[sampled_action]));
362
+ // std::cout << "sampled_actions_log_probs" << '[' << i << ']' << sampled_actions_log_probs[i] << std::endl;
363
+ // }
364
+
365
+ // 每个节点的legal_actions应该为一个固定离散集合,所以采用无放回抽样
366
+ // std::cout << "position uniform_distribution init" << std::endl;
367
+ std::uniform_real_distribution<double> uniform_distribution(0.0, 1.0); //均匀分布
368
+ // std::cout << "position uniform_distribution done" << std::endl;
369
+ std::vector<double> disturbed_probs;
370
+ std::vector<std::pair<int, double> > disc_action_with_probs;
371
+
372
+ // Use the reciprocal of the probability value as the exponent and a random number sampled from a uniform distribution as the base:
373
+ // Equivalent to adding a uniform random disturbance to the original probability value.
374
+ for (auto prob : probs)
375
+ {
376
+ disturbed_probs.push_back(std::pow(uniform_distribution(generator), 1. / prob));
377
+ }
378
+
379
+ // Sort from large to small according to the probability value after the disturbance:
380
+ // After sorting, the first vector is the index, and the second vector is the probability value after perturbation sorted from large to small.
381
+ for (size_t iter = 0; iter < disturbed_probs.size(); iter++)
382
+ {
383
+
384
+ #ifdef __GNUC__
385
+ // Use push_back for GCC
386
+ disc_action_with_probs.push_back(std::make_pair(iter, disturbed_probs[iter]));
387
+ #else
388
+ // Use emplace_back for other compilers
389
+ disc_action_with_probs.emplace_back(std::make_pair(iter, disturbed_probs[iter]));
390
+ #endif
391
+ }
392
+
393
+ std::sort(disc_action_with_probs.begin(), disc_action_with_probs.end(), cmp);
394
+
395
+ // take the fist ``num_of_sampled_actions`` actions
396
+ for (int k = 0; k < num_of_sampled_actions; ++k)
397
+ {
398
+ sampled_actions.push_back(disc_action_with_probs[k].first);
399
+ // disc_action_with_probs[k].second is disturbed_probs
400
+ // sampled_actions_probs.push_back(disc_action_with_probs[k].second);
401
+ sampled_actions_probs.push_back(probs[disc_action_with_probs[k].first]);
402
+
403
+ // TODO(pu): logging
404
+ // std::cout << "sampled_actions[k]: " << sampled_actions[k] << std::endl;
405
+ // std::cout << "sampled_actions_probs[k]: " << sampled_actions_probs[k] << std::endl;
406
+ }
407
+
408
+ // TODO(pu): fixed k, only for debugging
409
+ // Take the first ``num_of_sampled_actions`` actions: k=0,1,...,K-1
410
+ // for (int k = 0; k < num_of_sampled_actions; ++k)
411
+ // {
412
+ // sampled_actions.push_back(k);
413
+ // // disc_action_with_probs[k].second is disturbed_probs
414
+ // // sampled_actions_probs.push_back(disc_action_with_probs[k].second);
415
+ // sampled_actions_probs.push_back(probs[k]);
416
+ // }
417
+
418
+ disturbed_probs.clear(); // Empty the collection to prepare for the next sampling.
419
+ disc_action_with_probs.clear(); // Empty the collection to prepare for the next sampling.
420
+ }
421
+
422
+ float prior;
423
+ for (int i = 0; i < this->num_of_sampled_actions; ++i)
424
+ {
425
+
426
+ if (this->continuous_action_space == true)
427
+ {
428
+ CAction action = CAction(sampled_actions_after_tanh[i], 0);
429
+ std::vector<CAction> legal_actions;
430
+ this->children[action.get_combined_hash()] = CNode(sampled_actions_log_probs_after_tanh[i], legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space); // only for muzero/efficient zero, not support alphazero
431
+ this->legal_actions.push_back(action);
432
+ }
433
+ else
434
+ {
435
+ std::vector<float> sampled_action_tmp;
436
+ for (size_t iter = 0; iter < 1; iter++)
437
+ {
438
+ sampled_action_tmp.push_back(float(sampled_actions[i]));
439
+ }
440
+ CAction action = CAction(sampled_action_tmp, 0);
441
+ std::vector<CAction> legal_actions;
442
+ this->children[action.get_combined_hash()] = CNode(sampled_actions_probs[i], legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space); // only for muzero/efficient zero, not support alphazero
443
+ this->legal_actions.push_back(action);
444
+ }
445
+ }
446
+
447
+ #ifdef _WIN32
448
+ // 释放数组内存
449
+ delete[] policy;
450
+ #else
451
+ #endif
452
+ }
453
+
454
+ void CNode::add_exploration_noise(float exploration_fraction, const std::vector<float> &noises)
455
+ {
456
+ /*
457
+ Overview:
458
+ Add a noise to the prior of the child nodes.
459
+ Arguments:
460
+ - exploration_fraction: the fraction to add noise.
461
+ - noises: the vector of noises added to each child node.
462
+ */
463
+ float noise, prior;
464
+ for (int i = 0; i < this->num_of_sampled_actions; ++i)
465
+ {
466
+
467
+ noise = noises[i];
468
+ CNode *child = this->get_child(this->legal_actions[i]);
469
+ prior = child->prior;
470
+ if (this->continuous_action_space == true)
471
+ {
472
+ // if prior is log_prob
473
+ child->prior = log(exp(prior) * (1 - exploration_fraction) + noise * exploration_fraction + 1e-6);
474
+ }
475
+ else
476
+ {
477
+ // if prior is prob
478
+ child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction;
479
+ }
480
+ }
481
+ }
482
+
483
+ float CNode::compute_mean_q(int isRoot, float parent_q, float discount_factor)
484
+ {
485
+ /*
486
+ Overview:
487
+ Compute the mean q value of the current node.
488
+ Arguments:
489
+ - isRoot: whether the current node is a root node.
490
+ - parent_q: the q value of the parent node.
491
+ - discount_factor: the discount_factor of reward.
492
+ */
493
+ float total_unsigned_q = 0.0;
494
+ int total_visits = 0;
495
+ float parent_value_prefix = this->value_prefix;
496
+ for (auto a : this->legal_actions)
497
+ {
498
+ CNode *child = this->get_child(a);
499
+ if (child->visit_count > 0)
500
+ {
501
+ float true_reward = child->value_prefix - parent_value_prefix;
502
+ if (this->is_reset == 1)
503
+ {
504
+ true_reward = child->value_prefix;
505
+ }
506
+ float qsa = true_reward + discount_factor * child->value();
507
+ total_unsigned_q += qsa;
508
+ total_visits += 1;
509
+ }
510
+ }
511
+
512
+ float mean_q = 0.0;
513
+ if (isRoot && total_visits > 0)
514
+ {
515
+ mean_q = (total_unsigned_q) / (total_visits);
516
+ }
517
+ else
518
+ {
519
+ mean_q = (parent_q + total_unsigned_q) / (total_visits + 1);
520
+ }
521
+ return mean_q;
522
+ }
523
+
524
+ void CNode::print_out()
525
+ {
526
+ return;
527
+ }
528
+
529
+ int CNode::expanded()
530
+ {
531
+ /*
532
+ Overview:
533
+ Return whether the current node is expanded.
534
+ */
535
+ return this->children.size() > 0;
536
+ }
537
+
538
+ float CNode::value()
539
+ {
540
+ /*
541
+ Overview:
542
+ Return the real value of the current tree.
543
+ */
544
+ float true_value = 0.0;
545
+ if (this->visit_count == 0)
546
+ {
547
+ return true_value;
548
+ }
549
+ else
550
+ {
551
+ true_value = this->value_sum / this->visit_count;
552
+ return true_value;
553
+ }
554
+ }
555
+
556
+ std::vector<std::vector<float> > CNode::get_trajectory()
557
+ {
558
+ /*
559
+ Overview:
560
+ Find the current best trajectory starts from the current node.
561
+ Outputs:
562
+ - traj: a vector of node index, which is the current best trajectory from this node.
563
+ */
564
+ std::vector<CAction> traj;
565
+
566
+ CNode *node = this;
567
+ CAction best_action = node->best_action;
568
+ while (best_action.is_root_action != 1)
569
+ {
570
+ traj.push_back(best_action);
571
+ node = node->get_child(best_action);
572
+ best_action = node->best_action;
573
+ }
574
+
575
+ std::vector<std::vector<float> > traj_return;
576
+ for (int i = 0; i < traj.size(); ++i)
577
+ {
578
+ traj_return.push_back(traj[i].value);
579
+ }
580
+ return traj_return;
581
+ }
582
+
583
+ std::vector<int> CNode::get_children_distribution()
584
+ {
585
+ /*
586
+ Overview:
587
+ Get the distribution of child nodes in the format of visit_count.
588
+ Outputs:
589
+ - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
590
+ */
591
+ std::vector<int> distribution;
592
+ if (this->expanded())
593
+ {
594
+ for (auto a : this->legal_actions)
595
+ {
596
+ CNode *child = this->get_child(a);
597
+ distribution.push_back(child->visit_count);
598
+ }
599
+ }
600
+ return distribution;
601
+ }
602
+
603
+ CNode *CNode::get_child(CAction action)
604
+ {
605
+ /*
606
+ Overview:
607
+ Get the child node corresponding to the input action.
608
+ Arguments:
609
+ - action: the action to get child.
610
+ */
611
+ return &(this->children[action.get_combined_hash()]);
612
+ // TODO(pu): no hash
613
+ // return &(this->children[action]);
614
+ // return &(this->children[action.value[0]]);
615
+ }
616
+
617
+ //*********************************************************
618
+
619
+ CRoots::CRoots()
620
+ {
621
+ this->root_num = 0;
622
+ this->num_of_sampled_actions = 20;
623
+ }
624
+
625
+ CRoots::CRoots(int root_num, std::vector<std::vector<float> > legal_actions_list, int action_space_size, int num_of_sampled_actions, bool continuous_action_space)
626
+ {
627
+ /*
628
+ Overview:
629
+ Initialization of CNode with root_num, legal_actions_list, action_space_size, num_of_sampled_actions, continuous_action_space.
630
+ Arguments:
631
+ - root_num: the number of the current root.
632
+ - legal_action_list: the vector of the legal action of this root.
633
+ - action_space_size: the size of action space of the current env.
634
+ - num_of_sampled_actions: the number of sampled actions, i.e. K in the Sampled MuZero papers.
635
+ - continuous_action_space: whether the action space is continous in current env.
636
+ */
637
+ this->root_num = root_num;
638
+ this->legal_actions_list = legal_actions_list;
639
+ this->continuous_action_space = continuous_action_space;
640
+
641
+ // sampled related core code
642
+ this->num_of_sampled_actions = num_of_sampled_actions;
643
+ this->action_space_size = action_space_size;
644
+
645
+ for (int i = 0; i < this->root_num; ++i)
646
+ {
647
+ if (this->continuous_action_space == true and this->legal_actions_list[0][0] == -1)
648
+ {
649
+ // continous action space
650
+ std::vector<CAction> legal_actions;
651
+ this->roots.push_back(CNode(0, legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space));
652
+ }
653
+ else if (this->continuous_action_space == false or this->legal_actions_list[0][0] == -1)
654
+ {
655
+ // sampled
656
+ // discrete action space without action mask
657
+ std::vector<CAction> legal_actions;
658
+ this->roots.push_back(CNode(0, legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space));
659
+ }
660
+
661
+ else
662
+ {
663
+ // TODO(pu): discrete action space
664
+ std::vector<CAction> c_legal_actions;
665
+ for (int i = 0; i < this->legal_actions_list.size(); ++i)
666
+ {
667
+ CAction c_legal_action = CAction(legal_actions_list[i], 0);
668
+ c_legal_actions.push_back(c_legal_action);
669
+ }
670
+ this->roots.push_back(CNode(0, c_legal_actions, this->action_space_size, this->num_of_sampled_actions, this->continuous_action_space));
671
+ }
672
+ }
673
+ }
674
+
675
+ CRoots::~CRoots() {}
676
+
677
+ void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
678
+ {
679
+ /*
680
+ Overview:
681
+ Expand the roots and add noises.
682
+ Arguments:
683
+ - root_noise_weight: the exploration fraction of roots
684
+ - noises: the vector of noise add to the roots.
685
+ - value_prefixs: the vector of value prefixs of each root.
686
+ - policies: the vector of policy logits of each root.
687
+ - to_play_batch: the vector of the player side of each root.
688
+ */
689
+
690
+ // sampled related core code
691
+ for (int i = 0; i < this->root_num; ++i)
692
+ {
693
+ this->roots[i].expand(to_play_batch[i], 0, i, value_prefixs[i], policies[i]);
694
+ this->roots[i].add_exploration_noise(root_noise_weight, noises[i]);
695
+ this->roots[i].visit_count += 1;
696
+ }
697
+ }
698
+
699
+ void CRoots::prepare_no_noise(const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
700
+ {
701
+ /*
702
+ Overview:
703
+ Expand the roots without noise.
704
+ Arguments:
705
+ - value_prefixs: the vector of value prefixs of each root.
706
+ - policies: the vector of policy logits of each root.
707
+ - to_play_batch: the vector of the player side of each root.
708
+ */
709
+ for (int i = 0; i < this->root_num; ++i)
710
+ {
711
+ this->roots[i].expand(to_play_batch[i], 0, i, value_prefixs[i], policies[i]);
712
+
713
+ this->roots[i].visit_count += 1;
714
+ }
715
+ }
716
+
717
+ void CRoots::clear()
718
+ {
719
+ this->roots.clear();
720
+ }
721
+
722
+ std::vector<std::vector<std::vector<float> > > CRoots::get_trajectories()
723
+ {
724
+ /*
725
+ Overview:
726
+ Find the current best trajectory starts from each root.
727
+ Outputs:
728
+ - traj: a vector of node index, which is the current best trajectory from each root.
729
+ */
730
+ std::vector<std::vector<std::vector<float> > > trajs;
731
+ trajs.reserve(this->root_num);
732
+
733
+ for (int i = 0; i < this->root_num; ++i)
734
+ {
735
+ trajs.push_back(this->roots[i].get_trajectory());
736
+ }
737
+ return trajs;
738
+ }
739
+
740
+ std::vector<std::vector<int> > CRoots::get_distributions()
741
+ {
742
+ /*
743
+ Overview:
744
+ Get the children distribution of each root.
745
+ Outputs:
746
+ - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
747
+ */
748
+ std::vector<std::vector<int> > distributions;
749
+ distributions.reserve(this->root_num);
750
+
751
+ for (int i = 0; i < this->root_num; ++i)
752
+ {
753
+ distributions.push_back(this->roots[i].get_children_distribution());
754
+ }
755
+ return distributions;
756
+ }
757
+
758
+ // sampled related core code
759
+ std::vector<std::vector<std::vector<float> > > CRoots::get_sampled_actions()
760
+ {
761
+ /*
762
+ Overview:
763
+ Get the sampled_actions of each root.
764
+ Outputs:
765
+ - python_sampled_actions: a vector of sampled_actions for each root, e.g. the size of original action space is 6, the K=3,
766
+ python_sampled_actions = [[1,3,0], [2,4,0], [5,4,1]].
767
+ */
768
+ std::vector<std::vector<CAction> > sampled_actions;
769
+ std::vector<std::vector<std::vector<float> > > python_sampled_actions;
770
+
771
+ // sampled_actions.reserve(this->root_num);
772
+
773
+ for (int i = 0; i < this->root_num; ++i)
774
+ {
775
+ std::vector<CAction> sampled_action;
776
+ sampled_action = this->roots[i].legal_actions;
777
+ std::vector<std::vector<float> > python_sampled_action;
778
+
779
+ for (int j = 0; j < this->roots[i].legal_actions.size(); ++j)
780
+ {
781
+ python_sampled_action.push_back(sampled_action[j].value);
782
+ }
783
+ python_sampled_actions.push_back(python_sampled_action);
784
+ }
785
+
786
+ return python_sampled_actions;
787
+ }
788
+
789
+ std::vector<float> CRoots::get_values()
790
+ {
791
+ /*
792
+ Overview:
793
+ Return the estimated value of each root.
794
+ */
795
+ std::vector<float> values;
796
+ for (int i = 0; i < this->root_num; ++i)
797
+ {
798
+ values.push_back(this->roots[i].value());
799
+ }
800
+ return values;
801
+ }
802
+
803
+ //*********************************************************
804
+ //
805
+ void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players)
806
+ {
807
+ /*
808
+ Overview:
809
+ Update the q value of the root and its child nodes.
810
+ Arguments:
811
+ - root: the root that update q value from.
812
+ - min_max_stats: a tool used to min-max normalize the q value.
813
+ - discount_factor: the discount factor of reward.
814
+ - players: the number of players.
815
+ */
816
+ std::stack<CNode *> node_stack;
817
+ node_stack.push(root);
818
+ float parent_value_prefix = 0.0;
819
+ int is_reset = 0;
820
+ while (node_stack.size() > 0)
821
+ {
822
+ CNode *node = node_stack.top();
823
+ node_stack.pop();
824
+
825
+ if (node != root)
826
+ {
827
+ // NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
828
+ // but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
829
+ // true_reward = node.value_prefix - (- parent_value_prefix)
830
+ float true_reward = node->value_prefix - node->parent_value_prefix;
831
+
832
+ if (is_reset == 1)
833
+ {
834
+ true_reward = node->value_prefix;
835
+ }
836
+ float qsa;
837
+ if (players == 1)
838
+ qsa = true_reward + discount_factor * node->value();
839
+ else if (players == 2)
840
+ // TODO(pu): why only the last reward multiply the discount_factor?
841
+ qsa = true_reward + discount_factor * (-1) * node->value();
842
+
843
+ min_max_stats.update(qsa);
844
+ }
845
+
846
+ for (auto a : node->legal_actions)
847
+ {
848
+ CNode *child = node->get_child(a);
849
+ if (child->expanded())
850
+ {
851
+ child->parent_value_prefix = node->value_prefix;
852
+ node_stack.push(child);
853
+ }
854
+ }
855
+
856
+ is_reset = node->is_reset;
857
+ }
858
+ }
859
+
860
+ void cbackpropagate(std::vector<CNode *> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor)
861
+ {
862
+ /*
863
+ Overview:
864
+ Update the value sum and visit count of nodes along the search path.
865
+ Arguments:
866
+ - search_path: a vector of nodes on the search path.
867
+ - min_max_stats: a tool used to min-max normalize the q value.
868
+ - to_play: which player to play the game in the current node.
869
+ - value: the value to propagate along the search path.
870
+ - discount_factor: the discount factor of reward.
871
+ */
872
+ assert(to_play == -1 || to_play == 1 || to_play == 2);
873
+ if (to_play == -1)
874
+ {
875
+ // for play-with-bot-mode
876
+ float bootstrap_value = value;
877
+ int path_len = search_path.size();
878
+ for (int i = path_len - 1; i >= 0; --i)
879
+ {
880
+ CNode *node = search_path[i];
881
+ node->value_sum += bootstrap_value;
882
+ node->visit_count += 1;
883
+
884
+ float parent_value_prefix = 0.0;
885
+ int is_reset = 0;
886
+ if (i >= 1)
887
+ {
888
+ CNode *parent = search_path[i - 1];
889
+ parent_value_prefix = parent->value_prefix;
890
+ is_reset = parent->is_reset;
891
+ }
892
+
893
+ float true_reward = node->value_prefix - parent_value_prefix;
894
+ min_max_stats.update(true_reward + discount_factor * node->value());
895
+
896
+ if (is_reset == 1)
897
+ {
898
+ // parent is reset.
899
+ true_reward = node->value_prefix;
900
+ }
901
+
902
+ bootstrap_value = true_reward + discount_factor * bootstrap_value;
903
+ }
904
+ }
905
+ else
906
+ {
907
+ // for self-play-mode
908
+ float bootstrap_value = value;
909
+ int path_len = search_path.size();
910
+ for (int i = path_len - 1; i >= 0; --i)
911
+ {
912
+ CNode *node = search_path[i];
913
+ if (node->to_play == to_play)
914
+ node->value_sum += bootstrap_value;
915
+ else
916
+ node->value_sum += -bootstrap_value;
917
+ node->visit_count += 1;
918
+
919
+ float parent_value_prefix = 0.0;
920
+ int is_reset = 0;
921
+ if (i >= 1)
922
+ {
923
+ CNode *parent = search_path[i - 1];
924
+ parent_value_prefix = parent->value_prefix;
925
+ is_reset = parent->is_reset;
926
+ }
927
+
928
+ // NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
929
+ // but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
930
+ float true_reward = node->value_prefix - parent_value_prefix;
931
+
932
+ min_max_stats.update(true_reward + discount_factor * node->value());
933
+
934
+ if (is_reset == 1)
935
+ {
936
+ // parent is reset.
937
+ true_reward = node->value_prefix;
938
+ }
939
+ if (node->to_play == to_play)
940
+ bootstrap_value = -true_reward + discount_factor * bootstrap_value;
941
+ else
942
+ bootstrap_value = true_reward + discount_factor * bootstrap_value;
943
+ }
944
+ }
945
+ }
946
+
947
+ void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> is_reset_list, std::vector<int> &to_play_batch)
948
+ {
949
+ /*
950
+ Overview:
951
+ Expand the nodes along the search path and update the infos.
952
+ Arguments:
953
+ - current_latent_state_index: The index of latent state of the leaf node in the search path.
954
+ - discount_factor: the discount factor of reward.
955
+ - value_prefixs: the value prefixs of nodes along the search path.
956
+ - values: the values to propagate along the search path.
957
+ - policies: the policy logits of nodes along the search path.
958
+ - min_max_stats: a tool used to min-max normalize the q value.
959
+ - results: the search results.
960
+ - is_reset_list: the vector of is_reset nodes along the search path, where is_reset represents for whether the parent value prefix needs to be reset.
961
+ - to_play_batch: the batch of which player is playing on this node.
962
+ */
963
+ for (int i = 0; i < results.num; ++i)
964
+ {
965
+ results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[i], policies[i]);
966
+ // reset
967
+ results.nodes[i]->is_reset = is_reset_list[i];
968
+
969
+ cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[i], discount_factor);
970
+ }
971
+ }
972
+
973
+ CAction cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, bool continuous_action_space)
974
+ {
975
+ /*
976
+ Overview:
977
+ Select the child node of the roots according to ucb scores.
978
+ Arguments:
979
+ - root: the roots to select the child node.
980
+ - min_max_stats: a tool used to min-max normalize the score.
981
+ - pb_c_base: constants c2 in muzero.
982
+ - pb_c_init: constants c1 in muzero.
983
+ - disount_factor: the discount factor of reward.
984
+ - mean_q: the mean q value of the parent node.
985
+ - players: the number of players.
986
+ - continuous_action_space: whether the action space is continous in current env.
987
+ Outputs:
988
+ - action: the action to select.
989
+ */
990
+ // sampled related core code
991
+ // TODO(pu): Progressive widening (See https://hal.archives-ouvertes.fr/hal-00542673v2/document)
992
+ float max_score = FLOAT_MIN;
993
+ const float epsilon = 0.000001;
994
+ std::vector<CAction> max_index_lst;
995
+ for (auto a : root->legal_actions)
996
+ {
997
+
998
+ CNode *child = root->get_child(a);
999
+ // sampled related core code
1000
+ float temp_score = cucb_score(root, child, min_max_stats, mean_q, root->is_reset, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players, continuous_action_space);
1001
+
1002
+ if (max_score < temp_score)
1003
+ {
1004
+ max_score = temp_score;
1005
+
1006
+ max_index_lst.clear();
1007
+ max_index_lst.push_back(a);
1008
+ }
1009
+ else if (temp_score >= max_score - epsilon)
1010
+ {
1011
+ max_index_lst.push_back(a);
1012
+ }
1013
+ }
1014
+
1015
+ // python code: int action = 0;
1016
+ CAction action;
1017
+ if (max_index_lst.size() > 0)
1018
+ {
1019
+ int rand_index = rand() % max_index_lst.size();
1020
+ action = max_index_lst[rand_index];
1021
+ }
1022
+ return action;
1023
+ }
1024
+
1025
+ // sampled related core code
1026
+ float cucb_score(CNode *parent, CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players, bool continuous_action_space)
1027
+ {
1028
+ /*
1029
+ Overview:
1030
+ Compute the ucb score of the child.
1031
+ Arguments:
1032
+ - child: the child node to compute ucb score.
1033
+ - min_max_stats: a tool used to min-max normalize the score.
1034
+ - parent_mean_q: the mean q value of the parent node.
1035
+ - is_reset: whether the value prefix needs to be reset.
1036
+ - total_children_visit_counts: the total visit counts of the child nodes of the parent node.
1037
+ - parent_value_prefix: the value prefix of parent node.
1038
+ - pb_c_base: constants c2 in muzero.
1039
+ - pb_c_init: constants c1 in muzero.
1040
+ - disount_factor: the discount factor of reward.
1041
+ - players: the number of players.
1042
+ - continuous_action_space: whether the action space is continous in current env.
1043
+ Outputs:
1044
+ - ucb_value: the ucb score of the child.
1045
+ */
1046
+ float pb_c = 0.0, prior_score = 0.0, value_score = 0.0;
1047
+ pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init;
1048
+ pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1));
1049
+
1050
+ // prior_score = pb_c * child->prior;
1051
+
1052
+ // sampled related core code
1053
+ // TODO(pu): empirical distribution
1054
+ std::string empirical_distribution_type = "density";
1055
+ if (empirical_distribution_type.compare("density"))
1056
+ {
1057
+ if (continuous_action_space == true)
1058
+ {
1059
+ float empirical_prob_sum = 0;
1060
+ for (int i = 0; i < parent->children.size(); ++i)
1061
+ {
1062
+ empirical_prob_sum += exp(parent->get_child(parent->legal_actions[i])->prior);
1063
+ }
1064
+ prior_score = pb_c * exp(child->prior) / (empirical_prob_sum + 1e-6);
1065
+ }
1066
+ else
1067
+ {
1068
+ float empirical_prob_sum = 0;
1069
+ for (int i = 0; i < parent->children.size(); ++i)
1070
+ {
1071
+ empirical_prob_sum += parent->get_child(parent->legal_actions[i])->prior;
1072
+ }
1073
+ prior_score = pb_c * child->prior / (empirical_prob_sum + 1e-6);
1074
+ }
1075
+ }
1076
+ else if (empirical_distribution_type.compare("uniform"))
1077
+ {
1078
+ prior_score = pb_c * 1 / parent->children.size();
1079
+ }
1080
+ // sampled related core code
1081
+ if (child->visit_count == 0)
1082
+ {
1083
+ value_score = parent_mean_q;
1084
+ }
1085
+ else
1086
+ {
1087
+ float true_reward = child->value_prefix - parent_value_prefix;
1088
+ if (is_reset == 1)
1089
+ {
1090
+ true_reward = child->value_prefix;
1091
+ }
1092
+
1093
+ if (players == 1)
1094
+ value_score = true_reward + discount_factor * child->value();
1095
+ else if (players == 2)
1096
+ value_score = true_reward + discount_factor * (-child->value());
1097
+ }
1098
+
1099
+ value_score = min_max_stats.normalize(value_score);
1100
+
1101
+ if (value_score < 0)
1102
+ value_score = 0;
1103
+ if (value_score > 1)
1104
+ value_score = 1;
1105
+
1106
+ float ucb_value = prior_score + value_score;
1107
+ return ucb_value;
1108
+ }
1109
+
1110
+ void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch, bool continuous_action_space)
1111
+ {
1112
+ /*
1113
+ Overview:
1114
+ Search node path from the roots.
1115
+ Arguments:
1116
+ - roots: the roots that search from.
1117
+ - pb_c_base: constants c2 in muzero.
1118
+ - pb_c_init: constants c1 in muzero.
1119
+ - disount_factor: the discount factor of reward.
1120
+ - min_max_stats: a tool used to min-max normalize the score.
1121
+ - results: the search results.
1122
+ - virtual_to_play_batch: the batch of which player is playing on this node.
1123
+ - continuous_action_space: whether the action space is continous in current env.
1124
+ */
1125
+ // set seed
1126
+ get_time_and_set_rand_seed();
1127
+
1128
+ std::vector<float> null_value;
1129
+ for (int i = 0; i < 1; ++i)
1130
+ {
1131
+ null_value.push_back(i + 0.1);
1132
+ }
1133
+ // CAction last_action = CAction(null_value, 1);
1134
+ std::vector<float> last_action;
1135
+ float parent_q = 0.0;
1136
+ results.search_lens = std::vector<int>();
1137
+
1138
+ int players = 0;
1139
+ int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2
1140
+ if (largest_element == -1)
1141
+ players = 1;
1142
+ else
1143
+ players = 2;
1144
+
1145
+ for (int i = 0; i < results.num; ++i)
1146
+ {
1147
+ CNode *node = &(roots->roots[i]);
1148
+ int is_root = 1;
1149
+ int search_len = 0;
1150
+ results.search_paths[i].push_back(node);
1151
+
1152
+ while (node->expanded())
1153
+ {
1154
+ float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor);
1155
+ is_root = 0;
1156
+ parent_q = mean_q;
1157
+
1158
+ CAction action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players, continuous_action_space);
1159
+ if (players > 1)
1160
+ {
1161
+ assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2);
1162
+ if (virtual_to_play_batch[i] == 1)
1163
+ virtual_to_play_batch[i] = 2;
1164
+ else
1165
+ virtual_to_play_batch[i] = 1;
1166
+ }
1167
+
1168
+ node->best_action = action; // CAction
1169
+ // next
1170
+ node = node->get_child(action);
1171
+ last_action = action.value;
1172
+
1173
+ results.search_paths[i].push_back(node);
1174
+ search_len += 1;
1175
+ }
1176
+
1177
+ CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2];
1178
+
1179
+ results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index);
1180
+ results.latent_state_index_in_batch.push_back(parent->batch_index);
1181
+
1182
+ results.last_actions.push_back(last_action);
1183
+ results.search_lens.push_back(search_len);
1184
+ results.nodes.push_back(node);
1185
+ results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
1186
+ }
1187
+ }
1188
+
1189
+ }
LightZero/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.h ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // C++11
2
+
3
+ #ifndef CNODE_H
4
+ #define CNODE_H
5
+
6
+ #include "../../common_lib/cminimax.h"
7
+ #include <math.h>
8
+ #include <vector>
9
+ #include <stack>
10
+ #include <stdlib.h>
11
+ #include <time.h>
12
+ #include <cmath>
13
+ #include <sys/timeb.h>
14
+ #include <time.h>
15
+ #include <map>
16
+
17
+ const int DEBUG_MODE = 0;
18
+
19
+ namespace tree
20
+ {
21
+ // sampled related core code
22
+ class CAction
23
+ {
24
+ public:
25
+ std::vector<float> value;
26
+ std::vector<size_t> hash;
27
+ int is_root_action;
28
+
29
+ CAction();
30
+ CAction(std::vector<float> value, int is_root_action);
31
+ ~CAction();
32
+
33
+ std::vector<size_t> get_hash(void);
34
+ std::size_t get_combined_hash(void);
35
+ };
36
+
37
+ class CNode
38
+ {
39
+ public:
40
+ int visit_count, to_play, current_latent_state_index, batch_index, is_reset, action_space_size;
41
+ // sampled related core code
42
+ CAction best_action;
43
+ int num_of_sampled_actions;
44
+ float value_prefix, prior, value_sum;
45
+ float parent_value_prefix;
46
+ bool continuous_action_space;
47
+ std::vector<int> children_index;
48
+ std::map<size_t, CNode> children;
49
+
50
+ std::vector<CAction> legal_actions;
51
+
52
+ CNode();
53
+ // sampled related core code
54
+ CNode(float prior, std::vector<CAction> &legal_actions, int action_space_size, int num_of_sampled_actions, bool continuous_action_space);
55
+ ~CNode();
56
+
57
+ void expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector<float> &policy_logits);
58
+ void add_exploration_noise(float exploration_fraction, const std::vector<float> &noises);
59
+ float compute_mean_q(int isRoot, float parent_q, float discount_factor);
60
+ void print_out();
61
+
62
+ int expanded();
63
+
64
+ float value();
65
+
66
+ // sampled related core code
67
+ std::vector<std::vector<float> > get_trajectory();
68
+ std::vector<int> get_children_distribution();
69
+ CNode *get_child(CAction action);
70
+ };
71
+
72
+ class CRoots
73
+ {
74
+ public:
75
+ int root_num;
76
+ int num_of_sampled_actions;
77
+ int action_space_size;
78
+ std::vector<CNode> roots;
79
+ std::vector<std::vector<float> > legal_actions_list;
80
+ bool continuous_action_space;
81
+
82
+ CRoots();
83
+ CRoots(int root_num, std::vector<std::vector<float> > legal_actions_list, int action_space_size, int num_of_sampled_actions, bool continuous_action_space);
84
+ ~CRoots();
85
+
86
+ void prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
87
+ void prepare_no_noise(const std::vector<float> &value_prefixs, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
88
+ void clear();
89
+ // sampled related core code
90
+ std::vector<std::vector<std::vector<float> > > get_trajectories();
91
+ std::vector<std::vector<std::vector<float> > > get_sampled_actions();
92
+
93
+ std::vector<std::vector<int> > get_distributions();
94
+
95
+ std::vector<float> get_values();
96
+ };
97
+
98
+ class CSearchResults
99
+ {
100
+ public:
101
+ int num;
102
+ std::vector<int> latent_state_index_in_search_path, latent_state_index_in_batch, search_lens;
103
+ std::vector<int> virtual_to_play_batchs;
104
+ std::vector<std::vector<float> > last_actions;
105
+
106
+ std::vector<CNode *> nodes;
107
+ std::vector<std::vector<CNode *> > search_paths;
108
+
109
+ CSearchResults();
110
+ CSearchResults(int num);
111
+ ~CSearchResults();
112
+ };
113
+
114
+ //*********************************************************
115
+ void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players);
116
+ void cbackpropagate(std::vector<CNode *> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor);
117
+ void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> is_reset_list, std::vector<int> &to_play_batch);
118
+ CAction cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, bool continuous_action_space);
119
+ float cucb_score(CNode *parent, CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players, bool continuous_action_space);
120
+ void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch, bool continuous_action_space);
121
+ }
122
+
123
+ #endif
LightZero/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // C++11
2
+
3
+ #include <iostream>
4
+ #include "cnode.h"
5
+ #include <algorithm>
6
+ #include <map>
7
+ #include <cassert>
8
+ #include <numeric>
9
+ #include <iostream>
10
+ #include <vector>
11
+ #include <map>
12
+ #include <random>
13
+ #include <algorithm>
14
+ #include <iterator>
15
+
16
+ #ifdef _WIN32
17
+ #include "..\..\common_lib\utils.cpp"
18
+ #else
19
+ #include "../../common_lib/utils.cpp"
20
+ #endif
21
+
22
+
23
+ namespace tree
24
+ {
25
+
26
+ CSearchResults::CSearchResults()
27
+ {
28
+ /*
29
+ Overview:
30
+ Initialization of CSearchResults, the default result number is set to 0.
31
+ */
32
+ this->num = 0;
33
+ }
34
+
35
+ CSearchResults::CSearchResults(int num)
36
+ {
37
+ /*
38
+ Overview:
39
+ Initialization of CSearchResults with result number.
40
+ */
41
+ this->num = num;
42
+ for (int i = 0; i < num; ++i)
43
+ {
44
+ this->search_paths.push_back(std::vector<CNode *>());
45
+ }
46
+ }
47
+
48
+ CSearchResults::~CSearchResults() {}
49
+
50
+ //*********************************************************
51
+
52
+ CNode::CNode()
53
+ {
54
+ /*
55
+ Overview:
56
+ Initialization of CNode.
57
+ */
58
+ this->prior = 0;
59
+ this->legal_actions = legal_actions;
60
+
61
+ this->visit_count = 0;
62
+ this->value_sum = 0;
63
+ this->best_action = -1;
64
+ this->to_play = 0;
65
+ this->reward = 0.0;
66
+ this->is_chance = false;
67
+ this->chance_space_size= 2;
68
+
69
+ }
70
+
71
+ CNode::CNode(float prior, std::vector<int> &legal_actions, bool is_chance, int chance_space_size)
72
+ {
73
+ /*
74
+ Overview:
75
+ Initialization of CNode with prior value and legal actions.
76
+ Arguments:
77
+ - prior: the prior value of this node.
78
+ - legal_actions: a vector of legal actions of this node.
79
+ */
80
+ this->prior = prior;
81
+ this->legal_actions = legal_actions;
82
+
83
+ this->visit_count = 0;
84
+ this->value_sum = 0;
85
+ this->best_action = -1;
86
+ this->to_play = 0;
87
+ this->current_latent_state_index = -1;
88
+ this->batch_index = -1;
89
+ this->is_chance = is_chance;
90
+ this->chance_space_size = chance_space_size;
91
+ }
92
+
93
+ CNode::~CNode() {}
94
+
95
+ void CNode::expand(int to_play, int current_latent_state_index, int batch_index, float reward, const std::vector<float> &policy_logits, bool child_is_chance)
96
+ {
97
+ /*
98
+ Overview:
99
+ Expand the child nodes of the current node.
100
+ Arguments:
101
+ - to_play: which player to play the game in the current node.
102
+ - current_latent_state_index: The index of latent state of the leaf node in the search path of the current node.
103
+ - batch_index: The index of latent state of the leaf node in the search path of the current node.
104
+ - reward: the reward of the current node.
105
+ - policy_logits: the logit of the child nodes.
106
+ */
107
+ this->to_play = to_play;
108
+ this->current_latent_state_index = current_latent_state_index;
109
+ this->batch_index = batch_index;
110
+ this->reward = reward;
111
+
112
+
113
+ // assert((this->is_chance != child_is_chance) && "is_chance and child_is_chance should be different");
114
+
115
+ if(this->is_chance == true){
116
+ child_is_chance = false;
117
+ this->reward = 0.0;
118
+ }
119
+ else{
120
+ child_is_chance = true;
121
+ }
122
+
123
+ int action_num = policy_logits.size();
124
+ if (this->legal_actions.size() == 0)
125
+ {
126
+ for (int i = 0; i < action_num; ++i)
127
+ {
128
+ this->legal_actions.push_back(i);
129
+ }
130
+ }
131
+
132
+ float temp_policy;
133
+ float policy_sum = 0.0;
134
+
135
+ #ifdef _WIN32
136
+ // 创建动态数组
137
+ float* policy = new float[action_num];
138
+ #else
139
+ float policy[action_num];
140
+ #endif
141
+
142
+ float policy_max = FLOAT_MIN;
143
+ for (auto a : this->legal_actions)
144
+ {
145
+ if (policy_max < policy_logits[a])
146
+ {
147
+ policy_max = policy_logits[a];
148
+ }
149
+ }
150
+
151
+ for (auto a : this->legal_actions)
152
+ {
153
+ temp_policy = exp(policy_logits[a] - policy_max);
154
+ policy_sum += temp_policy;
155
+ policy[a] = temp_policy;
156
+ }
157
+
158
+ float prior;
159
+ for (auto a : this->legal_actions)
160
+ {
161
+ prior = policy[a] / policy_sum;
162
+ std::vector<int> tmp_empty;
163
+ this->children[a] = CNode(prior, tmp_empty, child_is_chance, this->chance_space_size); // only for muzero/efficient zero, not support alphazero
164
+ // this->children[a] = CNode(prior, tmp_empty, is_chance = child_is_chance); // only for muzero/efficient zero, not support alphazero
165
+ }
166
+
167
+ #ifdef _WIN32
168
+ // 释放数组内存
169
+ delete[] policy;
170
+ #else
171
+ #endif
172
+ }
173
+
174
+ void CNode::add_exploration_noise(float exploration_fraction, const std::vector<float> &noises)
175
+ {
176
+ /*
177
+ Overview:
178
+ Add a noise to the prior of the child nodes.
179
+ Arguments:
180
+ - exploration_fraction: the fraction to add noise.
181
+ - noises: the vector of noises added to each child node.
182
+ */
183
+ float noise, prior;
184
+ for (int i = 0; i < this->legal_actions.size(); ++i)
185
+ {
186
+ noise = noises[i];
187
+ CNode *child = this->get_child(this->legal_actions[i]);
188
+
189
+ prior = child->prior;
190
+ child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction;
191
+ }
192
+ }
193
+
194
+ float CNode::compute_mean_q(int isRoot, float parent_q, float discount_factor)
195
+ {
196
+ /*
197
+ Overview:
198
+ Compute the mean q value of the current node.
199
+ Arguments:
200
+ - isRoot: whether the current node is a root node.
201
+ - parent_q: the q value of the parent node.
202
+ - discount_factor: the discount_factor of reward.
203
+ */
204
+ float total_unsigned_q = 0.0;
205
+ int total_visits = 0;
206
+ for (auto a : this->legal_actions)
207
+ {
208
+ CNode *child = this->get_child(a);
209
+ if (child->visit_count > 0)
210
+ {
211
+ float true_reward = child->reward;
212
+ float qsa = true_reward + discount_factor * child->value();
213
+ total_unsigned_q += qsa;
214
+ total_visits += 1;
215
+ }
216
+ }
217
+
218
+ float mean_q = 0.0;
219
+ if (isRoot && total_visits > 0)
220
+ {
221
+ mean_q = (total_unsigned_q) / (total_visits);
222
+ }
223
+ else
224
+ {
225
+ mean_q = (parent_q + total_unsigned_q) / (total_visits + 1);
226
+ }
227
+ return mean_q;
228
+ }
229
+
230
+ void CNode::print_out()
231
+ {
232
+ return;
233
+ }
234
+
235
+ int CNode::expanded()
236
+ {
237
+ /*
238
+ Overview:
239
+ Return whether the current node is expanded.
240
+ */
241
+ return this->children.size() > 0;
242
+ }
243
+
244
+ float CNode::value()
245
+ {
246
+ /*
247
+ Overview:
248
+ Return the real value of the current tree.
249
+ */
250
+ float true_value = 0.0;
251
+ if (this->visit_count == 0)
252
+ {
253
+ return true_value;
254
+ }
255
+ else
256
+ {
257
+ true_value = this->value_sum / this->visit_count;
258
+ return true_value;
259
+ }
260
+ }
261
+
262
+ std::vector<int> CNode::get_trajectory()
263
+ {
264
+ /*
265
+ Overview:
266
+ Find the current best trajectory starts from the current node.
267
+ Outputs:
268
+ - traj: a vector of node index, which is the current best trajectory from this node.
269
+ */
270
+ std::vector<int> traj;
271
+
272
+ CNode *node = this;
273
+ int best_action = node->best_action;
274
+ while (best_action >= 0)
275
+ {
276
+ traj.push_back(best_action);
277
+
278
+ node = node->get_child(best_action);
279
+ best_action = node->best_action;
280
+ }
281
+ return traj;
282
+ }
283
+
284
+ std::vector<int> CNode::get_children_distribution()
285
+ {
286
+ /*
287
+ Overview:
288
+ Get the distribution of child nodes in the format of visit_count.
289
+ Outputs:
290
+ - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
291
+ */
292
+ std::vector<int> distribution;
293
+ if (this->expanded())
294
+ {
295
+ for (auto a : this->legal_actions)
296
+ {
297
+ CNode *child = this->get_child(a);
298
+ distribution.push_back(child->visit_count);
299
+ }
300
+ }
301
+ return distribution;
302
+ }
303
+
304
+ CNode *CNode::get_child(int action)
305
+ {
306
+ /*
307
+ Overview:
308
+ Get the child node corresponding to the input action.
309
+ Arguments:
310
+ - action: the action to get child.
311
+ */
312
+ return &(this->children[action]);
313
+ }
314
+
315
+ //*********************************************************
316
+
317
+ CRoots::CRoots()
318
+ {
319
+ /*
320
+ Overview:
321
+ The initialization of CRoots.
322
+ */
323
+ this->root_num = 0;
324
+ }
325
+
326
+ CRoots::CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list, int chance_space_size=2)
327
+ {
328
+ /*
329
+ Overview:
330
+ The initialization of CRoots with root num and legal action lists.
331
+ Arguments:
332
+ - root_num: the number of the current root.
333
+ - legal_action_list: the vector of the legal action of this root.
334
+ */
335
+ this->root_num = root_num;
336
+ this->legal_actions_list = legal_actions_list;
337
+
338
+ for (int i = 0; i < root_num; ++i)
339
+ {
340
+ this->roots.push_back(CNode(0, this->legal_actions_list[i], false, chance_space_size));
341
+ // this->roots.push_back(CNode(0, this->legal_actions_list[i], false));
342
+
343
+ }
344
+ }
345
+
346
+ CRoots::~CRoots() {}
347
+
348
+ void CRoots::prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
349
+ {
350
+ /*
351
+ Overview:
352
+ Expand the roots and add noises.
353
+ Arguments:
354
+ - root_noise_weight: the exploration fraction of roots
355
+ - noises: the vector of noise add to the roots.
356
+ - rewards: the vector of rewards of each root.
357
+ - policies: the vector of policy logits of each root.
358
+ - to_play_batch: the vector of the player side of each root.
359
+ */
360
+ for (int i = 0; i < this->root_num; ++i)
361
+ {
362
+ this->roots[i].expand(to_play_batch[i], 0, i, rewards[i], policies[i], true);
363
+ this->roots[i].add_exploration_noise(root_noise_weight, noises[i]);
364
+
365
+ this->roots[i].visit_count += 1;
366
+ }
367
+ }
368
+
369
+ void CRoots::prepare_no_noise(const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch)
370
+ {
371
+ /*
372
+ Overview:
373
+ Expand the roots without noise.
374
+ Arguments:
375
+ - rewards: the vector of rewards of each root.
376
+ - policies: the vector of policy logits of each root.
377
+ - to_play_batch: the vector of the player side of each root.
378
+ */
379
+ for (int i = 0; i < this->root_num; ++i)
380
+ {
381
+ this->roots[i].expand(to_play_batch[i], 0, i, rewards[i], policies[i], true);
382
+
383
+ this->roots[i].visit_count += 1;
384
+ }
385
+ }
386
+
387
+ void CRoots::clear()
388
+ {
389
+ /*
390
+ Overview:
391
+ Clear the roots vector.
392
+ */
393
+ this->roots.clear();
394
+ }
395
+
396
+ std::vector<std::vector<int> > CRoots::get_trajectories()
397
+ {
398
+ /*
399
+ Overview:
400
+ Find the current best trajectory starts from each root.
401
+ Outputs:
402
+ - traj: a vector of node index, which is the current best trajectory from each root.
403
+ */
404
+ std::vector<std::vector<int> > trajs;
405
+ trajs.reserve(this->root_num);
406
+
407
+ for (int i = 0; i < this->root_num; ++i)
408
+ {
409
+ trajs.push_back(this->roots[i].get_trajectory());
410
+ }
411
+ return trajs;
412
+ }
413
+
414
+ std::vector<std::vector<int> > CRoots::get_distributions()
415
+ {
416
+ /*
417
+ Overview:
418
+ Get the children distribution of each root.
419
+ Outputs:
420
+ - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]).
421
+ */
422
+ std::vector<std::vector<int> > distributions;
423
+ distributions.reserve(this->root_num);
424
+
425
+ for (int i = 0; i < this->root_num; ++i)
426
+ {
427
+ distributions.push_back(this->roots[i].get_children_distribution());
428
+ }
429
+ return distributions;
430
+ }
431
+
432
+ std::vector<float> CRoots::get_values()
433
+ {
434
+ /*
435
+ Overview:
436
+ Return the real value of each root.
437
+ */
438
+ std::vector<float> values;
439
+ for (int i = 0; i < this->root_num; ++i)
440
+ {
441
+ values.push_back(this->roots[i].value());
442
+ }
443
+ return values;
444
+ }
445
+
446
+ //*********************************************************
447
+ //
448
+ void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players)
449
+ {
450
+ /*
451
+ Overview:
452
+ Update the q value of the root and its child nodes.
453
+ Arguments:
454
+ - root: the root that update q value from.
455
+ - min_max_stats: a tool used to min-max normalize the q value.
456
+ - discount_factor: the discount factor of reward.
457
+ - players: the number of players.
458
+ */
459
+ std::stack<CNode *> node_stack;
460
+ node_stack.push(root);
461
+ while (node_stack.size() > 0)
462
+ {
463
+ CNode *node = node_stack.top();
464
+ node_stack.pop();
465
+
466
+ if (node != root)
467
+ {
468
+ // # NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
469
+ // # but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
470
+ // # true_reward = node.value_prefix - (- parent_value_prefix)
471
+ // float true_reward = node->value_prefix - node->parent_value_prefix;
472
+ float true_reward = node->reward;
473
+
474
+ float qsa;
475
+ if (players == 1)
476
+ qsa = true_reward + discount_factor * node->value();
477
+ else if (players == 2)
478
+ // TODO(pu):
479
+ qsa = true_reward + discount_factor * (-1) * node->value();
480
+
481
+ min_max_stats.update(qsa);
482
+ }
483
+
484
+ for (auto a : node->legal_actions)
485
+ {
486
+ CNode *child = node->get_child(a);
487
+ if (child->expanded())
488
+ {
489
+ node_stack.push(child);
490
+ }
491
+ }
492
+ }
493
+ }
494
+
495
+ void cbackpropagate(std::vector<CNode *> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor)
496
+ {
497
+ /*
498
+ Overview:
499
+ Update the value sum and visit count of nodes along the search path.
500
+ Arguments:
501
+ - search_path: a vector of nodes on the search path.
502
+ - min_max_stats: a tool used to min-max normalize the q value.
503
+ - to_play: which player to play the game in the current node.
504
+ - value: the value to propagate along the search path.
505
+ - discount_factor: the discount factor of reward.
506
+ */
507
+ assert(to_play == -1 || to_play == 1 || to_play == 2);
508
+ if (to_play == -1)
509
+ {
510
+ // for play-with-bot-mode
511
+ float bootstrap_value = value;
512
+ int path_len = search_path.size();
513
+ for (int i = path_len - 1; i >= 0; --i)
514
+ {
515
+ CNode *node = search_path[i];
516
+ node->value_sum += bootstrap_value;
517
+ node->visit_count += 1;
518
+
519
+ float true_reward = node->reward;
520
+
521
+ min_max_stats.update(true_reward + discount_factor * node->value());
522
+
523
+ bootstrap_value = true_reward + discount_factor * bootstrap_value;
524
+ // std::cout << "to_play: " << to_play << std::endl;
525
+
526
+ }
527
+ }
528
+ else
529
+ {
530
+ // for self-play-mode
531
+ float bootstrap_value = value;
532
+ int path_len = search_path.size();
533
+ for (int i = path_len - 1; i >= 0; --i)
534
+ {
535
+ CNode *node = search_path[i];
536
+ if (node->to_play == to_play)
537
+ node->value_sum += bootstrap_value;
538
+ else
539
+ node->value_sum += -bootstrap_value;
540
+ node->visit_count += 1;
541
+
542
+ // NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node,
543
+ // but treated as 1 player, just for obtaining the true reward in the perspective of current player of node.
544
+ // float true_reward = node->value_prefix - parent_value_prefix;
545
+ float true_reward = node->reward;
546
+
547
+ // TODO(pu): why in muzero-general is - node.value
548
+ min_max_stats.update(true_reward + discount_factor * -node->value());
549
+
550
+ if (node->to_play == to_play)
551
+ bootstrap_value = -true_reward + discount_factor * bootstrap_value;
552
+ else
553
+ bootstrap_value = true_reward + discount_factor * bootstrap_value;
554
+ }
555
+ }
556
+ }
557
+
558
+ void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &value_prefixs, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &to_play_batch, std::vector<bool> &is_chance_list, std::vector<int> &leaf_idx_list)
559
+ {
560
+ /*
561
+ Overview:
562
+ Expand the nodes along the search path and update the infos.
563
+ Arguments:
564
+ - current_latent_state_index: The index of latent state of the leaf node in the search path.
565
+ - discount_factor: the discount factor of reward.
566
+ - value_prefixs: the value prefixs of nodes along the search path.
567
+ - values: the values to propagate along the search path.
568
+ - policies: the policy logits of nodes along the search path.
569
+ - min_max_stats: a tool used to min-max normalize the q value.
570
+ - results: the search results.
571
+ - to_play_batch: the batch of which player is playing on this node.
572
+ */
573
+
574
+ if (leaf_idx_list.empty()) {
575
+ leaf_idx_list.resize(results.num);
576
+ for (int i = 0; i < results.num; ++i) {
577
+ leaf_idx_list[i] = i;
578
+ }
579
+ }
580
+
581
+ for (auto leaf_order = 0; leaf_order < leaf_idx_list.size(); ++leaf_order) {
582
+ int i = leaf_idx_list[leaf_order];
583
+ }
584
+ for (int leaf_order = 0; leaf_order < leaf_idx_list.size(); ++leaf_order)
585
+ {
586
+ int i = leaf_idx_list[leaf_order];
587
+ results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[leaf_order], policies[leaf_order], is_chance_list[i]);
588
+ cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[leaf_order], discount_factor);
589
+ }
590
+
591
+ }
592
+
593
+ int cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players)
594
+ {
595
+ /*
596
+ Overview:
597
+ Select the child node of the roots according to ucb scores.
598
+ Arguments:
599
+ - root: the roots to select the child node.
600
+ - min_max_stats: a tool used to min-max normalize the score.
601
+ - pb_c_base: constants c2 in muzero.
602
+ - pb_c_init: constants c1 in muzero.
603
+ - disount_factor: the discount factor of reward.
604
+ - mean_q: the mean q value of the parent node.
605
+ - players: the number of players.
606
+ Outputs:
607
+ - action: the action to select.
608
+ */
609
+ if (root->is_chance) {
610
+ // std::cout << "root->is_chance: True " << std::endl;
611
+
612
+ // If the node is a chance node, we sample from the prior outcome distribution.
613
+ std::vector<int> outcomes;
614
+ std::vector<double> probs;
615
+
616
+ for (const auto& kv : root->children) {
617
+ outcomes.push_back(kv.first);
618
+ probs.push_back(kv.second.prior); // Assuming 'prior' is a member variable of Node
619
+ }
620
+
621
+ std::random_device rd;
622
+ std::mt19937 gen(rd());
623
+ std::discrete_distribution<> dist(probs.begin(), probs.end());
624
+
625
+ int outcome = outcomes[dist(gen)];
626
+ // std::cout << "Outcome: " << outcome << std::endl;
627
+
628
+ return outcome;
629
+ }
630
+
631
+ // std::cout << "root->is_chance: False " << std::endl;
632
+
633
+ float max_score = FLOAT_MIN;
634
+ const float epsilon = 0.000001;
635
+ std::vector<int> max_index_lst;
636
+ for (auto a : root->legal_actions)
637
+ {
638
+
639
+ CNode *child = root->get_child(a);
640
+ float temp_score = cucb_score(child, min_max_stats, mean_q, root->visit_count - 1, pb_c_base, pb_c_init, discount_factor, players);
641
+
642
+ if (max_score < temp_score)
643
+ {
644
+ max_score = temp_score;
645
+
646
+ max_index_lst.clear();
647
+ max_index_lst.push_back(a);
648
+ }
649
+ else if (temp_score >= max_score - epsilon)
650
+ {
651
+ max_index_lst.push_back(a);
652
+ }
653
+ }
654
+
655
+ int action = 0;
656
+ if (max_index_lst.size() > 0)
657
+ {
658
+ int rand_index = rand() % max_index_lst.size();
659
+ action = max_index_lst[rand_index];
660
+ }
661
+ return action;
662
+ }
663
+
664
+ float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players)
665
+ {
666
+ /*
667
+ Overview:
668
+ Compute the ucb score of the child.
669
+ Arguments:
670
+ - child: the child node to compute ucb score.
671
+ - min_max_stats: a tool used to min-max normalize the score.
672
+ - mean_q: the mean q value of the parent node.
673
+ - total_children_visit_counts: the total visit counts of the child nodes of the parent node.
674
+ - pb_c_base: constants c2 in muzero.
675
+ - pb_c_init: constants c1 in muzero.
676
+ - disount_factor: the discount factor of reward.
677
+ - players: the number of players.
678
+ Outputs:
679
+ - ucb_value: the ucb score of the child.
680
+ */
681
+ float pb_c = 0.0, prior_score = 0.0, value_score = 0.0;
682
+ pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init;
683
+ pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1));
684
+
685
+ prior_score = pb_c * child->prior;
686
+ if (child->visit_count == 0)
687
+ {
688
+ value_score = parent_mean_q;
689
+ }
690
+ else
691
+ {
692
+ float true_reward = child->reward;
693
+ if (players == 1)
694
+ value_score = true_reward + discount_factor * child->value();
695
+ else if (players == 2)
696
+ value_score = true_reward + discount_factor * (-child->value());
697
+ }
698
+
699
+ value_score = min_max_stats.normalize(value_score);
700
+
701
+ if (value_score < 0)
702
+ value_score = 0;
703
+ if (value_score > 1)
704
+ value_score = 1;
705
+
706
+ float ucb_value = prior_score + value_score;
707
+ return ucb_value;
708
+ }
709
+
710
+ void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch)
711
+ {
712
+ /*
713
+ Overview:
714
+ Search node path from the roots.
715
+ Arguments:
716
+ - roots: the roots that search from.
717
+ - pb_c_base: constants c2 in muzero.
718
+ - pb_c_init: constants c1 in muzero.
719
+ - disount_factor: the discount factor of reward.
720
+ - min_max_stats: a tool used to min-max normalize the score.
721
+ - results: the search results.
722
+ - virtual_to_play_batch: the batch of which player is playing on this node.
723
+ */
724
+ // set seed
725
+ get_time_and_set_rand_seed();
726
+
727
+ int last_action = -1;
728
+ float parent_q = 0.0;
729
+ results.search_lens = std::vector<int>();
730
+
731
+ int players = 0;
732
+ int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2
733
+ if (largest_element == -1)
734
+ players = 1;
735
+ else
736
+ players = 2;
737
+
738
+ for (int i = 0; i < results.num; ++i)
739
+ {
740
+ CNode *node = &(roots->roots[i]);
741
+ int is_root = 1;
742
+ int search_len = 0;
743
+ results.search_paths[i].push_back(node);
744
+
745
+ // std::cout << "root->is_chance: " <<node->is_chance<< std::endl;
746
+ // node->is_chance=false;
747
+
748
+ while (node->expanded())
749
+ {
750
+ float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor);
751
+ is_root = 0;
752
+ parent_q = mean_q;
753
+ // std::cout << "node->is_chance: " <<node->is_chance<< std::endl;
754
+
755
+ int action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players);
756
+ if (players > 1)
757
+ {
758
+ assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2);
759
+ if (virtual_to_play_batch[i] == 1)
760
+ virtual_to_play_batch[i] = 2;
761
+ else
762
+ virtual_to_play_batch[i] = 1;
763
+ }
764
+
765
+ node->best_action = action;
766
+ // next
767
+ node = node->get_child(action);
768
+ last_action = action;
769
+ results.search_paths[i].push_back(node);
770
+ search_len += 1;
771
+ }
772
+
773
+ CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2];
774
+
775
+ results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index);
776
+ results.latent_state_index_in_batch.push_back(parent->batch_index);
777
+
778
+ results.last_actions.push_back(last_action);
779
+ results.search_lens.push_back(search_len);
780
+ results.nodes.push_back(node);
781
+ results.leaf_node_is_chance.push_back(node->is_chance);
782
+ results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]);
783
+
784
+ }
785
+ }
786
+
787
+ }
LightZero/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.h ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // C++11
2
+
3
+ #ifndef CNODE_H
4
+ #define CNODE_H
5
+
6
+ #include "./../common_lib/cminimax.h"
7
+ #include <math.h>
8
+ #include <vector>
9
+ #include <stack>
10
+ #include <stdlib.h>
11
+ #include <time.h>
12
+ #include <cmath>
13
+ #include <sys/timeb.h>
14
+ #include <time.h>
15
+ #include <map>
16
+
17
+ const int DEBUG_MODE = 0;
18
+
19
+ namespace tree {
20
+
21
+ class CNode {
22
+ public:
23
+ int visit_count, to_play, current_latent_state_index, batch_index, best_action;
24
+ float reward, prior, value_sum;
25
+ bool is_chance;
26
+ int chance_space_size;
27
+ std::vector<int> children_index;
28
+ std::map<int, CNode> children;
29
+
30
+ std::vector<int> legal_actions;
31
+
32
+ CNode();
33
+ CNode(float prior, std::vector<int> &legal_actions, bool is_chance = false, int chance_space_size = 2);
34
+ ~CNode();
35
+
36
+ void expand(int to_play, int current_latent_state_index, int batch_index, float reward, const std::vector<float> &policy_logits, bool is_chance);
37
+ void add_exploration_noise(float exploration_fraction, const std::vector<float> &noises);
38
+ float compute_mean_q(int isRoot, float parent_q, float discount_factor);
39
+ void print_out();
40
+
41
+ int expanded();
42
+
43
+ float value();
44
+
45
+ std::vector<int> get_trajectory();
46
+ std::vector<int> get_children_distribution();
47
+ CNode* get_child(int action);
48
+ };
49
+
50
+ class CRoots{
51
+ public:
52
+ int root_num;
53
+ std::vector<CNode> roots;
54
+ std::vector<std::vector<int> > legal_actions_list;
55
+ int chance_space_size;
56
+
57
+ CRoots();
58
+ CRoots(int root_num, std::vector<std::vector<int> > &legal_actions_list, int chance_space_size);
59
+ ~CRoots();
60
+
61
+ void prepare(float root_noise_weight, const std::vector<std::vector<float> > &noises, const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
62
+ void prepare_no_noise(const std::vector<float> &rewards, const std::vector<std::vector<float> > &policies, std::vector<int> &to_play_batch);
63
+ void clear();
64
+ std::vector<std::vector<int> > get_trajectories();
65
+ std::vector<std::vector<int> > get_distributions();
66
+ std::vector<float> get_values();
67
+
68
+ };
69
+
70
+ class CSearchResults{
71
+ public:
72
+ int num;
73
+ std::vector<int> latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens;
74
+ std::vector<int> virtual_to_play_batchs;
75
+ std::vector<CNode*> nodes;
76
+ std::vector<bool> leaf_node_is_chance;
77
+ std::vector<std::vector<CNode*> > search_paths;
78
+
79
+ CSearchResults();
80
+ CSearchResults(int num);
81
+ ~CSearchResults();
82
+
83
+ };
84
+
85
+
86
+ //*********************************************************
87
+ void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players);
88
+ void cbackpropagate(std::vector<CNode*> &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor);
89
+ void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector<float> &rewards, const std::vector<float> &values, const std::vector<std::vector<float> > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &to_play_batch, std::vector<int> & is_chance_list, std::vector<int> &leaf_idx_list);
90
+ int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players);
91
+ float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, float total_children_visit_counts, float pb_c_base, float pb_c_init, float discount_factor, int players);
92
+ void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector<int> &virtual_to_play_batch);
93
+ }
94
+
95
+ #endif