先坤 commited on
Commit
db26c81
·
1 Parent(s): 49b4f2c

add greedrl

Browse files
.gitignore ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .idea
2
+ *.tar.gz
3
+ logs
4
+ **/__pycache__
5
+ data
6
+ *.log
7
+ *.pkl
8
+ *.pt
9
+ **/build/
10
+ **/dist/
11
+ **/*.egg-info
12
+ .DS_Store
13
+ .nfs*
14
+ *.so
15
+ *.dylib
16
+ *.iml
17
+ target
18
+ **/nohup.out
19
+ *.pth
20
+ **/.flattened-pom.xml
CMakeLists.txt ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 2.8.12)
2
+ project(greedrl_C_ LANGUAGES CXX)
3
+
4
+ set(CMAKE_CXX_STANDARD 14)
5
+
6
+ find_package(PythonInterp REQUIRED)
7
+ execute_process(COMMAND "python" "-c"
8
+ "
9
+ import os
10
+ import torch
11
+ from distutils import sysconfig as s
12
+ print(s.get_python_inc(plat_specific=True))
13
+ print(s.get_config_var('EXT_SUFFIX'))
14
+ print(os.path.dirname(torch.__file__))
15
+ "
16
+ RESULT_VARIABLE _PYTHON_SUCCESS
17
+ OUTPUT_VARIABLE _PYTHON_VALUES
18
+ ERROR_VARIABLE _PYTHON_ERROR_VALUE)
19
+
20
+ if(NOT _PYTHON_SUCCESS MATCHES 0)
21
+ message("_PYTHON_SUCCESS: ${_PYTHON_SUCCESS}")
22
+ message("_PYTHON_VALUES: ${_PYTHON_VALUES}")
23
+ message("_PYTHON_ERROR_VALUE: ${_PYTHON_ERROR_VALUE}")
24
+ message(FATAL_ERROR "get python config error!")
25
+ endif()
26
+
27
+ string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES})
28
+ list(GET _PYTHON_VALUES 0 PYTHON_INCLUDE_DIR)
29
+ list(GET _PYTHON_VALUES 1 PYTHON_EXT_SUFFIX)
30
+ list(GET _PYTHON_VALUES 2 TORCH_HOME)
31
+
32
+ include_directories(
33
+ ${PYTHON_INCLUDE_DIR}
34
+ ${TORCH_HOME}/include
35
+ ${TORCH_HOME}/include/TH
36
+ ${TORCH_HOME}/include/THC
37
+ ${TORCH_HOME}/include/torch/csrc/api/include
38
+ )
39
+
40
+ string(LENGTH "${CMAKE_SOURCE_DIR}/" SOURCE_PATH_LENGTH)
41
+ add_compile_options(-DSOURCE_PATH_LENGTH=${SOURCE_PATH_LENGTH})
42
+ add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=0 -fvisibility=hidden -fopenmp)
43
+
44
+ if(${CMAKE_CXX_COMPILER_ID} STREQUAL "Clang")
45
+ add_link_options(-undefined dynamic_lookup)
46
+ endif()
47
+
48
+ file(GLOB_RECURSE CSRC_CPP csrc/*.cpp)
49
+
50
+ add_library(greedrl_c MODULE ${CSRC_CPP})
51
+ set_target_properties(greedrl_c PROPERTIES PREFIX "")
52
+ set_target_properties(greedrl_c PROPERTIES SUFFIX "${PYTHON_EXT_SUFFIX}")
53
+ target_compile_options(greedrl_c PRIVATE -Wno-sign-conversion -O3)
54
+ target_link_libraries(greedrl_c c10 torch torch_cpu torch_python)
55
+ target_link_directories(greedrl_c PRIVATE ${TORCH_HOME}/lib)
56
+
57
+ find_package(CUDA)
58
+ if(CUDA_FOUND)
59
+ enable_language(CUDA)
60
+ file(GLOB_RECURSE CSRC_CU csrc/*.cu)
61
+ add_library(greedrl_cu OBJECT ${CSRC_CU})
62
+ target_compile_options(greedrl_cu PRIVATE -keep -Xptxas -v --expt-relaxed-constexpr --expt-extended-lambda -O3)
63
+ set_target_properties(greedrl_cu PROPERTIES POSITION_INDEPENDENT_CODE ON CUDA_ARCHITECTURES "70;75;80")
64
+ add_compile_definitions(CUDA_FOUND)
65
+ include_directories(${CUDA_INCLUDE_DIRS})
66
+ target_link_libraries(greedrl_c torch_cuda greedrl_cu)
67
+ target_link_directories(greedrl_c PRIVATE ${TORCH_HOME}/lib)
68
+ endif()
README.md CHANGED
@@ -1,3 +1,628 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ pipeline_tag: reinforcement-learning
4
+ tags:
5
+ - Deep Reinforcement Learning
6
+ - Combinatorial Optimization
7
+ - Reinforcement Learning
8
+ - Vehicle Routing Problem
9
  ---
10
+
11
+ ![](./images/GREEDRL-Logo-Original-640.png)
12
+
13
+
14
+ # 🤠GreedRL
15
+
16
+ ## Overview
17
+
18
+ - 🤠GreedRL is a fast and general framework for **Combinatorial Optimization Problems (COPs)**, based on **Deep Reinforcement Learning (DRL)**.
19
+
20
+ - 🤠GreedRL achieves **1200 times faster and 3% improved performance** than [Google OR-Tools](https://developers.google.com/optimization) for large-scale (>=1000 nodes) CVRPs.
21
+
22
+ ## 🏆Award
23
+
24
+ [INFORMS 2021 Franz Edelman Award finalists](https://www.informs.org/Resource-Center/Video-Library/Edelman-Competition-Videos/2021-Edelman-Competition-Videos/2021-Edelman-Finalist-Alibaba) for Achievement in Operations Research and the Management Sciences (recognized for our work on Cainiao Network VRP algorithm).
25
+
26
+
27
+ ## Main features
28
+
29
+ * **GENERAL**
30
+
31
+ 🤠GreedRL makes **a high level of abstraction for COPs**, which can solve various types of problems, such as TSP, CVRP, VRPTW, PDPTW, SDVRP, DPDP, Order Batching, etc.
32
+
33
+ * **HIGH-PERFORMANCE**
34
+
35
+ 🤠GreedRL have improved the DRL environment (Env) simulation speed by **CUDA and C++ implementations**.
36
+
37
+ * **USER-FRIENDLY**
38
+
39
+ 🤠GreedRL framework provides **user-friendly ability for COPs modeling**, where users only need to declare constraints, objectives and variables of COPs. For more examples, please refer to [COPs Modeling examples](https://huggingface.co/Cainiao-AI/GreedRL/blob/main/README.md#cops-modeling-examples).
40
+
41
+ ## Editions
42
+
43
+ We provide an open source Community Edition and an Enterprise Edition of our 🤠GreedRL for users.
44
+
45
+ - **The Community Edition** is now released and available to [download](https://huggingface.co/Cainiao-AI/GreedRL).
46
+ - **The Enterprise Edition** has a high-performance implementation that achives a faster computing speed, especially when solving larg-scale COPs. For more informations, please contact <a href="mailto:[email protected]">us</a>.
47
+
48
+
49
+ ## Architecture
50
+ ![](./images/GREEDRL-Framwork_en.png)
51
+
52
+ ## COPs Modeling examples
53
+
54
+
55
+ ### Capacitated Vehicle Routing Problem (CVRP)
56
+ <details>
57
+ <summary>CVRP</summary>
58
+
59
+ ```python
60
+ from greedrl.feature import *
61
+ from greedrl.variable import *
62
+ from greedrl.function import *
63
+ from greedrl import Problem, Solution, Solver
64
+ from greedrl import runner
65
+
66
+ features = [continuous_feature('task_demand'),
67
+ continuous_feature('worker_weight_limit'),
68
+ continuous_feature('distance_matrix'),
69
+ variable_feature('distance_this_to_task'),
70
+ variable_feature('distance_task_to_end')]
71
+
72
+ variables = [task_demand_now('task_demand_now', feature='task_demand'),
73
+ task_demand_now('task_demand_this', feature='task_demand', only_this=True),
74
+ feature_variable('task_weight'),
75
+ worker_variable('worker_weight_limit'),
76
+ worker_used_resource('worker_used_weight', task_require='task_weight'),
77
+ edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
78
+ edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
79
+ edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True)]
80
+
81
+
82
+ class Constraint:
83
+
84
+ def do_task(self):
85
+ return self.task_demand_this
86
+
87
+ def mask_task(self):
88
+ # 已经完成的任务
89
+ mask = self.task_demand_now <= 0
90
+ # 车辆容量限制
91
+ worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
92
+ mask |= self.task_demand_now * self.task_weight > worker_weight_limit[:, None]
93
+ return mask
94
+
95
+ def finished(self):
96
+ return torch.all(self.task_demand_now <= 0, 1)
97
+
98
+
99
+ class Objective:
100
+
101
+ def step_worker_end(self):
102
+ return self.distance_last_to_this
103
+
104
+ def step_task(self):
105
+ return self.distance_last_to_this
106
+ ```
107
+
108
+ </details>
109
+
110
+ ### Pickup and Delivery Problem with Time Windows (PDPTW)
111
+ <details>
112
+ <summary>PDPTW</summary>
113
+
114
+ ```python
115
+ from greedrl.model import runner
116
+ from greedrl.feature import *
117
+ from greedrl.variable import *
118
+ from greedrl.function import *
119
+ from greedrl import Problem, Solution, Solver
120
+
121
+ features = [local_category('task_group'),
122
+ global_category('task_priority', 2),
123
+ variable_feature('distance_this_to_task'),
124
+ variable_feature('distance_task_to_end')]
125
+
126
+ variables = [task_demand_now('task_demand_now', feature='task_demand'),
127
+ task_demand_now('task_demand_this', feature='task_demand', only_this=True),
128
+ feature_variable('task_weight'),
129
+ feature_variable('task_group'),
130
+ feature_variable('task_priority'),
131
+ feature_variable('task_due_time2', feature='task_due_time'),
132
+ task_variable('task_due_time'),
133
+ task_variable('task_service_time'),
134
+ task_variable('task_due_time_penalty'),
135
+ worker_variable('worker_basic_cost'),
136
+ worker_variable('worker_distance_cost'),
137
+ worker_variable('worker_due_time'),
138
+ worker_variable('worker_weight_limit'),
139
+ worker_used_resource('worker_used_weight', task_require='task_weight'),
140
+ worker_used_resource('worker_used_time', 'distance_matrix', 'task_service_time', 'task_ready_time',
141
+ 'worker_ready_time'),
142
+ edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
143
+ edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
144
+ edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True)]
145
+
146
+
147
+ class Constraint:
148
+
149
+ def do_task(self):
150
+ return self.task_demand_this
151
+
152
+ def mask_worker_end(self):
153
+ return task_group_split(self.task_group, self.task_demand_now <= 0)
154
+
155
+ def mask_task(self):
156
+ mask = self.task_demand_now <= 0
157
+ mask |= task_group_priority(self.task_group, self.task_priority, mask)
158
+
159
+ worker_used_time = self.worker_used_time[:, None] + self.distance_this_to_task
160
+ mask |= (worker_used_time > self.task_due_time2) & (self.task_priority == 0)
161
+
162
+ # 容量约束
163
+ worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
164
+ mask |= self.task_demand_now * self.task_weight > worker_weight_limit[:, None]
165
+ return mask
166
+
167
+ def finished(self):
168
+ return torch.all(self.task_demand_now <= 0, 1)
169
+
170
+
171
+ class Objective:
172
+
173
+ def step_worker_start(self):
174
+ return self.worker_basic_cost
175
+
176
+ def step_worker_end(self):
177
+ feasible = self.worker_used_time <= self.worker_due_time
178
+ return self.distance_last_to_this * self.worker_distance_cost, feasible
179
+
180
+ def step_task(self):
181
+ worker_used_time = self.worker_used_time - self.task_service_time
182
+ feasible = worker_used_time <= self.task_due_time
183
+ feasible &= worker_used_time <= self.worker_due_time
184
+ cost = self.distance_last_to_this * self.worker_distance_cost
185
+ return torch.where(feasible, cost, cost + self.task_due_time_penalty), feasible
186
+ ```
187
+
188
+ </details>
189
+
190
+
191
+ ### VRP with Time Windows (VRPTW)
192
+ <details>
193
+ <summary>VRPTW</summary>
194
+
195
+ ```python
196
+ from greedrl import Problem, Solution, Solver
197
+ from greedrl.feature import *
198
+ from greedrl.variable import *
199
+ from greedrl.function import *
200
+ from greedrl.model import runner
201
+ from greedrl.myenv import VrptwEnv
202
+
203
+ features = [continuous_feature('worker_weight_limit'),
204
+ continuous_feature('worker_ready_time'),
205
+ continuous_feature('worker_due_time'),
206
+ continuous_feature('worker_basic_cost'),
207
+ continuous_feature('worker_distance_cost'),
208
+ continuous_feature('task_demand'),
209
+ continuous_feature('task_weight'),
210
+ continuous_feature('task_ready_time'),
211
+ continuous_feature('task_due_time'),
212
+ continuous_feature('task_service_time'),
213
+ continuous_feature('distance_matrix')]
214
+
215
+ variables = [task_demand_now('task_demand_now', feature='task_demand'),
216
+ task_demand_now('task_demand_this', feature='task_demand', only_this=True),
217
+ feature_variable('task_weight'),
218
+ feature_variable('task_due_time'),
219
+ feature_variable('task_ready_time'),
220
+ feature_variable('task_service_time'),
221
+ worker_variable('worker_weight_limit'),
222
+ worker_variable('worker_due_time'),
223
+ worker_variable('worker_basic_cost'),
224
+ worker_variable('worker_distance_cost'),
225
+ worker_used_resource('worker_used_weight', task_require='task_weight'),
226
+ worker_used_resource('worker_used_time', 'distance_matrix', 'task_service_time', 'task_ready_time',
227
+ 'worker_ready_time'),
228
+ edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
229
+ edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
230
+ edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True)]
231
+
232
+
233
+ class Constraint:
234
+
235
+ def do_task(self):
236
+ return self.task_demand_this
237
+
238
+ def mask_task(self):
239
+ # 已经完成的任务
240
+ mask = self.task_demand_now <= 0
241
+ # 车辆容量限制
242
+ worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
243
+ mask |= self.task_demand_now * self.task_weight > worker_weight_limit[:, None]
244
+
245
+ worker_used_time = self.worker_used_time[:, None] + self.distance_this_to_task
246
+ mask |= worker_used_time > self.task_due_time
247
+
248
+ worker_used_time = torch.max(worker_used_time, self.task_ready_time)
249
+ worker_used_time += self.task_service_time
250
+ worker_used_time += self.distance_task_to_end
251
+ mask |= worker_used_time > self.worker_due_time[:, None]
252
+
253
+ return mask
254
+
255
+ def finished(self):
256
+ return torch.all(self.task_demand_now <= 0, 1)
257
+
258
+
259
+ class Objective:
260
+
261
+ def step_worker_start(self):
262
+ return self.worker_basic_cost
263
+
264
+ def step_worker_end(self):
265
+ return self.distance_last_to_this * self.worker_distance_cost
266
+
267
+ def step_task(self):
268
+ return self.distance_last_to_this * self.worker_distance_cost
269
+ ```
270
+
271
+ </details>
272
+
273
+ ### Travelling Salesman Problem (TSP)
274
+ <details>
275
+ <summary>TSP</summary>
276
+
277
+ ```python
278
+ from greedrl.feature import *
279
+ from greedrl.variable import *
280
+ from greedrl import Problem
281
+ from greedrl import runner
282
+
283
+ features = [continuous_feature('task_location'),
284
+ variable_feature('distance_this_to_task'),
285
+ variable_feature('distance_task_to_end')]
286
+
287
+ variables = [task_demand_now('task_demand_now', feature='task_demand'),
288
+ task_demand_now('task_demand_this', feature='task_demand', only_this=True),
289
+ edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
290
+ edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
291
+ edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True),
292
+ edge_variable('distance_last_to_loop', feature='distance_matrix', last_to_loop=True)]
293
+
294
+
295
+ class Constraint:
296
+
297
+ def do_task(self):
298
+ return self.task_demand_this
299
+
300
+ def mask_task(self):
301
+ mask = self.task_demand_now <= 0
302
+ return mask
303
+
304
+ def mask_worker_end(self):
305
+ return torch.any(self.task_demand_now > 0, 1)
306
+
307
+ def finished(self):
308
+ return torch.all(self.task_demand_now <= 0, 1)
309
+
310
+
311
+ class Objective:
312
+
313
+ def step_worker_end(self):
314
+ return self.distance_last_to_loop
315
+
316
+ def step_task(self):
317
+ return self.distance_last_to_this
318
+ ```
319
+
320
+ </details>
321
+
322
+ ### Split Delivery Vehicle Routing Problem (SDVRP)
323
+ <details>
324
+ <summary>SDVRP</summary>
325
+
326
+ ```python
327
+ from greedrl.feature import *
328
+ from greedrl.variable import *
329
+ from greedrl import Problem
330
+ from greedrl import runner
331
+
332
+ features = [continuous_feature('task_demand'),
333
+ continuous_feature('worker_weight_limit'),
334
+ continuous_feature('distance_matrix'),
335
+ variable_feature('distance_this_to_task'),
336
+ variable_feature('distance_task_to_end')]
337
+
338
+ variables = [task_demand_now('task_demand'),
339
+ task_demand_now('task_demand_this', feature='task_demand', only_this=True),
340
+ feature_variable('task_weight'),
341
+ task_variable('task_weight_this', feature='task_weight'),
342
+ worker_variable('worker_weight_limit'),
343
+ worker_used_resource('worker_used_weight', task_require='task_weight'),
344
+ edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True)]
345
+
346
+
347
+ class Constraint:
348
+
349
+ def do_task(self):
350
+ worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
351
+ return torch.min(self.task_demand_this, worker_weight_limit // self.task_weight_this)
352
+
353
+ def mask_task(self):
354
+ mask = self.task_demand <= 0
355
+ worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
356
+ mask |= self.task_weight > worker_weight_limit[:, None]
357
+ return mask
358
+
359
+ def finished(self):
360
+ return torch.all(self.task_demand <= 0, 1)
361
+
362
+
363
+ class Objective:
364
+
365
+ def step_worker_end(self):
366
+ return self.distance_last_to_this
367
+
368
+ def step_task(self):
369
+ return self.distance_last_to_this
370
+ ```
371
+
372
+ </details>
373
+
374
+ ### Realistic Business Scenario
375
+ <details>
376
+ <summary>real-time Dynamic Pickup and Delivery Problem (DPDP)</summary>
377
+
378
+ ```python
379
+ from greedrl.feature import *
380
+ from greedrl.variable import *
381
+ from greedrl.function import *
382
+ from greedrl import Problem
383
+ from greedrl import runner
384
+
385
+ features = [local_category('task_order'),
386
+ global_category('task_type', 2),
387
+ global_category('task_new_order', 2),
388
+ variable_feature('time_this_to_task'),
389
+ continuous_feature('x_time_matrix'),
390
+ continuous_feature('task_due_time_x'),
391
+ continuous_feature('worker_task_mask')]
392
+
393
+ variables = [task_demand_now('task_demand_now', feature='task_demand'),
394
+ task_demand_now('task_demand_this', feature='task_demand', only_this=True),
395
+ task_variable('task_pickup_this', feature='task_pickup'),
396
+ task_variable('task_due_time_this', feature='task_due_time'),
397
+ feature_variable('task_order', feature='task_order'),
398
+ feature_variable('task_type', feature='task_type'),
399
+ feature_variable('task_new_pickup', feature='task_new_pickup'),
400
+ feature_variable('worker_task_mask', feature='worker_task_mask'),
401
+ worker_count_now('worker_count_now', feature='worker_count'),
402
+ worker_variable('worker_min_old_task_this', feature='worker_min_old_task'),
403
+ worker_variable('worker_max_new_order_this', feature='worker_max_new_order'),
404
+ worker_variable('worker_task_mask_this', feature='worker_task_mask'),
405
+ worker_used_resource('worker_used_old_task', task_require='task_old'),
406
+ worker_used_resource('worker_used_new_order', task_require='task_new_pickup'),
407
+ worker_used_resource('worker_used_time', edge_require='time_matrix'),
408
+ edge_variable('time_this_to_task', feature='x_time_matrix', this_to_task=True)]
409
+
410
+
411
+ class Constraint:
412
+
413
+ def do_task(self):
414
+ return self.task_demand_this
415
+
416
+ def mask_worker_start(self):
417
+ mask = self.worker_count_now <= 0
418
+
419
+ finished = self.task_demand_now <= 0
420
+ worker_task_mask = self.worker_task_mask | finished[:, None, :]
421
+ mask |= torch.all(worker_task_mask, 2)
422
+
423
+ return mask
424
+
425
+ def mask_worker_end(self):
426
+ mask = self.worker_used_old_task < self.worker_min_old_task_this
427
+ mask |= task_group_split(self.task_order, self.task_demand_now <= 0)
428
+ return mask
429
+
430
+ def mask_task(self):
431
+ mask = self.task_demand_now <= 0
432
+
433
+ mask |= task_group_priority(self.task_order, self.task_type, mask)
434
+
435
+ worker_max_new_order = self.worker_max_new_order_this - self.worker_used_new_order
436
+ mask |= self.task_new_pickup > worker_max_new_order[:, None]
437
+
438
+ mask |= self.worker_task_mask_this
439
+
440
+ return mask
441
+
442
+ def finished(self):
443
+ worker_mask = self.worker_count_now <= 0
444
+ task_mask = self.task_demand_now <= 0
445
+ worker_task_mask = worker_mask[:, :, None] | task_mask[:, None, :]
446
+
447
+ worker_task_mask |= self.worker_task_mask
448
+ batch_size = worker_task_mask.size(0)
449
+ worker_task_mask = worker_task_mask.view(batch_size, -1)
450
+ return worker_task_mask.all(1)
451
+
452
+
453
+ class Objective:
454
+
455
+ def step_task(self):
456
+ over_time = (self.worker_used_time - self.task_due_time_this).clamp(min=0)
457
+ pickup_time = self.worker_used_time * self.task_pickup_this
458
+ return self.worker_used_time + over_time + pickup_time
459
+
460
+ def step_finish(self):
461
+ return self.task_demand_now.sum(1) * 1000
462
+ ```
463
+
464
+ </details>
465
+
466
+ ### Order Batching Problem
467
+ <details>
468
+ <summary>Batching</summary>
469
+
470
+ ```python
471
+ from greedrl import Problem, Solver
472
+ from greedrl.feature import *
473
+ from greedrl.variable import *
474
+ from greedrl import runner
475
+
476
+
477
+ features = [local_feature('task_area'),
478
+ local_feature('task_roadway'),
479
+ local_feature('task_area_group'),
480
+ sparse_local_feature('task_item_id', 'task_item_num'),
481
+ sparse_local_feature('task_item_owner_id', 'task_item_num'),
482
+ variable_feature('worker_task_item'),
483
+ variable_feature('worker_used_roadway'),
484
+ variable_feature('worker_used_area')]
485
+
486
+ variables = [task_demand_now('task_demand_now', feature='task_demand'),
487
+ task_demand_now('task_demand_this', feature='task_demand', only_this=True),
488
+ feature_variable('task_item_id'),
489
+ feature_variable('task_item_num'),
490
+ feature_variable('task_item_owner_id'),
491
+ feature_variable('task_area'),
492
+ feature_variable('task_area_group'),
493
+ feature_variable('task_load'),
494
+ feature_variable('task_group'),
495
+ worker_variable('worker_load_limit'),
496
+ worker_variable('worker_area_limit'),
497
+ worker_variable('worker_area_group_limit'),
498
+ worker_task_item('worker_task_item', item_id='task_item_id', item_num='task_item_num'),
499
+ worker_task_item('worker_task_item_owner', item_id='task_item_owner_id', item_num='task_item_num'),
500
+ worker_used_resource('worker_used_load', task_require='task_load'),
501
+ worker_used_resource('worker_used_area', task_require='task_area'),
502
+ worker_used_resource('worker_used_roadway', task_require='task_roadway'),
503
+ worker_used_resource('worker_used_area_group', task_require='task_area_group')]
504
+
505
+
506
+ class Constraint:
507
+
508
+ def do_task(self):
509
+ return self.task_demand_this
510
+
511
+ def mask_worker_end(self):
512
+ return self.worker_used_load < self.worker_load_limit
513
+
514
+ def mask_task(self):
515
+ # completed tasks
516
+ mask = self.task_demand_now <= 0
517
+ # mask |= task_group_priority(self.task_group, self.task_out_stock_time, mask)
518
+
519
+ NT = self.task_item_id.size(1)
520
+ worker_task_item = self.worker_task_item[:, None, :]
521
+ worker_task_item = worker_task_item.expand(-1, NT, -1)
522
+ task_item_in_worker = worker_task_item.gather(2, self.task_item_id.long())
523
+ task_item_in_worker = (task_item_in_worker > 0) & (self.task_item_num > 0)
524
+
525
+ worker_task_item_owner = self.worker_task_item_owner[:, None, :]
526
+ worker_task_item_owner = worker_task_item_owner.expand(-1, NT, -1)
527
+ task_item_owner_in_worker = worker_task_item_owner.gather(2, self.task_item_owner_id.long())
528
+ task_item_owner_in_worker = (task_item_owner_in_worker > 0) & (self.task_item_num > 0)
529
+
530
+ #
531
+ mask |= torch.any(task_item_in_worker & ~task_item_owner_in_worker, 2)
532
+
533
+ worker_load_limit = self.worker_load_limit - self.worker_used_load
534
+ mask |= (self.task_load > worker_load_limit[:, None])
535
+
536
+ task_area = self.task_area + self.worker_used_area[:, None, :]
537
+ task_area_num = task_area.clamp(0, 1).sum(2, dtype=torch.int32)
538
+ mask |= (task_area_num > self.worker_area_limit[:, None])
539
+
540
+ tak_area_group = self.task_area_group + self.worker_used_area_group[:, None, :]
541
+ tak_area_group_num = tak_area_group.clamp(0, 1).sum(2, dtype=torch.int32)
542
+ mask |= (tak_area_group_num > self.worker_area_group_limit[:, None])
543
+
544
+ return mask
545
+
546
+ def finished(self):
547
+ return torch.all(self.task_demand_now <= 0, 1)
548
+
549
+
550
+ class Objective:
551
+
552
+ def step_worker_end(self):
553
+ area_num = self.worker_used_area.clamp(0, 1).sum(1)
554
+ roadway_num = self.worker_used_roadway.clamp(0, 1).sum(1)
555
+ item_num = self.worker_task_item.clamp(0, 1).sum(1)
556
+ penalty = (self.worker_load_limit - self.worker_used_load) * 10
557
+ return area_num * 100 + roadway_num * 10 + item_num + penalty
558
+ ```
559
+
560
+ </details>
561
+
562
+
563
+ #
564
+ #
565
+ # Getting started
566
+
567
+ ## Description
568
+ We are delighted to release 🤠GreedRL Community Edition, as well as example of training and testing scripts for the standard Capacitated VRP (CVRP), you can download it and get started.
569
+
570
+ ## Test environment
571
+ 🤠GreedRL Community Edition has been tested on Ubuntu 18.04 with GCC compiler v7.5.0 and CUDA version 11.4, and a [Miniconda](https://docs.conda.io/en/latest/miniconda.html#system-requirements) distribution with Python 3.8. We recommend using a similar configuration to avoid any possiblem compilation issue.
572
+
573
+ ## Installation
574
+ First, clone the repository.
575
+ ```aidl
576
+ $ git clone https://huggingface.co/Cainiao-AI/GreedRL
577
+ ```
578
+ Then, create and activate a python environment using conda, and install required packages.
579
+ ```aidl
580
+ $ conda create -n python38 python==3.8
581
+ $ source activate python38
582
+ $ pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113
583
+ ```
584
+ Finally, compile and add the resulting library `greedrl` to the `PYTHONPATH`
585
+ ```aidl
586
+ $ python setup.py build
587
+ $ export PYTHONPATH={your_current_path}/build/lib.linux-x86_64-cpython-38/:$PYTHONPATH
588
+ ```
589
+
590
+ ## CVRP Training
591
+
592
+ 1. Training data
593
+
594
+ We use generated data for the training phase, the customers and depot locations are randomly generated in the unit square [0,1] X [0,1]. For CVRP, we assume that the demand of each node is a discrete number in {1,...,9}, chosen uniformly at random, and each vehicle has a default capacity of 50.
595
+
596
+
597
+ 2. Start training
598
+ ```python
599
+ $ cd examples/cvrp
600
+ $ python train.py --model_filename cvrp_100.pt --problem_size 100
601
+ ```
602
+
603
+ ## CVRP Testing
604
+
605
+ After training process, you'll get a trained model, like `cvrp_100.pt`, that you can use for test.
606
+
607
+ ```python
608
+ $ cd examples/cvrp
609
+ $ python solve.py --device cpu --model_name cvrp_100.pt --problem_size 100
610
+ ```
611
+
612
+ # Support
613
+ We look forward you to downloading it, using it, and opening discussion if you encounter any problems or have ideas on building an even better experience.
614
+ For commercial enquiries, please contact <a href="mailto:[email protected]">us</a>.
615
+
616
+ # Citation
617
+ ```
618
+ @article{hu2022alibaba,
619
+ title={Alibaba vehicle routing algorithms enable rapid pick and delivery},
620
+ author={Hu, Haoyuan and Zhang, Ying and Wei, Jiangwen and Zhan, Yang and Zhang, Xinhui and Huang, Shaojian and Ma, Guangrui and Deng, Yuming and Jiang, Siwei},
621
+ journal={INFORMS Journal on Applied Analytics},
622
+ volume={52},
623
+ number={1},
624
+ pages={27--41},
625
+ year={2022},
626
+ publisher={INFORMS}
627
+ }
628
+ ```
csrc/common.h ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <cfloat>
3
+ #include <climits>
4
+ #include <cstdint>
5
+ #include <limits>
6
+ #include <chrono>
7
+ #include <stdexcept>
8
+ #include <torch/extension.h>
9
+
10
+ #define ASSERT(c) assert(c)
11
+ #define ALIGN(v, n) ((v + n - 1) / n * n)
12
+ #define INF std::numeric_limits<float>::infinity()
13
+ #define __FILENAME__ (__FILE__+ SOURCE_PATH_LENGTH)
14
+
15
+ #define GRL_ERROR(format, args...) \
16
+ greedrl_error(__FILENAME__, __LINE__, format, ##args); \
17
+
18
+
19
+ #define GRL_CHECK(flag, format, args...) \
20
+ greedrl_check(__FILENAME__, __LINE__, flag, format, ##args); \
21
+
22
+
23
+ #define MALLOC(ptr, T, size) \
24
+ ptr = (T*) malloc(sizeof(T) * (size)); \
25
+ GRL_CHECK(ptr != nullptr, "out of memory!"); \
26
+
27
+
28
+ #define GALLOC(ptr, T, size) \
29
+ GRL_CHECK((size) > 0, "malloc 0 bytes"); \
30
+ T* const ptr = (T*) malloc(sizeof(T) * (size)); \
31
+ GRL_CHECK(ptr != nullptr, "out of memory!"); \
32
+ AllocGuard ptr##_##alloc##_##guard(ptr); \
33
+
34
+
35
+ #define REALLOC(ptr, T, size) \
36
+ GRL_CHECK((size) > 0, "malloc 0 bytes"); \
37
+ ptr = (T*) realloc(ptr, sizeof(T) * (size)); \
38
+ GRL_CHECK(ptr != nullptr, "out of memory!"); \
39
+
40
+
41
+ #define GRL_CHECK_TENSOR(tensor, device, allow_sub_contiguous, allow_null, ...) \
42
+ greedrl_check_tensor(__FILENAME__, __LINE__, tensor, #tensor, device, \
43
+ allow_sub_contiguous, allow_null, {__VA_ARGS__}); \
44
+
45
+
46
+ const int GRL_WORKER_START = 0;
47
+ const int GRL_WORKER_END = 1;
48
+ const int GRL_TASK = 2;
49
+ const int GRL_FINISH = 3;
50
+
51
+ const int MAX_BATCH_SIZE = 100000;
52
+ const int MAX_TASK_COUNT = 5120;
53
+ const int MAX_SHARED_MEM = 48128;
54
+
55
+ using String = std::string;
56
+ using Device = torch::Device;
57
+ using Tensor = torch::Tensor;
58
+ using TensorMap = std::map<String, Tensor>;
59
+ using TensorList = std::vector<Tensor>;
60
+
61
+
62
+ inline void greedrl_error(const char* const file, const int64_t line,
63
+ const char* const format, ...)
64
+ {
65
+ const int N = 2048;
66
+ static char buf[N];
67
+
68
+ va_list args;
69
+ va_start(args, format);
70
+ int n = vsnprintf(buf, N, format, args);
71
+ va_end(args);
72
+
73
+ if(n < N)
74
+ {
75
+ snprintf(buf+n, N-n, " at %s:%ld", file, line);
76
+ }
77
+
78
+ throw std::runtime_error(buf);
79
+ }
80
+
81
+ inline void greedrl_check(const char* const file, const int64_t line,
82
+ const bool flag, const char* const format, ...)
83
+ {
84
+ if(flag)
85
+ {
86
+ return;
87
+ }
88
+
89
+ const int N = 2048;
90
+ static char buf[N];
91
+
92
+ va_list args;
93
+ va_start(args, format);
94
+ int n = vsnprintf(buf, N, format, args);
95
+ va_end(args);
96
+
97
+ if(n < N)
98
+ {
99
+ snprintf(buf+n, N-n, " at %s:%ld", file, line);
100
+ }
101
+
102
+ throw std::runtime_error(buf);
103
+ }
104
+
105
+ // contiguous except the 1st dimension
106
+ inline bool is_sub_contiguous(const Tensor& tensor)
107
+ {
108
+ int dim = tensor.dim();
109
+ if(dim==1) return true;
110
+
111
+ auto sizes = tensor.sizes();
112
+ auto strides = tensor.strides();
113
+
114
+ if(strides[dim-1] != 1) return false;
115
+
116
+ int s = 1;
117
+ for(int i=dim-2; i>0; i--)
118
+ {
119
+ s *= sizes[i+1];
120
+ if(strides[i] != s) return false;
121
+ }
122
+
123
+ return true;
124
+
125
+ };
126
+
127
+ inline void greedrl_check_tensor(const char* const file,
128
+ const int line,
129
+ const Tensor& tensor,
130
+ const String& name,
131
+ const Device& device,
132
+ bool allow_sub_contiguous,
133
+ bool allow_null,
134
+ std::initializer_list<int> sizes)
135
+ {
136
+ greedrl_check(file, line, tensor.numel() < 1000 * 1000 * 1000, "tensor size too large");
137
+
138
+ auto device2 = tensor.device();
139
+ greedrl_check(file, line, device2==device,
140
+ "'%s' device is %s, but expect %s",
141
+ name.c_str(), device2.str().c_str(), device.str().c_str());
142
+
143
+ bool is_contiguous = allow_sub_contiguous ? is_sub_contiguous(tensor) : tensor.is_contiguous();
144
+ greedrl_check(file, line, is_contiguous, "'%s' is not contiguous", name.c_str());
145
+
146
+ if(allow_null && tensor.data_ptr() == nullptr) return;
147
+
148
+ if(tensor.dim() != sizes.size())
149
+ {
150
+ greedrl_error(file, line, "'%s' dim is %d, but expect %d", name.c_str(), (int)tensor.dim(), (int)sizes.size());
151
+ }
152
+ int i=0;
153
+ for(auto s:sizes)
154
+ {
155
+ greedrl_check(file, line, tensor.size(i)==s, "'%s' size(%d) is %d, but expect %d", name.c_str(), i, (int)tensor.size(i), s);
156
+ i++;
157
+ }
158
+ }
159
+
160
+
161
+ #ifdef CUDA_FOUND
162
+
163
+ #include <cuda_runtime_api.h>
164
+
165
+ #define GRL_CHECK_CUDA(error)\
166
+ greedrl_check_cuda(error, __FILENAME__, __LINE__);
167
+
168
+ inline void greedrl_check_cuda(const cudaError_t& error,
169
+ const char* file, const int64_t line)
170
+ {
171
+ if(error==cudaSuccess)
172
+ {
173
+ return;
174
+ }
175
+
176
+ const int N = 2048;
177
+ static char buf[N];
178
+ snprintf(buf, N, "%s, at %s:%ld", cudaGetErrorString(error), file, line);
179
+ throw std::runtime_error(buf);
180
+ }
181
+
182
+ cudaDeviceProp& cuda_get_device_prop(int i);
183
+
184
+ #endif
csrc/pybind.cpp ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <pybind11/pybind11.h>
2
+ #include "task_group_split.h"
3
+ #include "task_group_priority.h"
4
+
5
+ namespace py = pybind11;
6
+
7
+ PYBIND11_MODULE(greedrl_c, m) {
8
+ m.def("task_group_split", &task_group_split);
9
+ m.def("task_group_priority", &task_group_priority);
10
+ }
11
+
csrc/task_group_priority.cpp ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "task_group_priority.h"
2
+
3
+ void task_group_priority_cpu(
4
+ int* group, int* priority, bool* value, bool* output,
5
+ int batch_size, int task_num, int group_num)
6
+ {
7
+ auto temp = torch::make_unique<int[]>(group_num);
8
+ for(int b=0; b<batch_size; b++)
9
+ {
10
+ for(int i=0; i<group_num; i++){
11
+ temp[i] = std::numeric_limits<int>::max();
12
+ }
13
+
14
+ for(int i=0; i<task_num; i++){
15
+ if(value[i]){
16
+ continue;
17
+ }
18
+ int g = group[i];
19
+ int p = priority[i];
20
+ if(p < temp[g]){
21
+ temp[g] = p;
22
+ }
23
+ }
24
+
25
+ for(int i=0; i<task_num; i++){
26
+ int g = group[i];
27
+ output[i] = priority[i]!=temp[g];
28
+ }
29
+
30
+ group += task_num;
31
+ priority += task_num;
32
+ value += task_num;
33
+ output += task_num;
34
+ }
35
+ };
36
+
37
+ auto task_group_priority(
38
+ const torch::Tensor& group,
39
+ const torch::Tensor& priority,
40
+ const torch::Tensor& value) -> torch::Tensor
41
+ {
42
+ auto device = group.device();
43
+
44
+ const int batch_size = group.size(0);
45
+ const int task_num = group.size(1);
46
+ const int group_num = group.max().item<int>() + 1;
47
+
48
+ const int _group_num = group.min().item<int>();
49
+
50
+ GRL_CHECK(group_num <= task_num && _group_num >= 0, "group value error");
51
+
52
+ GRL_CHECK_TENSOR(group, device, false, false, batch_size, task_num);
53
+ GRL_CHECK_TENSOR(priority, device, false, false, batch_size, task_num);
54
+ GRL_CHECK_TENSOR(value, device, false, false, batch_size, task_num);
55
+
56
+ auto output = torch::zeros({batch_size, task_num}, torch::dtype(torch::kBool).device(device));
57
+
58
+ switch(device.type())
59
+ {
60
+ case torch::kCPU:
61
+ task_group_priority_cpu(group.data_ptr<int>(), priority.data_ptr<int>(), value.data_ptr<bool>(),
62
+ output.data_ptr<bool>(), batch_size, task_num, group_num);
63
+ break;
64
+ #ifdef CUDA_FOUND
65
+ case torch::kCUDA:
66
+ task_group_priority_cuda(group.data_ptr<int>(), priority.data_ptr<int>(), value.data_ptr<bool>(),
67
+ output.data_ptr<bool>(), batch_size, task_num, group_num, device.index());
68
+ break;
69
+ #endif
70
+ default:
71
+ GRL_ERROR("unsupported device: %s", device.str().c_str());
72
+ }
73
+
74
+ return output;
75
+ };
csrc/task_group_priority.cu ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "task_group_priority.h"
2
+
3
+ __global__ void task_group_priority_kernel(
4
+ int* group, int* priority, bool* value, bool* output,
5
+ int batch_size, int task_num, int group_num)
6
+ {
7
+ group += blockIdx.x * task_num;
8
+ priority += blockIdx.x * task_num;
9
+ value += blockIdx.x * task_num;
10
+ output += blockIdx.x * task_num;
11
+
12
+ extern __shared__ int temp[];
13
+
14
+ for(int i=threadIdx.x; i<group_num; i+=blockDim.x)
15
+ {
16
+ temp[i] = std::numeric_limits<int>::max();
17
+ }
18
+
19
+ __syncthreads();
20
+
21
+ for(int i=threadIdx.x; i<task_num; i+=blockDim.x){
22
+ if(value[i]){
23
+ continue;
24
+ }
25
+ int g = group[i];
26
+ int p = priority[i];
27
+ atomicMin(&temp[g], p);
28
+ }
29
+
30
+ __syncthreads();
31
+
32
+ for(int i=threadIdx.x; i<task_num; i+=blockDim.x){
33
+ int g = group[i];
34
+ output[i] = priority[i]!=temp[g];
35
+ }
36
+ };
37
+
38
+ template<typename _Tg, typename _Tp>
39
+ __global__ void cuda_do_task_group_priority(
40
+ const torch::PackedTensorAccessor<_Tg,2,torch::RestrictPtrTraits> group,
41
+ const torch::PackedTensorAccessor<_Tp,2,torch::RestrictPtrTraits> priority,
42
+ const torch::PackedTensorAccessor<bool,2,torch::RestrictPtrTraits> value,
43
+ torch::PackedTensorAccessor<bool,2,torch::RestrictPtrTraits> result,
44
+ const _Tg NG)
45
+ {
46
+ const int NP = group.size(0);
47
+ const int NT = group.size(1);
48
+ const int p = blockIdx.x * blockDim.x + threadIdx.x;
49
+ if(p < NP)
50
+ {
51
+ extern __shared__ char _temp[];
52
+ auto temp = reinterpret_cast<_Tp*>(_temp);
53
+ temp += (threadIdx.x * NG);
54
+ for(_Tg g=0; g<NG; g++){
55
+ temp[g] = std::numeric_limits<_Tp>::max();
56
+ }
57
+
58
+ for(int t=0; t<NT; t++){
59
+ if(value[p][t]){
60
+ continue;
61
+ }
62
+ _Tg g = group[p][t];
63
+ _Tp _p = priority[p][t];
64
+ if(_p < temp[g]){
65
+ temp[g] = _p;
66
+ }
67
+ }
68
+
69
+ for(int t=0; t<NT; t++){
70
+ _Tg g = group[p][t];
71
+ if(priority[p][t]==temp[g]){
72
+ result[p][t] = false;
73
+ }
74
+ }
75
+ }
76
+ };
77
+
78
+
79
+
80
+ void task_group_priority_cuda(
81
+ int* group, int* priority, bool* value, bool* output,
82
+ const int batch_size, const int task_num, const int group_num, const int device)
83
+ {
84
+ const int shared_mem = group_num * sizeof(int);
85
+
86
+ GRL_CHECK_CUDA(cudaSetDevice(device));
87
+
88
+ task_group_priority_kernel<<<batch_size, 256, shared_mem>>>(
89
+ group, priority, value, output, batch_size, task_num, group_num);
90
+
91
+ GRL_CHECK_CUDA(cudaGetLastError());
92
+ };
93
+
csrc/task_group_priority.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "./common.h"
4
+
5
+ /**
6
+ * tasks are divided into groups,
7
+ * tasks in a group are visited by it's priority.
8
+ * the min priority value of unvisited tasks in a group is computed,
9
+ * output is false, if the task's priority equal the computed min priority, otherwise output is true
10
+ *
11
+ * group: task's group, shape is (batch_size, task_num)
12
+ * priority: task's priority, shape is (batch_size, task_num)
13
+ * value: task is visited or not, shape is (batch_size, task_num)
14
+ *
15
+ * output: the result, shape is (batch_size, task_num)
16
+ */
17
+ auto task_group_priority(
18
+ const torch::Tensor& group,
19
+ const torch::Tensor& priority,
20
+ const torch::Tensor& value) -> torch::Tensor;
21
+
22
+ void task_group_priority_cpu(
23
+ int* group, int* priority, bool* value, bool* ouput,
24
+ int batch_size, int task_num, int group_num);
25
+
26
+ void task_group_priority_cuda(
27
+ int* group, int* priority, bool* value, bool* ouput,
28
+ int batch_size, int task_num, int group_num, int device);
29
+
30
+
csrc/task_group_split.cpp ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "task_group_split.h"
2
+
3
+ void task_group_split_cpu(
4
+ int* group, bool* value, bool* output,
5
+ const int batch_size, const int task_num, const int group_num)
6
+ {
7
+ auto temp = torch::make_unique<bool[]>(group_num);
8
+ for(int b=0; b<batch_size; b++)
9
+ {
10
+ for(int i=0; i<group_num; i++){
11
+ temp[i] = false;
12
+ }
13
+
14
+ for(int i=0; i<task_num; i++){
15
+ if(value[i]){
16
+ int g = group[i];
17
+ temp[g] = true;
18
+ }
19
+ }
20
+
21
+ output[b] = false;
22
+ for(int i=0; i<task_num; i++){
23
+ int g = group[i];
24
+ if(temp[g] && !value[i]){
25
+ output[b] = true;
26
+ break;
27
+ }
28
+ }
29
+
30
+ group += task_num;
31
+ value += task_num;
32
+ }
33
+ };
34
+
35
+
36
+ auto task_group_split(
37
+ const Tensor& group, const Tensor& value) -> Tensor
38
+ {
39
+ auto device = group.device();
40
+ const int batch_size = group.size(0);
41
+ const int task_num = group.size(1);
42
+ const int group_num = group.max().item<int>() + 1;
43
+ const int _group_num = group.min().item<int>();
44
+
45
+ GRL_CHECK(group_num <= task_num && _group_num >= 0, "group value error");
46
+
47
+ GRL_CHECK_TENSOR(group, device, false, false, batch_size, task_num);
48
+ GRL_CHECK_TENSOR(value, device, false, false, batch_size, task_num);
49
+
50
+ auto output = torch::zeros({batch_size}, torch::dtype(torch::kBool).device(device));
51
+
52
+ switch(device.type())
53
+ {
54
+ case torch::kCPU:
55
+ task_group_split_cpu(group.data_ptr<int>(), value.data_ptr<bool>(),
56
+ output.data_ptr<bool>(), batch_size, task_num, group_num);
57
+ break;
58
+ #ifdef CUDA_FOUND
59
+ case torch::kCUDA:
60
+ task_group_split_cuda(group.data_ptr<int>(), value.data_ptr<bool>(),
61
+ output.data_ptr<bool>(), batch_size, task_num, group_num, device.index());
62
+ break;
63
+ #endif
64
+ default:
65
+ GRL_ERROR("unsupported device: %s", device.str().c_str());
66
+ }
67
+
68
+ return output;
69
+ };
csrc/task_group_split.cu ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "task_group_split.h"
2
+
3
+ __global__ void task_group_split_kernel(
4
+ int* group, bool* value, bool* output,
5
+ const int batch_size, const int task_num, const int group_num)
6
+ {
7
+ group += blockIdx.x * task_num;
8
+ value += blockIdx.x * task_num;
9
+ extern __shared__ bool temp[];
10
+
11
+ __shared__ bool split;
12
+ if(threadIdx.x == 0) split = false;
13
+
14
+ for(int i=threadIdx.x; i<group_num; i+=blockDim.x)
15
+ {
16
+ temp[i] = false;
17
+ }
18
+
19
+ __syncthreads();
20
+
21
+ for(int i=threadIdx.x; i<task_num; i+=blockDim.x)
22
+ {
23
+ int g = group[i];
24
+ if(value[i]) temp[g] = true;
25
+ }
26
+
27
+ __syncthreads();
28
+
29
+ for(int i=threadIdx.x; i<task_num; i+=blockDim.x)
30
+ {
31
+ int g = group[i];
32
+ if(temp[g] && !value[i]) split = true;
33
+ }
34
+
35
+ __syncthreads();
36
+
37
+ if(threadIdx.x == 0) output[blockIdx.x] = split;
38
+ };
39
+
40
+ void task_group_split_cuda(
41
+ int* group, bool* value, bool* output,
42
+ const int batch_size, const int task_num, const int group_num, const int device)
43
+ {
44
+ const int shared_mem = group_num * sizeof(bool);
45
+
46
+ GRL_CHECK_CUDA(cudaSetDevice(device));
47
+
48
+ task_group_split_kernel<<<batch_size, 256, shared_mem>>>(
49
+ group, value, output, batch_size, task_num, group_num);
50
+
51
+ GRL_CHECK_CUDA(cudaGetLastError());
52
+ };
53
+
csrc/task_group_split.h ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "./common.h"
4
+
5
+ /**
6
+ * tasks are divided into groups,
7
+ * if tasks in a group are all visited or all not visited,
8
+ * output is is false, otherwise output is true
9
+ *
10
+ * group: task's group, shape is (batch_size, task_num)
11
+ * value: task is visited or not, shape is (batch_size, task_num)
12
+ *
13
+ * output: the result, shape is (batch_size,)
14
+ */
15
+ auto task_group_split(const Tensor& group, const Tensor& value) -> Tensor;
16
+
17
+ void task_group_split_cpu(
18
+ int* group, bool* value, bool* output,
19
+ const int batch_size, const int task_num, const int group_num);
20
+
21
+ void task_group_split_cuda(
22
+ int* group, bool* value, bool* output,
23
+ const int batch_size, const int task_num, const int group_num, const int device);
24
+
examples/batching/batching.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from greedrl import Problem, Solver
4
+ from greedrl.feature import *
5
+ from greedrl.variable import *
6
+
7
+ features = [local_feature('task_area'),
8
+ local_feature('task_roadway'),
9
+ local_feature('task_area_group'),
10
+ sparse_local_feature('task_item_id', 'task_item_num'),
11
+ sparse_local_feature('task_item_owner_id', 'task_item_num'),
12
+ variable_feature('worker_task_item'),
13
+ variable_feature('worker_used_roadway'),
14
+ variable_feature('worker_used_area')]
15
+
16
+ variables = [task_demand_now('task_demand_now', feature='task_demand'),
17
+ task_demand_now('task_demand_this', feature='task_demand', only_this=True),
18
+ feature_variable('task_item_id'),
19
+ feature_variable('task_item_num'),
20
+ feature_variable('task_item_owner_id'),
21
+ feature_variable('task_area'),
22
+ feature_variable('task_area_group'),
23
+ feature_variable('task_load'),
24
+ feature_variable('task_group'),
25
+ worker_variable('worker_load_limit'),
26
+ worker_variable('worker_area_limit'),
27
+ worker_variable('worker_area_group_limit'),
28
+ worker_task_item('worker_task_item', item_id='task_item_id', item_num='task_item_num'),
29
+ worker_task_item('worker_task_item_owner', item_id='task_item_owner_id', item_num='task_item_num'),
30
+ worker_used_resource('worker_used_load', task_require='task_load'),
31
+ worker_used_resource('worker_used_area', task_require='task_area'),
32
+ worker_used_resource('worker_used_roadway', task_require='task_roadway'),
33
+ worker_used_resource('worker_used_area_group', task_require='task_area_group')]
34
+
35
+
36
+ class Constraint:
37
+
38
+ def do_task(self):
39
+ return self.task_demand_this
40
+
41
+ def mask_worker_end(self):
42
+ return self.worker_used_load < self.worker_load_limit
43
+
44
+ def mask_task(self):
45
+ # 已经完成的任务
46
+ mask = self.task_demand_now <= 0
47
+ # mask |= task_group_priority(self.task_group, self.task_out_stock_time, mask)
48
+
49
+ NT = self.task_item_id.size(1)
50
+ worker_task_item = self.worker_task_item[:, None, :]
51
+ worker_task_item = worker_task_item.expand(-1, NT, -1)
52
+ task_item_in_worker = worker_task_item.gather(2, self.task_item_id.long())
53
+ task_item_in_worker = (task_item_in_worker > 0) & (self.task_item_num > 0)
54
+
55
+ worker_task_item_owner = self.worker_task_item_owner[:, None, :]
56
+ worker_task_item_owner = worker_task_item_owner.expand(-1, NT, -1)
57
+ task_item_owner_in_worker = worker_task_item_owner.gather(2, self.task_item_owner_id.long())
58
+ task_item_owner_in_worker = (task_item_owner_in_worker > 0) & (self.task_item_num > 0)
59
+
60
+ # 同一个sku,不同货主,不能在一个拣选单
61
+ mask |= torch.any(task_item_in_worker & ~task_item_owner_in_worker, 2)
62
+
63
+ worker_load_limit = self.worker_load_limit - self.worker_used_load
64
+ mask |= (self.task_load > worker_load_limit[:, None])
65
+
66
+ task_area = self.task_area + self.worker_used_area[:, None, :]
67
+ task_area_num = task_area.clamp(0, 1).sum(2, dtype=torch.int32)
68
+ mask |= (task_area_num > self.worker_area_limit[:, None])
69
+
70
+ tak_area_group = self.task_area_group + self.worker_used_area_group[:, None, :]
71
+ tak_area_group_num = tak_area_group.clamp(0, 1).sum(2, dtype=torch.int32)
72
+ mask |= (tak_area_group_num > self.worker_area_group_limit[:, None])
73
+
74
+ return mask
75
+
76
+ def finished(self):
77
+ return torch.all(self.task_demand_now <= 0, 1)
78
+
79
+
80
+ class Objective:
81
+
82
+ def step_worker_end(self):
83
+ area_num = self.worker_used_area.clamp(0, 1).sum(1)
84
+ roadway_num = self.worker_used_roadway.clamp(0, 1).sum(1)
85
+ item_num = self.worker_task_item.clamp(0, 1).sum(1)
86
+ penalty = (self.worker_load_limit - self.worker_used_load) * 10
87
+ return area_num * 100 + roadway_num * 10 + item_num + penalty
88
+
89
+
90
+ def make_problem_from_json(data):
91
+ if isinstance(data, str):
92
+ data = json.loads(data)
93
+ problem = Problem()
94
+ problem.id = data["id"]
95
+ if 'uuid' in data:
96
+ problem.uuid = data["uuid"]
97
+
98
+ problem.task_item_id = torch.tensor(data["task_item_id"], dtype=torch.int32)
99
+ problem.task_item_owner_id = torch.tensor(data["task_item_owner_id"], dtype=torch.int32)
100
+ problem.task_item_num = torch.tensor(data["task_item_num"], dtype=torch.int32)
101
+ problem.task_area = torch.tensor(data["task_area"], dtype=torch.int32)
102
+ problem.task_roadway = torch.tensor(data["task_roadway"], dtype=torch.int32)
103
+ problem.task_out_stock_time = torch.tensor(data["task_out_stock_time"], dtype=torch.int32)
104
+ problem.task_area_group = torch.tensor(data["task_area_group"], dtype=torch.int32)
105
+
106
+ NT = problem.task_item_id.size(0)
107
+ problem.task_load = torch.ones(NT, dtype=torch.int32)
108
+ problem.task_group = torch.zeros(NT, dtype=torch.int32)
109
+ problem.task_demand = torch.ones(NT, dtype=torch.int32)
110
+
111
+ problem.worker_load_limit = torch.tensor(data["worker_load_limit"], dtype=torch.int32)
112
+ problem.worker_area_limit = torch.tensor(data["worker_area_limit"], dtype=torch.int32)
113
+ problem.worker_area_group_limit = torch.tensor(data["worker_area_group_limit"], dtype=torch.int32)
114
+
115
+ problem.features = features
116
+ problem.variables = variables
117
+ problem.constraint = Constraint
118
+ problem.objective = Objective
119
+
120
+ return problem
121
+
122
+
123
+ def make_problem(batch_count, batch_size=1, task_count=100):
124
+ assert batch_size == 1
125
+
126
+ NT = task_count
127
+ problem_list = []
128
+ for i in range(batch_count):
129
+ problem = Problem()
130
+ problem.id = i
131
+
132
+ device = Solver().device
133
+ p = torch.ones(NT, 1000, dtype=torch.float32, device=device)
134
+ problem.task_item_id = torch.multinomial(p, 10).to(torch.int32).cpu()
135
+ problem.task_item_owner_id = torch.multinomial(p, 10).to(torch.int32).cpu()
136
+ problem.task_item_num = torch.randint(0, 5, (NT, 10), dtype=torch.int32)
137
+ problem.task_area = torch.randint(0, 5, (NT, 10), dtype=torch.int32).clamp(0, 1)
138
+ problem.task_roadway = torch.randint(0, 5, (NT, 200), dtype=torch.int32).clamp(0, 1)
139
+ problem.task_area_group = torch.randint(0, 5, (NT, 10), dtype=torch.int32).clamp(0, 1)
140
+
141
+ problem.task_load = torch.ones(NT, dtype=torch.int32)
142
+ problem.task_group = torch.zeros(NT, dtype=torch.int32)
143
+ problem.task_demand = torch.ones(NT, dtype=torch.int32)
144
+
145
+ problem.worker_load_limit = torch.tensor([20], dtype=torch.int32)
146
+ problem.worker_area_limit = torch.tensor([10], dtype=torch.int32)
147
+ problem.worker_area_group_limit = torch.tensor([10], dtype=torch.int32)
148
+
149
+ problem.features = features
150
+ problem.variables = variables
151
+ problem.constraint = Constraint
152
+ problem.objective = Objective
153
+
154
+ problem_list.append(problem)
155
+
156
+ return problem_list
157
+
158
+
159
+ if __name__ == '__main__':
160
+ import sys
161
+ import os.path as osp
162
+ sys.path.append(osp.join(osp.dirname(__file__), '../'))
163
+ import runner
164
+
165
+ runner.run(make_problem)
examples/cvrp/cvrp.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from greedrl.feature import *
2
+ from greedrl.variable import *
3
+ from greedrl import Problem
4
+
5
+ features = [continuous_feature('task_demand'),
6
+ continuous_feature('worker_weight_limit'),
7
+ continuous_feature('distance_matrix'),
8
+ variable_feature('distance_this_to_task'),
9
+ variable_feature('distance_task_to_end')]
10
+
11
+ variables = [task_demand_now('task_demand_now', feature='task_demand'),
12
+ task_demand_now('task_demand_this', feature='task_demand', only_this=True),
13
+ feature_variable('task_weight'),
14
+ worker_variable('worker_weight_limit'),
15
+ worker_used_resource('worker_used_weight', task_require='task_weight'),
16
+ edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
17
+ edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
18
+ edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True)]
19
+
20
+
21
+ class Constraint:
22
+
23
+ def do_task(self):
24
+ return self.task_demand_this
25
+
26
+ def mask_task(self):
27
+ # 已经完成的任务
28
+ mask = self.task_demand_now <= 0
29
+ # 车辆容量限制
30
+ worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
31
+ mask |= self.task_demand_now * self.task_weight > worker_weight_limit[:, None]
32
+ return mask
33
+
34
+ def finished(self):
35
+ return torch.all(self.task_demand_now <= 0, 1)
36
+
37
+
38
+ class Objective:
39
+
40
+ def step_worker_end(self):
41
+ return self.distance_last_to_this
42
+
43
+ def step_task(self):
44
+ return self.distance_last_to_this
45
+
46
+
47
+ def make_problem(batch_count, batch_size=1, task_count=100):
48
+ assert task_count in (100, 1000, 2000, 5000)
49
+
50
+ weight_limit = 50
51
+ problem_list = []
52
+ for i in range(batch_count):
53
+ problem = Problem(True)
54
+ problem.id = torch.arange(batch_size) + i * batch_size;
55
+
56
+ problem.worker_weight_limit = torch.full((batch_size, 1), weight_limit, dtype=torch.int32)
57
+
58
+ N = task_count
59
+ problem.task_demand = torch.randint(1, 10, (batch_size, N), dtype=torch.int32)
60
+ problem.task_demand_x = problem.task_demand.float() / weight_limit
61
+
62
+ # 一个单位的task_demand的重量
63
+ problem.task_weight = torch.ones(batch_size, N, dtype=torch.int32)
64
+
65
+ loc = torch.rand(batch_size, N + 1, 2, dtype=torch.float32)
66
+ problem.task_location = loc[:, 1:, :]
67
+ problem.worker_location = loc[:, 0:1, :]
68
+
69
+ distance_matrix = torch.norm(loc[:, :, None, :] - loc[:, None, :, :], dim=3)
70
+ problem.distance_matrix = distance_matrix
71
+
72
+ problem.features = features
73
+ problem.variables = variables
74
+ problem.constraint = Constraint
75
+ problem.objective = Objective
76
+
77
+ problem_list.append(problem)
78
+
79
+ return problem_list
80
+
81
+
82
+ if __name__ == '__main__':
83
+ import sys
84
+ import os.path as osp
85
+ sys.path.append(osp.join(osp.dirname(__file__), '../'))
86
+ import runner
87
+
88
+ runner.run(make_problem)
examples/cvrp/orts.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import torch
4
+ import argparse
5
+ import utils
6
+ import multiprocessing as mp
7
+ from concurrent.futures import ThreadPoolExecutor
8
+
9
+ from ortools.constraint_solver import pywrapcp
10
+ from ortools.constraint_solver import routing_enums_pb2
11
+
12
+
13
+ def solve(problem, i, max_time):
14
+ scale = 100000
15
+ size = problem.task_demand.size(1)
16
+ demand = [0] + problem.task_demand[i].tolist()
17
+ capacity = problem.worker_weight_limit[i].tolist()
18
+ distance = (problem.distance_matrix[i] * scale + 0.5).to(torch.int32).tolist()
19
+
20
+ queue = mp.Queue()
21
+ p = mp.Process(target=do_solve, args=(size, demand, capacity, distance, max_time, queue))
22
+ p.start()
23
+ p.join()
24
+
25
+ return queue.get() / scale, queue.get()
26
+
27
+
28
+ def do_solve(size, demand, capacity, distance, max_time, queue):
29
+ capacity = capacity * size
30
+
31
+ manager = pywrapcp.RoutingIndexManager(size + 1, size, 0)
32
+ routing = pywrapcp.RoutingModel(manager)
33
+
34
+ def distance_callback(from_index, to_index):
35
+ from_node = manager.IndexToNode(from_index)
36
+ to_node = manager.IndexToNode(to_index)
37
+ return distance[from_node][to_node]
38
+
39
+ distance_callback_index = routing.RegisterTransitCallback(distance_callback)
40
+ routing.SetArcCostEvaluatorOfAllVehicles(distance_callback_index)
41
+
42
+ def demand_callback(from_index):
43
+ from_node = manager.IndexToNode(from_index)
44
+ return demand[from_node]
45
+
46
+ demand_callback_index = routing.RegisterUnaryTransitCallback(demand_callback)
47
+ routing.AddDimensionWithVehicleCapacity(demand_callback_index, 0, capacity, True, 'capacity')
48
+
49
+ params = pywrapcp.DefaultRoutingSearchParameters()
50
+ params.first_solution_strategy = (routing_enums_pb2.FirstSolutionStrategy.PATH_CHEAPEST_ARC)
51
+ params.local_search_metaheuristic = (routing_enums_pb2.LocalSearchMetaheuristic.GUIDED_LOCAL_SEARCH)
52
+ params.time_limit.seconds = max_time
53
+
54
+ start_time = time.time()
55
+ solution = routing.SolveWithParameters(params)
56
+ spent_time = time.time() - start_time
57
+
58
+ queue.put(solution.ObjectiveValue())
59
+ queue.put(spent_time)
60
+
61
+
62
+ def run_orts(task, max_time):
63
+ problem, i = task
64
+ return solve(problem, i, max_time)
65
+
66
+
67
+ def main(args):
68
+ print("args: {}".format(vars(args)))
69
+ problem_size = args.problem_size
70
+ problem_count = args.problem_count
71
+ batch_size = args.batch_size
72
+
73
+ assert problem_count % batch_size == 0
74
+ batch_count = problem_count // batch_size
75
+ problem_list = utils.make_problem(batch_count, batch_size, problem_size)
76
+
77
+ executor = ThreadPoolExecutor(max_workers=args.threads)
78
+ task_list = [(p, i) for p in problem_list for i in range(batch_size)]
79
+
80
+ total_cost = 0
81
+ total_time = 0
82
+ for cost, elapse in executor.map(run_orts, task_list, [args.max_time] * problem_count):
83
+ total_cost += cost
84
+ total_time += elapse
85
+
86
+ avg_cost = total_cost / problem_count
87
+ avg_time = total_time / problem_count
88
+ print()
89
+ print("-----------------------------------------------------")
90
+ print("avg_cost: {:.4f}".format(avg_cost))
91
+ print("avg_time: {:.6f}s".format(avg_time))
92
+ print("total_count: {}".format(problem_count))
93
+ print("-----------------------------------------------------\n")
94
+ sys.stdout.flush()
95
+
96
+
97
+ if __name__ == '__main__':
98
+ parser = argparse.ArgumentParser()
99
+ parser.add_argument('--threads', default=20, type=int, help='number of threads')
100
+ parser.add_argument('--max_time', default=60, type=int, help='the time limit for the search in seconds')
101
+
102
+ parser.add_argument('--problem_size', default=100, type=int, choices=[100, 1000, 2000, 5000], help='problem size')
103
+ parser.add_argument('--problem_count', default=128, type=int, help='total number of generated problem instances')
104
+ parser.add_argument('--batch_size', default=128, type=int, help='batch size for feedforwarding')
105
+
106
+ args = parser.parse_args()
107
+ main(args)
examples/cvrp/solve.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import torch
5
+ import argparse
6
+ import utils
7
+ from greedrl import Solver
8
+
9
+ torch.set_num_threads(1)
10
+ torch.set_num_interop_threads(1)
11
+
12
+
13
+ def do_solve(args):
14
+ print("args: {}".format(vars(args)))
15
+
16
+ problem_size = args.problem_size
17
+ problem_count = args.problem_count
18
+ batch_size = args.batch_size
19
+ assert problem_count % batch_size == 0
20
+ batch_count = problem_count // batch_size
21
+
22
+ problem_list = utils.make_problem(batch_count, batch_size, problem_size)
23
+
24
+ solver = Solver(device=args.device)
25
+
26
+ model_path = os.path.join('./', args.model_name)
27
+ solver.load_agent(model_path)
28
+
29
+ total_cost = 0
30
+
31
+ if solver.device.type == 'cuda':
32
+ torch.cuda.synchronize()
33
+
34
+ start_time = time.time()
35
+ for problem in problem_list:
36
+ solution = solver.solve(problem, greedy=False, batch_size=batch_size)
37
+ total_cost += solution.cost.sum().item()
38
+
39
+ if solver.device.type == 'cuda':
40
+ torch.cuda.synchronize()
41
+
42
+ total_time = time.time() - start_time
43
+
44
+ avg_cost = total_cost / problem_count
45
+ avg_time = total_time / problem_count
46
+ print()
47
+ print("-----------------------------------------------------")
48
+ print("avg_cost: {:.4f}".format(avg_cost))
49
+ print("avg_time: {:.6f}s".format(avg_time))
50
+ print("total_count: {}".format(problem_count))
51
+ print("-----------------------------------------------------\n")
52
+ sys.stdout.flush()
53
+
54
+
55
+ if __name__ == '__main__':
56
+ parser = argparse.ArgumentParser()
57
+ parser.add_argument('--device', default='cpu', choices=['cpu', 'cuda'], help="choose a device")
58
+ parser.add_argument('--model_name', default='cvrp_100.pt', choices=['cvrp_100.pt', 'cvrp_1000.pt', 'cvrp_2000.pt', 'cvrp_5000.pt'], help="choose a model")
59
+ parser.add_argument('--problem_size', default=100, type=int, choices=[100, 1000, 2000, 5000], help='problem size')
60
+ parser.add_argument('--problem_count', default=128, type=int, help='total number of generated problem instances')
61
+ parser.add_argument('--batch_size', default=128, type=int, help='batch size for feedforwarding')
62
+
63
+ args = parser.parse_args()
64
+ do_solve(args)
65
+
examples/cvrp/train.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import math
3
+ import argparse
4
+ import torch.distributed as dist
5
+ import torch.multiprocessing as mp
6
+ import utils
7
+ from greedrl import Solver
8
+
9
+
10
+ def do_train(args, rank):
11
+ world_size = args.world_size
12
+ model_filename = args.model_filename
13
+ problem_size = args.problem_size
14
+ batch_size = args.batch_size
15
+
16
+ index = model_filename.rfind('.')
17
+ if world_size > 1:
18
+ stdout_filename = '{}_r{}.log'.format(model_filename[0:index], rank)
19
+ else:
20
+ stdout_filename = '{}.log'.format(model_filename[0:index])
21
+
22
+ stdout = open(stdout_filename, 'a')
23
+ sys.stdout = stdout
24
+ sys.stderr = stdout
25
+
26
+ print("args: {}".format(vars(args)))
27
+ if world_size > 1:
28
+ dist.init_process_group('NCCL', init_method='tcp://127.0.0.1:29500',
29
+ rank=rank, world_size=world_size)
30
+
31
+ problem_batch_size = 8
32
+ batch_count = 0
33
+ if problem_size == 100:
34
+ batch_count = math.ceil(10000 / problem_batch_size)
35
+ elif problem_size == 1000:
36
+ batch_count = math.ceil(200 / problem_batch_size)
37
+ elif problem_size == 2000:
38
+ batch_count = math.ceil(100 / problem_batch_size)
39
+ elif problem_size == 5000:
40
+ batch_count = math.ceil(10 / problem_batch_size)
41
+ else:
42
+ raise Exception("unsupported problem size: {}".format(problem_size))
43
+
44
+ nn_args = {
45
+ 'encode_norm': 'instance',
46
+ 'encode_layers': 6,
47
+ 'decode_rnn': 'LSTM'
48
+ }
49
+
50
+ device = None if world_size == 1 else 'cuda:{}'.format(rank)
51
+ solver = Solver(device, nn_args)
52
+
53
+ train_dataset = utils.Dataset(None, problem_batch_size, problem_size)
54
+ valid_dataset = utils.Dataset(batch_count, problem_batch_size, problem_size)
55
+
56
+ solver.train(model_filename, train_dataset, valid_dataset,
57
+ train_dataset_workers=5,
58
+ batch_size=batch_size,
59
+ memopt=10,
60
+ topk_size=1,
61
+ init_lr=1e-4,
62
+ valid_steps=500,
63
+ warmup_steps=0)
64
+
65
+
66
+ if __name__ == '__main__':
67
+
68
+ parser = argparse.ArgumentParser()
69
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
70
+ parser.add_argument('--model_filename', type=str, help='model file name')
71
+ parser.add_argument('--problem_size', default=100, type=int, choices=[100, 1000, 2000, 5000], help='problem size')
72
+ parser.add_argument('--batch_size', default=128, type=int, help='batch size for training')
73
+
74
+ args = parser.parse_args()
75
+
76
+ processes = []
77
+ for rank in range(args.world_size):
78
+ p = mp.Process(target=do_train, args=(args, rank))
79
+ p.start()
80
+ processes.append(p)
81
+
82
+ for p in processes:
83
+ p.join()
examples/cvrp/utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from greedrl.feature import *
2
+ from cvrp import make_problem as make_cvrp_problem
3
+ from torch.utils.data import Dataset, IterableDataset, DataLoader
4
+
5
+
6
+ def make_problem(batch_count, batch_size, task_count):
7
+ features = [continuous_feature('task_demand_x'),
8
+ continuous_feature('distance_matrix')]
9
+
10
+ problem_list = make_cvrp_problem(batch_count, batch_size, task_count)
11
+ for problem in problem_list:
12
+ problem.features = features
13
+
14
+ return problem_list
15
+
16
+
17
+ class Dataset(IterableDataset):
18
+ def __init__(self, batch_count, batch_size, task_count):
19
+ self._batch_size = batch_size
20
+ self._task_count = task_count
21
+ self._batch_count = batch_count
22
+ self._index = 0
23
+
24
+ def __iter__(self):
25
+ self._index = 0
26
+ return self
27
+
28
+ def __next__(self):
29
+ if self._batch_count is not None \
30
+ and self._index >= self._batch_count:
31
+ raise StopIteration()
32
+
33
+ p = make_problem(1, self._batch_size, self._task_count)[0]
34
+ self._index += 1
35
+ return p
36
+
37
+
38
+ def write_vrplib(filename, name, size, demand, capacity, location):
39
+ with open(filename, 'w') as f:
40
+ f.write('\n'.join([
41
+ "{} : {}".format(k, v)
42
+ for k, v in (
43
+ ('NAME', name),
44
+ ('TYPE', 'CVRP'),
45
+ ('COMMENT', 'NONE'),
46
+ ('DIMENSION', size + 1),
47
+ ('EDGE_WEIGHT_TYPE', 'EUC_2D'),
48
+ ('CAPACITY', capacity)
49
+ )
50
+ ]))
51
+
52
+ f.write('\n')
53
+ f.write('NODE_COORD_SECTION\n')
54
+
55
+ f.write('\n'.join(['{}\t{}\t{}'.format(i + 1, x, y) for i, (x, y) in enumerate(location)]))
56
+
57
+ f.write('\n')
58
+ f.write('DEMAND_SECTION\n')
59
+ f.write('\n'.join(['{}\t{}'.format(i + 1, d) for i, d in enumerate([0] + demand)]))
60
+
61
+ f.write('\n')
62
+ f.write('DEPOT_SECTION\n')
63
+ f.write('1\n')
64
+ f.write('-1\n')
65
+ f.write('EOF\n')
examples/dpdp/dpdp.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from greedrl.feature import *
4
+ from greedrl.variable import *
5
+ from greedrl.function import *
6
+ from greedrl import Problem
7
+
8
+ features = [local_category('task_order'),
9
+ global_category('task_type', 2),
10
+ global_category('task_new_order', 2),
11
+ variable_feature('time_this_to_task'),
12
+ continuous_feature('x_time_matrix'),
13
+ continuous_feature('task_due_time_x'),
14
+ continuous_feature('worker_task_mask')]
15
+
16
+ variables = [task_demand_now('task_demand_now', feature='task_demand'),
17
+ task_demand_now('task_demand_this', feature='task_demand', only_this=True),
18
+ task_variable('task_pickup_this', feature='task_pickup'),
19
+ task_variable('task_due_time_this', feature='task_due_time'),
20
+ feature_variable('task_order', feature='task_order'),
21
+ feature_variable('task_type', feature='task_type'),
22
+ feature_variable('task_new_pickup', feature='task_new_pickup'),
23
+ feature_variable('worker_task_mask', feature='worker_task_mask'),
24
+ worker_count_now('worker_count_now', feature='worker_count'),
25
+ worker_variable('worker_min_old_task_this', feature='worker_min_old_task'),
26
+ worker_variable('worker_max_new_order_this', feature='worker_max_new_order'),
27
+ worker_variable('worker_task_mask_this', feature='worker_task_mask'),
28
+ worker_used_resource('worker_used_old_task', task_require='task_old'),
29
+ worker_used_resource('worker_used_new_order', task_require='task_new_pickup'),
30
+ worker_used_resource('worker_used_time', edge_require='time_matrix'),
31
+ edge_variable('time_this_to_task', feature='x_time_matrix', this_to_task=True)]
32
+
33
+
34
+ class Constraint:
35
+
36
+ def do_task(self):
37
+ return self.task_demand_this
38
+
39
+ def mask_worker_start(self):
40
+ mask = self.worker_count_now <= 0
41
+
42
+ finished = self.task_demand_now <= 0
43
+ worker_task_mask = self.worker_task_mask | finished[:, None, :]
44
+ mask |= torch.all(worker_task_mask, 2)
45
+
46
+ return mask
47
+
48
+ def mask_worker_end(self):
49
+ mask = self.worker_used_old_task < self.worker_min_old_task_this
50
+ mask |= task_group_split(self.task_order, self.task_demand_now <= 0)
51
+ return mask
52
+
53
+ def mask_task(self):
54
+ mask = self.task_demand_now <= 0
55
+
56
+ mask |= task_group_priority(self.task_order, self.task_type, mask)
57
+
58
+ worker_max_new_order = self.worker_max_new_order_this - self.worker_used_new_order
59
+ mask |= self.task_new_pickup > worker_max_new_order[:, None]
60
+
61
+ mask |= self.worker_task_mask_this
62
+
63
+ return mask
64
+
65
+ def finished(self):
66
+ worker_mask = self.worker_count_now <= 0
67
+ task_mask = self.task_demand_now <= 0
68
+ worker_task_mask = worker_mask[:, :, None] | task_mask[:, None, :]
69
+
70
+ worker_task_mask |= self.worker_task_mask
71
+ batch_size = worker_task_mask.size(0)
72
+ worker_task_mask = worker_task_mask.view(batch_size, -1)
73
+ return worker_task_mask.all(1)
74
+
75
+
76
+ class Objective:
77
+
78
+ def step_task(self):
79
+ over_time = (self.worker_used_time - self.task_due_time_this).clamp(min=0)
80
+ pickup_time = self.worker_used_time * self.task_pickup_this
81
+ return self.worker_used_time + over_time + pickup_time
82
+
83
+ def step_finish(self):
84
+ return self.task_demand_now.sum(1) * 1000
85
+
86
+
87
+ def preprocess(problem):
88
+ NW, NT = problem.worker_task_mask.size()
89
+
90
+ worker_task_old = torch.ones(NW, NT, dtype=torch.int32)
91
+ new_task_mask = problem.task_new_order[None, :].expand(NW, NT)
92
+ worker_task_old[new_task_mask] = 0
93
+ worker_task_old[problem.worker_task_mask] = 0
94
+ assert torch.all(worker_task_old.sum(0) <= 1)
95
+ problem.worker_min_old_task = worker_task_old.sum(1)
96
+
97
+ problem.worker_count = torch.ones(NW, dtype=torch.int32)
98
+ problem.task_demand = torch.ones(NT, dtype=torch.int32)
99
+ problem.task_pickup = (problem.task_type == 0).to(torch.int32)
100
+
101
+ task_old = torch.ones(NT, dtype=torch.int32)
102
+ task_old[problem.task_new_order] = 0
103
+ problem.task_old = task_old
104
+
105
+ task_new_pickup = torch.ones(NT, dtype=torch.int32)
106
+ task_new_pickup[problem.task_type >= 1] = 0
107
+ task_new_pickup[~problem.task_new_order] = 0
108
+ problem.task_new_pickup = task_new_pickup
109
+
110
+ problem.task_due_time_x = problem.task_due_time.float() / 900
111
+ problem.x_time_matrix = problem.time_matrix.float() / 900
112
+
113
+ problem.features = features
114
+ problem.variables = variables
115
+ problem.constraint = Constraint
116
+ problem.objective = Objective
117
+
118
+ return problem
119
+
120
+
121
+ def make_problem_from_json(data):
122
+ data = json.loads(data)
123
+
124
+ problem = Problem()
125
+
126
+ problem.id = data['id']
127
+ problem.task_order = torch.tensor(data['task_order'], dtype=torch.int32)
128
+ problem.task_type = torch.tensor(data['task_type'], dtype=torch.int32)
129
+ problem.task_new_order = torch.tensor(data['task_new_order'], dtype=torch.bool)
130
+ problem.task_due_time = torch.tensor(data['task_due_time'], dtype=torch.int32)
131
+
132
+ problem.worker_max_new_order = torch.tensor(data['worker_max_new_order'], dtype=torch.int32)
133
+ problem.worker_task_mask = torch.tensor(data['worker_task_mask'], dtype=torch.bool)
134
+ problem.time_matrix = torch.tensor(data['time_matrix'], dtype=torch.int32)
135
+
136
+ NW, NT = problem.worker_task_mask.size()
137
+
138
+ assert problem.task_order.size() == (NT,), "task_order size error"
139
+ assert problem.task_type.size() == (NT,), "task_type size error"
140
+ assert problem.task_new_order.size() == (NT,), "task_new_order size error"
141
+ assert problem.task_due_time.size() == (NT,), "task_due_time size error"
142
+ assert problem.worker_max_new_order.size() == (NW,), "worker_max_new_order size error"
143
+ assert problem.time_matrix.size() == (NW + NT, NW + NT), "time_matrix size error"
144
+
145
+ return preprocess(problem)
146
+
147
+
148
+ def make_problem(batch_count, batch_size=1, task_count=100):
149
+ assert batch_size == 1
150
+ assert task_count == 100
151
+
152
+ NW = 100
153
+ NT = task_count
154
+ NO = NT // 2 # 订单数, 一个订单有pickup, delivery两个任务
155
+ problem_list = []
156
+ for i in range(batch_count):
157
+ problem = Problem()
158
+
159
+ # user-provided data
160
+ problem.worker_max_new_order = torch.full((NW,), 2, dtype=torch.int32)
161
+
162
+ task_order = torch.arange(NO, dtype=torch.int32)
163
+ problem.task_order = torch.cat([task_order, task_order], 0)
164
+
165
+ task_type = torch.zeros(NO, dtype=torch.int32)
166
+ problem.task_type = torch.cat([task_type, task_type + 1], 0)
167
+
168
+ problem.task_new_order = torch.ones(NT, dtype=torch.bool)
169
+
170
+ task_due_time = torch.randint(1000, 1800, (NO,), dtype=torch.int32)
171
+ problem.task_due_time = torch.cat([task_due_time, task_due_time + 1800], 0)
172
+
173
+ worker_task_mask = torch.rand(NW, NO) < 0.9
174
+ problem.worker_task_mask = torch.cat([worker_task_mask, worker_task_mask], 1)
175
+
176
+ loc = torch.rand(NW + NT, 2, dtype=torch.float32)
177
+ time_matrix = torch.norm(loc[:, None, :] - loc[None, :, :], dim=2) * 1000
178
+ problem.time_matrix = time_matrix.to(torch.int32)
179
+
180
+ problem_list.append(preprocess(problem))
181
+
182
+ return problem_list
183
+
184
+
185
+ if __name__ == '__main__':
186
+ import sys
187
+ import os.path as osp
188
+ sys.path.append(osp.join(osp.dirname(__file__), '../'))
189
+ import runner
190
+
191
+ runner.run(make_problem)
examples/pdptw/pdptw.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from greedrl.feature import *
2
+ from greedrl.variable import *
3
+ from greedrl.function import *
4
+ from greedrl import Problem
5
+
6
+ features = [local_category('task_group'),
7
+ global_category('task_priority', 2),
8
+ variable_feature('distance_this_to_task'),
9
+ variable_feature('distance_task_to_end')]
10
+
11
+ variables = [task_demand_now('task_demand_now', feature='task_demand'),
12
+ task_demand_now('task_demand_this', feature='task_demand', only_this=True),
13
+ feature_variable('task_weight'),
14
+ feature_variable('task_group'),
15
+ feature_variable('task_priority'),
16
+ feature_variable('task_due_time2', feature='task_due_time'),
17
+ task_variable('task_due_time'),
18
+ task_variable('task_service_time'),
19
+ task_variable('task_due_time_penalty'),
20
+ worker_variable('worker_basic_cost'),
21
+ worker_variable('worker_distance_cost'),
22
+ worker_variable('worker_due_time'),
23
+ worker_variable('worker_weight_limit'),
24
+ worker_used_resource('worker_used_weight', task_require='task_weight'),
25
+ worker_used_resource('worker_used_time', 'distance_matrix', 'task_service_time', 'task_ready_time',
26
+ 'worker_ready_time'),
27
+ edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
28
+ edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
29
+ edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True)]
30
+
31
+
32
+ class Constraint:
33
+
34
+ def do_task(self):
35
+ return self.task_demand_this
36
+
37
+ def mask_worker_end(self):
38
+ return task_group_split(self.task_group, self.task_demand_now <= 0)
39
+
40
+ def mask_task(self):
41
+ mask = self.task_demand_now <= 0
42
+ mask |= task_group_priority(self.task_group, self.task_priority, mask)
43
+
44
+ worker_used_time = self.worker_used_time[:, None] + self.distance_this_to_task
45
+ mask |= (worker_used_time > self.task_due_time2) & (self.task_priority == 0)
46
+
47
+ # 容量约束
48
+ worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
49
+ mask |= self.task_demand_now * self.task_weight > worker_weight_limit[:, None]
50
+ return mask
51
+
52
+ def finished(self):
53
+ return torch.all(self.task_demand_now <= 0, 1)
54
+
55
+
56
+ class Objective:
57
+
58
+ def step_worker_start(self):
59
+ return self.worker_basic_cost
60
+
61
+ def step_worker_end(self):
62
+ feasible = self.worker_used_time <= self.worker_due_time
63
+ return self.distance_last_to_this * self.worker_distance_cost, feasible
64
+
65
+ def step_task(self):
66
+ worker_used_time = self.worker_used_time - self.task_service_time
67
+ feasible = worker_used_time <= self.task_due_time
68
+ feasible &= worker_used_time <= self.worker_due_time
69
+ cost = self.distance_last_to_this * self.worker_distance_cost
70
+ return torch.where(feasible, cost, cost + self.task_due_time_penalty), feasible
71
+
72
+
73
+ def make_problem(batch_count, batch_size=1, task_count=100):
74
+ assert batch_size == 1
75
+
76
+ N = task_count // 2 # 订单数, 一个订单有pickup, delivery两个任务
77
+ problem_list = []
78
+ for i in range(batch_count):
79
+ problem = Problem()
80
+ problem.id = i
81
+
82
+ problem.worker_weight_limit = torch.tensor([50], dtype=torch.float32)
83
+ problem.worker_ready_time = torch.tensor([0], dtype=torch.float32)
84
+ problem.worker_due_time = torch.tensor([1000000], dtype=torch.float32)
85
+ problem.worker_basic_cost = torch.tensor([100], dtype=torch.float32)
86
+ problem.worker_distance_cost = torch.tensor([1], dtype=torch.float32)
87
+
88
+ task_demand = torch.randint(1, 10, (N,), dtype=torch.int32)
89
+ problem.task_demand = torch.cat([task_demand, task_demand], 0)
90
+
91
+ task_weight = torch.ones(N, dtype=torch.float32)
92
+ problem.task_weight = torch.cat([task_weight, task_weight * -1], 0)
93
+
94
+ task_group = torch.arange(N, dtype=torch.int32)
95
+ problem.task_group = torch.cat([task_group, task_group], 0)
96
+
97
+ task_priority = torch.zeros(N, dtype=torch.int32)
98
+ problem.task_priority = torch.cat([task_priority, task_priority + 1], 0)
99
+
100
+ task_ready_time = torch.zeros(N, dtype=torch.float32)
101
+ problem.task_ready_time = torch.cat([task_ready_time, task_ready_time], 0)
102
+
103
+ task_due_time = torch.randint(10000, 100000, (N,), dtype=torch.float32)
104
+ problem.task_due_time = torch.cat([task_due_time, task_due_time * 2], 0)
105
+
106
+ task_service_time = torch.zeros(N, dtype=torch.float32)
107
+ problem.task_service_time = torch.cat([task_service_time, task_service_time])
108
+
109
+ task_due_time_penalty = torch.ones(N, dtype=torch.float32)
110
+ problem.task_due_time_penalty = torch.cat([task_due_time_penalty, task_due_time_penalty])
111
+
112
+ loc = torch.rand(N + 1, 2, dtype=torch.float32)
113
+ distance_matrix = torch.norm(loc[:, None, :] - loc[None, :, :], dim=2) * 1000
114
+ distance_matrix = distance_matrix.to(torch.float32)
115
+ index = torch.cat([torch.zeros(N + 1, dtype=torch.int64), torch.arange(N, dtype=torch.int64) + 1])
116
+ index1 = index[:, None]
117
+ index2 = index[None, :]
118
+ problem.distance_matrix = distance_matrix[index1, index2]
119
+
120
+ problem.features = features
121
+ problem.variables = variables
122
+ problem.constraint = Constraint
123
+ problem.objective = Objective
124
+
125
+ problem_list.append(problem)
126
+
127
+ return problem_list
128
+
129
+
130
+ if __name__ == '__main__':
131
+ import sys
132
+ import os.path as osp
133
+ sys.path.append(osp.join(osp.dirname(__file__), '../'))
134
+ import runner
135
+
136
+ runner.run(make_problem)
examples/runner.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import random
3
+ import argparse
4
+ import torch
5
+
6
+ from greedrl import Problem, Solution, Solver
7
+
8
+
9
+ def run(make_problem, mask_task_ratio=0.1):
10
+ random.seed(123)
11
+ torch.manual_seed(123)
12
+ problem_list = make_problem(1)
13
+
14
+ parser = argparse.ArgumentParser(description="")
15
+ parser.add_argument('--device', default=None, type=str)
16
+ parser.add_argument('--batch_size', default=32, type=int)
17
+ parser.add_argument('--agent_file', default=None, type=str)
18
+ parser.add_argument('--valid_steps', default=5, type=int)
19
+ parser.add_argument('--max_steps', default=10000000, type=int)
20
+
21
+ args, _ = parser.parse_known_args()
22
+ for k, v in args.__dict__.items():
23
+ print("arg: {} = {}".format(k, v))
24
+
25
+ # rl train
26
+ solver = Solver(device=args.device)
27
+ solver.train(args.agent_file, problem_list, problem_list,
28
+ batch_size=args.batch_size, valid_steps=args.valid_steps, max_steps=args.max_steps)
29
+ # predict
30
+ solver = Solver(device=args.device)
31
+ if args.agent_file is not None:
32
+ solver.load_agent(args.agent_file)
33
+
34
+ print("solve ...")
35
+ start = time.time()
36
+ for problem in problem_list:
37
+ solver.solve(problem, batch_size=args.batch_size)
38
+ print("time: {}s".format(time.time() - start))
examples/sdvrp/sdvrp.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from greedrl.feature import *
2
+ from greedrl.variable import *
3
+ from greedrl import Problem
4
+
5
+ features = [continuous_feature('task_demand'),
6
+ continuous_feature('worker_weight_limit'),
7
+ continuous_feature('distance_matrix'),
8
+ variable_feature('distance_this_to_task'),
9
+ variable_feature('distance_task_to_end')]
10
+
11
+ variables = [task_demand_now('task_demand'),
12
+ task_demand_now('task_demand_this', feature='task_demand', only_this=True),
13
+ feature_variable('task_weight'),
14
+ task_variable('task_weight_this', feature='task_weight'),
15
+ worker_variable('worker_weight_limit'),
16
+ worker_used_resource('worker_used_weight', task_require='task_weight'),
17
+ edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True)]
18
+
19
+
20
+ class Constraint:
21
+
22
+ def do_task(self):
23
+ worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
24
+ return torch.min(self.task_demand_this, worker_weight_limit // self.task_weight_this)
25
+
26
+ def mask_task(self):
27
+ # 已经完成的任务
28
+ mask = self.task_demand <= 0
29
+ # 车辆容量限制
30
+ worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
31
+ # 至少要能装下一个单位的demand
32
+ mask |= self.task_weight > worker_weight_limit[:, None]
33
+ return mask
34
+
35
+ def finished(self):
36
+ return torch.all(self.task_demand <= 0, 1)
37
+
38
+
39
+ class Objective:
40
+
41
+ def step_worker_end(self):
42
+ return self.distance_last_to_this
43
+
44
+ def step_task(self):
45
+ return self.distance_last_to_this
46
+
47
+
48
+ def make_problem(batch_count, batch_size=1, task_count=100):
49
+ assert batch_size == 1
50
+
51
+ NT = task_count
52
+ problem_list = []
53
+ for i in range(batch_count):
54
+ problem = Problem()
55
+ problem.id = i
56
+
57
+ problem.worker_weight_limit = [50]
58
+
59
+ problem.task_demand = torch.randint(1, 10, (NT,), dtype=torch.int64)
60
+
61
+ # 一个单位的task_demand的重量
62
+ problem.task_weight = torch.ones(NT, dtype=torch.int64)
63
+
64
+ loc = torch.rand(NT + 1, 2, dtype=torch.float32)
65
+ distance_matrix = torch.norm(loc[:, None, :] - loc[None, :, :], dim=2) * 1000
66
+ problem.distance_matrix = distance_matrix.to(torch.int64)
67
+
68
+ problem.variables = variables
69
+ problem.constraint = Constraint
70
+ problem.objective = Objective
71
+
72
+ problem_list.append(problem)
73
+
74
+ return problem_list
75
+
76
+
77
+ if __name__ == '__main__':
78
+ import sys
79
+ import os.path as osp
80
+ sys.path.append(osp.join(osp.dirname(__file__), '../'))
81
+ import runner
82
+
83
+ runner.run(make_problem)
examples/tsp/tsp.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from greedrl.feature import *
2
+ from greedrl.variable import *
3
+ from greedrl import Problem
4
+
5
+ features = [continuous_feature('task_location'),
6
+ variable_feature('distance_this_to_task'),
7
+ variable_feature('distance_task_to_end')]
8
+
9
+ variables = [task_demand_now('task_demand_now', feature='task_demand'),
10
+ task_demand_now('task_demand_this', feature='task_demand', only_this=True),
11
+ edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
12
+ edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
13
+ edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True),
14
+ edge_variable('distance_last_to_loop', feature='distance_matrix', last_to_loop=True)]
15
+
16
+
17
+ class Constraint:
18
+
19
+ def do_task(self):
20
+ return self.task_demand_this
21
+
22
+ def mask_task(self):
23
+ # 已经完成的任务
24
+ mask = self.task_demand_now <= 0
25
+ return mask
26
+
27
+ def mask_worker_end(self):
28
+ return torch.any(self.task_demand_now > 0, 1)
29
+
30
+ def finished(self):
31
+ return torch.all(self.task_demand_now <= 0, 1)
32
+
33
+
34
+ class Objective:
35
+
36
+ def step_worker_end(self):
37
+ return self.distance_last_to_loop
38
+
39
+ def step_task(self):
40
+ return self.distance_last_to_this
41
+
42
+
43
+ def make_problem(batch_count, batch_size=1, task_count=100):
44
+ NP = batch_size
45
+ NT = task_count
46
+ problem_list = []
47
+ for i in range(batch_count):
48
+ problem = Problem(True)
49
+
50
+ problem.task_demand = torch.ones(NP, NT, dtype=torch.int32)
51
+
52
+ loc = torch.rand(NP, NT + 1, 2, dtype=torch.float32)
53
+ problem.distance_matrix = torch.norm(loc[:, :, None, :] - loc[:, None, :, :], dim=3)
54
+ problem.distance_matrix[0, :] = 0
55
+ problem.distance_matrix[:, 0] = 0
56
+
57
+ problem.task_location = loc[:, 1:]
58
+
59
+ problem.features = features
60
+ problem.variables = variables
61
+ problem.constraint = Constraint
62
+ problem.objective = Objective
63
+
64
+ problem_list.append(problem)
65
+ return problem_list
66
+
67
+
68
+ if __name__ == '__main__':
69
+ import sys
70
+ import os.path as osp
71
+ sys.path.append(osp.join(osp.dirname(__file__), '../'))
72
+ import runner
73
+
74
+ runner.run(make_problem)
examples/vrptw/vrptw.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from greedrl import Problem
4
+ from greedrl.feature import *
5
+ from greedrl.variable import *
6
+
7
+ features = [continuous_feature('worker_weight_limit'),
8
+ continuous_feature('worker_ready_time'),
9
+ continuous_feature('worker_due_time'),
10
+ continuous_feature('worker_basic_cost'),
11
+ continuous_feature('worker_distance_cost'),
12
+ continuous_feature('task_demand'),
13
+ continuous_feature('task_weight'),
14
+ continuous_feature('task_ready_time'),
15
+ continuous_feature('task_due_time'),
16
+ continuous_feature('task_service_time'),
17
+ continuous_feature('distance_matrix')]
18
+
19
+ variables = [task_demand_now('task_demand_now', feature='task_demand'),
20
+ task_demand_now('task_demand_this', feature='task_demand', only_this=True),
21
+ feature_variable('task_weight'),
22
+ feature_variable('task_due_time'),
23
+ feature_variable('task_ready_time'),
24
+ feature_variable('task_service_time'),
25
+ worker_variable('worker_weight_limit'),
26
+ worker_variable('worker_due_time'),
27
+ worker_variable('worker_basic_cost'),
28
+ worker_variable('worker_distance_cost'),
29
+ worker_used_resource('worker_used_weight', task_require='task_weight'),
30
+ worker_used_resource('worker_used_time', 'distance_matrix', 'task_service_time', 'task_ready_time',
31
+ 'worker_ready_time'),
32
+ edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True),
33
+ edge_variable('distance_this_to_task', feature='distance_matrix', this_to_task=True),
34
+ edge_variable('distance_task_to_end', feature='distance_matrix', task_to_end=True)]
35
+
36
+
37
+ class Constraint:
38
+
39
+ def do_task(self):
40
+ return self.task_demand_this
41
+
42
+ def mask_task(self):
43
+ # 已经完成的任务
44
+ mask = self.task_demand_now <= 0
45
+ # 车辆容量限制
46
+ worker_weight_limit = self.worker_weight_limit - self.worker_used_weight
47
+ mask |= self.task_demand_now * self.task_weight > worker_weight_limit[:, None]
48
+
49
+ worker_used_time = self.worker_used_time[:, None] + self.distance_this_to_task
50
+ mask |= worker_used_time > self.task_due_time
51
+
52
+ worker_used_time = torch.max(worker_used_time, self.task_ready_time)
53
+ worker_used_time += self.task_service_time
54
+ worker_used_time += self.distance_task_to_end
55
+ mask |= worker_used_time > self.worker_due_time[:, None]
56
+
57
+ return mask
58
+
59
+ def finished(self):
60
+ return torch.all(self.task_demand_now <= 0, 1)
61
+
62
+
63
+ class Objective:
64
+
65
+ def step_worker_start(self):
66
+ return self.worker_basic_cost
67
+
68
+ def step_worker_end(self):
69
+ return self.distance_last_to_this * self.worker_distance_cost
70
+
71
+ def step_task(self):
72
+ return self.distance_last_to_this * self.worker_distance_cost
73
+
74
+
75
+ def make_problem_from_json(data):
76
+ if isinstance(data, str):
77
+ data = json.loads(data)
78
+
79
+ problem = Problem()
80
+ problem.worker_weight_limit = torch.tensor(data['worker_weight_limit'], dtype=torch.float32)
81
+ problem.worker_ready_time = torch.tensor(data['worker_ready_time'], dtype=torch.float32)
82
+ problem.worker_due_time = torch.tensor(data['worker_due_time'], dtype=torch.float32)
83
+ problem.worker_basic_cost = torch.tensor(data['worker_basic_cost'], dtype=torch.float32)
84
+ problem.worker_distance_cost = torch.tensor(data['worker_distance_cost'], dtype=torch.float32)
85
+
86
+ problem.task_demand = torch.tensor(data['task_demand'], dtype=torch.int32)
87
+ problem.task_weight = torch.tensor(data['task_weight'], dtype=torch.float32)
88
+ problem.task_ready_time = torch.tensor(data['task_ready_time'], dtype=torch.float32)
89
+ problem.task_due_time = torch.tensor(data['task_due_time'], dtype=torch.float32)
90
+ problem.task_service_time = torch.tensor(data['task_service_time'], dtype=torch.float32)
91
+
92
+ problem.distance_matrix = torch.tensor(data['distance_matrix'], dtype=torch.float32);
93
+
94
+ problem.features = features
95
+ problem.variables = variables
96
+ problem.constraint = Constraint
97
+ problem.objective = Objective
98
+
99
+ return problem
100
+
101
+
102
+ def make_problem(batch_count, batch_size=1, task_count=100):
103
+ assert batch_size == 1
104
+
105
+ NT = task_count
106
+ problem_list = []
107
+ for i in range(batch_count):
108
+ problem = Problem()
109
+ problem.id = i
110
+
111
+ problem.worker_weight_limit = torch.tensor([50], dtype=torch.float32)
112
+ problem.worker_ready_time = torch.tensor([0], dtype=torch.float32)
113
+ problem.worker_due_time = torch.tensor([1000000], dtype=torch.float32)
114
+ problem.worker_basic_cost = torch.tensor([100], dtype=torch.float32)
115
+ problem.worker_distance_cost = torch.tensor([1], dtype=torch.float32)
116
+
117
+ problem.task_demand = torch.randint(1, 10, (NT,), dtype=torch.int32)
118
+ problem.task_weight = torch.ones(NT, dtype=torch.float32)
119
+ problem.task_ready_time = torch.zeros(NT, dtype=torch.float32)
120
+ problem.task_due_time = torch.randint(10000, 100000, (NT,), dtype=torch.float32)
121
+ problem.task_service_time = torch.zeros(NT, dtype=torch.float32)
122
+
123
+ loc = torch.rand(NT + 1, 2, dtype=torch.float32)
124
+ problem.distance_matrix = torch.norm(loc[:, None, :] - loc[None, :, :], dim=2) * 1000
125
+ problem_list.append(problem)
126
+
127
+ problem.features = features
128
+ problem.variables = variables
129
+ problem.constraint = Constraint
130
+ problem.objective = Objective
131
+
132
+ return problem_list
133
+
134
+
135
+ if __name__ == '__main__':
136
+ import sys
137
+ import os.path as osp
138
+ sys.path.append(osp.join(osp.dirname(__file__), '../'))
139
+ import runner
140
+
141
+ runner.run(make_problem)
greedrl/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.c
2
+ version.py
greedrl/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from .solver import Problem, Solution, Solver
4
+ from .const import GRL_WORKER_START, GRL_WORKER_END, GRL_TASK, GRL_FINISH
5
+
6
+
7
+ greedrl = sys.modules[__name__]
8
+
greedrl/agent.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch import nn
4
+ from collections import OrderedDict
5
+ from torch.utils.checkpoint import checkpoint
6
+ from .feature import *
7
+ from .pyenv import PyEnv
8
+ from .encode import Encode
9
+ from .decode import Decode
10
+
11
+
12
+ class Agent(nn.Module):
13
+
14
+ def __init__(self, nn_args):
15
+ super(Agent, self).__init__()
16
+
17
+ self.nn_args = nn_args
18
+ self.vars_dim = sum(nn_args['variable_dim'].values())
19
+ self.steps_ratio = nn_args.setdefault('decode_steps_ratio', 1.0);
20
+
21
+ logit_clips = nn_args.setdefault('decode_logit_clips', 10.0);
22
+ if isinstance(logit_clips, str):
23
+ self.logit_clips = [float(v) for v in logit_clips.split(',')]
24
+ else:
25
+ self.logit_clips = [float(logit_clips)]
26
+
27
+ self.nn_encode = Encode(nn_args)
28
+ self.nn_decode = Decode(nn_args)
29
+
30
+ def nn_args_dict(self):
31
+ return self.nn_args
32
+
33
+ def forward(self, problem, batch_size, greedy=False, solution=None, memopt=0):
34
+ X, K, V = self.nn_encode(problem.feats, problem.batch_size,
35
+ problem.worker_num, problem.task_num, memopt)
36
+
37
+ return self.interact(problem, X, K, V, batch_size, greedy, solution, memopt)
38
+
39
+ def interact(self, problem, X, K, V, batch_size, greedy, solution, memopt):
40
+ NP = problem.batch_size
41
+ NW = problem.worker_num
42
+ NT = problem.task_num
43
+
44
+ sample_num = batch_size // NP
45
+ assert sample_num > 0 and batch_size % NP == 0
46
+
47
+ MyEnv = problem.environment
48
+ if MyEnv is None:
49
+ env = PyEnv(problem, batch_size, sample_num, self.nn_args)
50
+ else:
51
+ env = MyEnv(str(problem.device), problem.feats, batch_size,
52
+ sample_num, problem.worker_num, problem.task_num)
53
+
54
+ query = X.new_zeros(batch_size, X.size(-1))
55
+ state1 = X.new_zeros(batch_size, X.size(-1))
56
+ state2 = X.new_zeros(batch_size, X.size(-1))
57
+
58
+ p_list = []
59
+ NULL = X.new_ones(0)
60
+ p_index = torch.div(torch.arange(batch_size, device=X.device), sample_num, rounding_mode='trunc') # torch.arange(batch_size, device=X.device) // sample_num
61
+ if solution is not None:
62
+ solution = solution[:, :, 0:2].to(torch.int64).permute(1, 0, 2)
63
+ assert torch.all(solution >= 0) and solution.size(1) == batch_size
64
+ offset = torch.tensor([0, NW, NW + NW, NW + NW + NT], device=X.device)
65
+ chosen_list = solution[:, :, 1] + offset[solution[:, :, 0]]
66
+
67
+ mode = 0
68
+ sample_p = torch.rand(batch_size, device=X.device)
69
+ for chosen in chosen_list:
70
+ env_time = env.time()
71
+ clip = self.logit_clips[min(env_time, len(self.logit_clips) - 1)]
72
+ varfeat = env.make_feat() if self.vars_dim > 0 else NULL
73
+ state1, state2, chosen_p = self.decode(X, K, V, query, state1, state2,
74
+ varfeat, env.mask(), chosen, sample_p, clip, mode, memopt)
75
+ query = X[p_index, chosen]
76
+ p_list.append(chosen_p)
77
+ env.step(chosen)
78
+
79
+ assert env.all_finished(), 'not all finished!'
80
+ else:
81
+ mode = 1 if greedy else 2
82
+ min_env_time = int(self.steps_ratio * NT)
83
+ R = torch.rand(NT * 2, batch_size, device=X.device)
84
+ while True:
85
+ env_time = env.time()
86
+ if env_time > min_env_time and env_time % 3 == 0 and env.all_finished():
87
+ break
88
+
89
+ clip = self.logit_clips[min(env_time, len(self.logit_clips) - 1)]
90
+ sample_p = R[env_time % R.size(0)]
91
+ chosen = X.new_empty(batch_size, dtype=torch.int64)
92
+ varfeat = env.make_feat() if self.vars_dim > 0 else NULL
93
+ state1, state2, chosen_p = self.decode(X, K, V, query, state1, state2,
94
+ varfeat, env.mask(), chosen, sample_p, clip, mode, memopt)
95
+ query = X[p_index, chosen]
96
+ p_list.append(chosen_p)
97
+ env.step(chosen)
98
+
99
+ env.finalize()
100
+ return env, torch.stack(p_list, 1)
101
+
102
+ def decode(self, X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p, clip, mode, memopt):
103
+ run_fn = self.decode_fn(clip, mode, memopt)
104
+ if self.training and memopt > 3:
105
+ return checkpoint(run_fn, X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p)
106
+ else:
107
+ return run_fn(X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p)
108
+
109
+ def decode_fn(self, clip, mode, memopt):
110
+ memopt = 0 if memopt > 3 else memopt
111
+
112
+ def run_fn(X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p):
113
+ return self.nn_decode(X, K, V, query, state1, state2,
114
+ varfeat, mask, chosen, sample_p, clip, mode, memopt)
115
+
116
+ return run_fn
117
+
118
+
119
+ def parse_nn_args(problem, nn_args):
120
+ worker_dim = OrderedDict()
121
+ task_dim = OrderedDict()
122
+ edge_dim = OrderedDict()
123
+ variable_dim = OrderedDict()
124
+ embed_dict = OrderedDict()
125
+
126
+ def set_dim_by_name(name, k, dim):
127
+ if name.startswith("worker_task_"):
128
+ edge_dim[k] = dim
129
+ elif name.startswith("worker_"):
130
+ worker_dim[k] = dim
131
+ elif name.startswith("task_"):
132
+ task_dim[k] = dim
133
+ elif name.endswith("_matrix"):
134
+ edge_dim[k] = dim
135
+ else:
136
+ raise Exception("attribute can't be feature: {}".format(k))
137
+
138
+ feature_dict = make_feat_dict(problem)
139
+ variables = [var(problem, problem.batch_size, 1) for var in problem.variables]
140
+ variable_dict = dict([(var.name, var) for var in variables])
141
+ for k, f in feature_dict.items():
142
+ if isinstance(f, VariableFeature):
143
+ var = variable_dict[f.name]
144
+ assert hasattr(var, 'make_feat'), \
145
+ "{} cann't be variable feature, name:{}".format(type(var).__name__, k)
146
+ v = var.make_feat()
147
+ if v.dim() == 2:
148
+ variable_dim[k] = 1
149
+ else:
150
+ variable_dim[k] = v.size(-1)
151
+ elif isinstance(f, SparseLocalFeature):
152
+ edge_dim[k] = 1
153
+ set_dim_by_name(f.value, k, 1)
154
+ elif isinstance(f, LocalFeature):
155
+ edge_dim[k] = 1
156
+ set_dim_by_name(f.name, k, 1)
157
+ elif isinstance(f, LocalCategory):
158
+ edge_dim[k] = 1
159
+ elif isinstance(f, GlobalCategory):
160
+ set_dim_by_name(f.name, k, nn_args.setdefault('encode_hidden_dim', 128))
161
+ embed_dict[k] = f.size
162
+ elif isinstance(f, ContinuousFeature):
163
+ v = problem.feats[k]
164
+ if k.startswith("worker_task_") or k.endswith("_matrix"):
165
+ simple_dim = 3
166
+ else:
167
+ simple_dim = 2
168
+
169
+ if v.dim() == simple_dim:
170
+ set_dim_by_name(f.name, k, 1)
171
+ else:
172
+ set_dim_by_name(f.name, k, v.size(-1))
173
+ else:
174
+ raise Exception("unsupported feature type: {}".format(type(f)))
175
+
176
+ nn_args['worker_dim'] = worker_dim
177
+ nn_args['task_dim'] = task_dim
178
+ nn_args['edge_dim'] = edge_dim
179
+ nn_args['variable_dim'] = variable_dim
180
+ nn_args['embed_dict'] = embed_dict
181
+ nn_args['feature_dict'] = feature_dict
182
+ return nn_args
183
+
184
+
185
+ def make_feat_dict(problem):
186
+ feature_dict = OrderedDict()
187
+
188
+ def add(k, f):
189
+ _f = feature_dict.get(k)
190
+ if _f is None or _f == f:
191
+ feature_dict[k] = f
192
+ else:
193
+ "duplicated feature, name: {}, feature1: {}, feature2: {}".format(k, _f, f)
194
+
195
+ for f in problem.features:
196
+ if isinstance(f, VariableFeature):
197
+ add(':'.join(['var', f.name]), f)
198
+ elif isinstance(f, SparseLocalFeature):
199
+ add(':'.join([f.index, f.value]), f)
200
+ else:
201
+ add(f.name, f)
202
+
203
+ return feature_dict
greedrl/const.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+
2
+ GRL_WORKER_START = 0
3
+ GRL_WORKER_END = 1
4
+ GRL_TASK = 2
5
+ GRL_FINISH = 3
6
+
7
+
greedrl/decode.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from torch import nn
6
+ from torch.utils.checkpoint import checkpoint
7
+
8
+
9
+ class MultiHeadAttention(nn.Module):
10
+ def __init__(self, heads, hidden_dim):
11
+ super(MultiHeadAttention, self).__init__()
12
+
13
+ assert hidden_dim % heads == 0
14
+
15
+ self.heads = heads
16
+ head_dim = hidden_dim // heads
17
+ self.alpha = 1 / math.sqrt(head_dim)
18
+
19
+ self.nn_Q = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
20
+ self.nn_O = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
21
+
22
+ for param in self.parameters():
23
+ stdv = 1. / math.sqrt(param.size(-1))
24
+ param.data.uniform_(-stdv, stdv)
25
+
26
+ def forward(self, q, K, V, mask):
27
+ batch_size, query_num, hidden_dim = q.size()
28
+
29
+ size = (self.heads, batch_size, query_num, -1)
30
+
31
+ q = q.reshape(-1, hidden_dim)
32
+ Q = torch.matmul(q, self.nn_Q).view(size)
33
+
34
+ value_num = V.size(2)
35
+ heads_batch = self.heads * batch_size
36
+ Q = Q.view(heads_batch, query_num, -1)
37
+ K = K.view(heads_batch, value_num, -1).transpose(1, 2)
38
+
39
+ S = masked_tensor(mask, self.heads)
40
+ S = S.view(heads_batch, query_num, value_num)
41
+ S.baddbmm_(Q, K, alpha=self.alpha)
42
+ S = S.view(self.heads, batch_size, query_num, value_num)
43
+
44
+ S = F.softmax(S, dim=-1)
45
+
46
+ x = torch.matmul(S, V).permute(1, 2, 0, 3)
47
+ x = x.reshape(batch_size, query_num, -1)
48
+ x = torch.matmul(x, self.nn_O)
49
+ return x
50
+
51
+
52
+ class Decode(nn.Module):
53
+
54
+ def __init__(self, nn_args):
55
+ super(Decode, self).__init__()
56
+
57
+ self.nn_args = nn_args
58
+
59
+ heads = nn_args['decode_atten_heads']
60
+ hidden_dim = nn_args['decode_hidden_dim']
61
+
62
+ self.heads = heads
63
+ self.alpha = 1 / math.sqrt(hidden_dim)
64
+
65
+ if heads > 0:
66
+ assert hidden_dim % heads == 0
67
+ head_dim = hidden_dim // heads
68
+ self.nn_K = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
69
+ self.nn_V = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
70
+ self.nn_mha = MultiHeadAttention(heads, hidden_dim)
71
+
72
+ decode_rnn = nn_args.setdefault('decode_rnn', 'LSTM')
73
+ assert decode_rnn in ('GRU', 'LSTM', 'NONE')
74
+ if decode_rnn == 'GRU':
75
+ self.nn_rnn_cell = nn.GRUCell(hidden_dim, hidden_dim)
76
+ elif decode_rnn == 'LSTM':
77
+ self.nn_rnn_cell = nn.LSTMCell(hidden_dim, hidden_dim)
78
+ else:
79
+ self.nn_rnn_cell = None
80
+
81
+ self.vars_dim = sum(nn_args['variable_dim'].values())
82
+ if self.vars_dim > 0:
83
+ atten_type = nn_args.setdefault('decode_atten_type', 'add')
84
+ assert atten_type == 'add', "must be addition attention when vars_dim > 0, {}".format(atten_type)
85
+ self.nn_A = nn.Parameter(torch.Tensor(self.vars_dim, hidden_dim))
86
+ self.nn_B = nn.Parameter(torch.Tensor(hidden_dim))
87
+ else:
88
+ atten_type = nn_args.setdefault('decode_atten_type', 'prod')
89
+
90
+ if atten_type == 'add':
91
+ self.nn_W = nn.Parameter(torch.Tensor(hidden_dim))
92
+ else:
93
+ self.nn_W = None
94
+
95
+ for param in self.parameters():
96
+ stdv = 1 / math.sqrt(param.size(-1))
97
+ param.data.uniform_(-stdv, stdv)
98
+
99
+ def forward(self, X, K, V, query, state1, state2, varfeat, mask, chosen, sample_p, clip, mode, memopt=0):
100
+ if self.training and memopt > 2:
101
+ state1, state2 = checkpoint(self.rnn_step, query, state1, state2)
102
+ else:
103
+ state1, state2 = self.rnn_step(query, state1, state2)
104
+
105
+ query = state1
106
+ NP = X.size(0)
107
+ NR = query.size(0) // NP
108
+ batch_size = query.size(0)
109
+ if self.heads > 0:
110
+ query = query.view(NP, NR, -1)
111
+ if self.training and memopt > 1:
112
+ query = checkpoint(self.nn_mha, query, K, V, mask)
113
+ else:
114
+ query = self.nn_mha(query, K, V, mask)
115
+
116
+ query = query.view(batch_size, -1)
117
+
118
+ if self.nn_W is None:
119
+ query = query.view(NP, NR, -1)
120
+ logit = masked_tensor(mask, 1)
121
+ logit = logit.view(NP, NR, -1)
122
+ X = X.permute(0, 2, 1)
123
+ logit.baddbmm_(query, X, alpha=self.alpha)
124
+ logit = logit.view(batch_size, -1)
125
+ else:
126
+ if self.training and self.vars_dim > 0 and memopt > 0:
127
+ logit = checkpoint(self.atten, query, X, varfeat, mask)
128
+ else:
129
+ logit = self.atten(query, X, varfeat, mask)
130
+
131
+ chosen_p = choose(logit, chosen, sample_p, clip, mode)
132
+ return state1, state2, chosen_p
133
+
134
+ def rnn_step(self, query, state1, state2):
135
+ if isinstance(self.nn_rnn_cell, nn.GRUCell):
136
+ state1 = self.nn_rnn_cell(query, state1)
137
+ elif isinstance(self.nn_rnn_cell, nn.LSTMCell):
138
+ state1, state2 = self.nn_rnn_cell(query, (state1, state2))
139
+ return state1, state2
140
+
141
+ def atten(self, query, keyvalue, varfeat, mask):
142
+ if self.vars_dim > 0:
143
+ varfeat = vfaddmm(varfeat, mask, self.nn_A, self.nn_B)
144
+ return atten(query, keyvalue, varfeat, mask, self.nn_W)
145
+
146
+
147
+ def choose(logit, chosen, sample_p, clip, mode):
148
+ mask = logit == -math.inf
149
+ logit = torch.tanh(logit) * clip
150
+ logit[mask] = -math.inf
151
+
152
+ if mode == 0:
153
+ pass
154
+ elif mode == 1:
155
+ chosen[:] = logit.argmax(1)
156
+ elif mode == 2:
157
+ p = logit.exp()
158
+ chosen[:] = torch.multinomial(p, 1).squeeze(1)
159
+ else:
160
+ raise Exception()
161
+
162
+ logp = logit.log_softmax(1)
163
+ logp = logp.gather(1, chosen[:, None])
164
+ logp = logp.squeeze(1)
165
+ return logp
166
+
167
+
168
+ def atten(query, keyvalue, varfeat, mask, weight):
169
+ batch_size = query.size(0)
170
+ NP, NK, ND = keyvalue.size()
171
+
172
+ query = query.view(NP, -1, 1, ND)
173
+ varfeat = varfeat.view(NP, -1, NK, ND)
174
+ keyvalue = keyvalue[:, None, :, :]
175
+ keyvalue = keyvalue + varfeat + query
176
+ keyvalue = torch.tanh(keyvalue)
177
+ keyvalue = keyvalue.view(-1, ND)
178
+
179
+ logit = masked_tensor(mask, 1).view(-1)
180
+ logit.addmv_(keyvalue, weight)
181
+ return logit.view(batch_size, -1)
182
+
183
+
184
+ def masked_tensor(mask, heads):
185
+ size = list(mask.size())
186
+ size.insert(0, heads)
187
+ mask = mask[None].expand(size)
188
+ result = mask.new_zeros(size, dtype=torch.float32)
189
+ result[mask] = -math.inf
190
+ return result
191
+
192
+
193
+ def vfaddmm(varfeat, mask, A, B):
194
+ varfeat = varfeat.permute(0, 2, 1)
195
+ return F.linear(varfeat, A.permute(1, 0), B)
196
+
greedrl/dense.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from .utils import get_act
4
+ from .norm import Norm1D, Norm2D
5
+
6
+
7
+ class Dense(nn.Module):
8
+
9
+ def __init__(self, input_dim, output_dim, bias=True, norm1d='none', norm2d='none', act='none'):
10
+ super(Dense, self).__init__()
11
+ assert norm1d == 'none' or norm2d == 'none', "one of [norm1d, norm2d] must be none"
12
+
13
+ if norm1d != 'none':
14
+ self.nn_norm = Norm1D(input_dim, norm1d)
15
+ elif norm2d != 'none':
16
+ self.nn_norm = Norm2D(input_dim, norm2d)
17
+ else:
18
+ self.nn_norm = None
19
+
20
+ self.nn_act = get_act(act)
21
+ self.nn_linear = nn.Linear(input_dim, output_dim, bias)
22
+
23
+ def weight(self):
24
+ return self.nn_linear.weight
25
+
26
+ def forward(self, x):
27
+ if self.nn_norm is not None:
28
+ x = self.nn_norm(x)
29
+ x = self.nn_act(x)
30
+ x = self.nn_linear(x)
31
+ return x
greedrl/encode.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from torch import nn
6
+ from torch.utils.checkpoint import checkpoint
7
+ from .norm import Norm1D, Norm2D
8
+ from .dense import Dense
9
+ from .utils import repeat
10
+ from .feature import *
11
+
12
+
13
+ class MultiHeadAttention(nn.Module):
14
+ def __init__(self, heads, hidden_dim):
15
+ super(MultiHeadAttention, self).__init__()
16
+
17
+ assert hidden_dim % heads == 0
18
+
19
+ self.heads = heads
20
+ head_dim = hidden_dim // heads
21
+ self.alpha = 1 / math.sqrt(head_dim)
22
+
23
+ self.nn_Q = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
24
+ self.nn_K = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
25
+ self.nn_V = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
26
+ self.nn_O = nn.Parameter(torch.Tensor(hidden_dim, hidden_dim))
27
+
28
+ for param in self.parameters():
29
+ stdv = 1. / math.sqrt(param.size(-1))
30
+ param.data.uniform_(-stdv, stdv)
31
+
32
+ def forward(self, x, edge):
33
+ batch_size, item_num, hidden_dim = x.size()
34
+ size = (self.heads, batch_size, item_num, -1)
35
+
36
+ x = x.reshape(-1, hidden_dim)
37
+ Q = torch.matmul(x, self.nn_Q).view(size)
38
+ K = torch.matmul(x, self.nn_K).view(size)
39
+ V = torch.matmul(x, self.nn_V).view(size)
40
+
41
+ heads_batch = self.heads * batch_size
42
+ Q = Q.view(heads_batch, item_num, -1)
43
+ K = K.view(heads_batch, item_num, -1).transpose(1, 2)
44
+
45
+ if edge is not None:
46
+ S = edge.view(heads_batch, item_num, item_num)
47
+ S = S.baddbmm(Q, K, alpha=self.alpha)
48
+ else:
49
+ S = Q.new_zeros(heads_batch, item_num, item_num)
50
+ S = S.baddbmm_(Q, K, alpha=self.alpha)
51
+
52
+ S = S.view(self.heads, batch_size, item_num, item_num)
53
+
54
+ S = F.softmax(S, dim=-1)
55
+
56
+ x = torch.matmul(S, V).permute(1, 2, 0, 3)
57
+ x = x.reshape(batch_size, item_num, -1)
58
+ x = torch.matmul(x, self.nn_O)
59
+ return x
60
+
61
+
62
+ class Encode(nn.Module):
63
+ def __init__(self, nn_args):
64
+ super(Encode, self).__init__()
65
+
66
+ self.nn_args = nn_args
67
+ self.worker_dim = nn_args['worker_dim']
68
+ self.task_dim = nn_args['task_dim']
69
+ self.edge_dim = nn_args['edge_dim']
70
+
71
+ self.embed_dict = nn_args['embed_dict']
72
+ self.feature_dict = nn_args['feature_dict']
73
+
74
+ layers = nn_args.setdefault('encode_layers', 3)
75
+ heads = nn_args.setdefault('encode_atten_heads', 8)
76
+ norm = nn_args.setdefault('encode_norm', 'instance')
77
+ hidden_dim = nn_args.setdefault('encode_hidden_dim', 128)
78
+ output_dim = nn_args.setdefault('decode_hidden_dim', 128)
79
+ output_heads = nn_args.setdefault('decode_atten_heads', 0)
80
+
81
+ self.heads = heads
82
+ self.layers = layers
83
+
84
+ worker_dim = max(1, sum(self.worker_dim.values()))
85
+ task_dim = max(1, sum(self.task_dim.values()))
86
+
87
+ self.nn_dense_worker_start = Dense(worker_dim, hidden_dim)
88
+ self.nn_dense_worker_end = Dense(worker_dim, hidden_dim)
89
+ self.nn_dense_task = Dense(task_dim, hidden_dim)
90
+
91
+ self.nn_norm_worker_task = Norm1D(hidden_dim, norm, True)
92
+
93
+ if len(self.edge_dim) > 0:
94
+ edge_dim = sum(self.edge_dim.values())
95
+ self.nn_dense_edge = Dense(edge_dim, heads)
96
+ self.nn_norm_edge = Norm2D(heads, norm, True)
97
+
98
+ nn_embed_dict = {}
99
+ for k, v in self.embed_dict.items():
100
+ nn_embed_dict[k] = nn.Embedding(v, hidden_dim)
101
+ self.nn_embed_dict = nn.ModuleDict(nn_embed_dict)
102
+
103
+ self.nn_attens = nn.ModuleList()
104
+ self.nn_denses = nn.ModuleList()
105
+ self.nn_norms1 = nn.ModuleList()
106
+ self.nn_norms2 = nn.ModuleList()
107
+ for i in range(layers):
108
+ self.nn_attens.append(MultiHeadAttention(heads, hidden_dim))
109
+ self.nn_denses.append(nn.Sequential(
110
+ Dense(hidden_dim, hidden_dim * 4),
111
+ Dense(hidden_dim * 4, hidden_dim, act='relu'),
112
+ ))
113
+ self.nn_norms1.append(Norm1D(hidden_dim, norm, True))
114
+ self.nn_norms2.append(Norm1D(hidden_dim, norm, True))
115
+
116
+ self.nn_finish = nn.Parameter(torch.Tensor(1, 1, hidden_dim))
117
+
118
+ if output_dim != hidden_dim:
119
+ self.nn_X = nn.Parameter(torch.Tensor(hidden_dim, output_dim))
120
+ else:
121
+ self.nn_X = None
122
+
123
+ if output_heads > 0:
124
+ assert output_dim % output_heads == 0
125
+ head_dim = output_dim // output_heads
126
+ self.nn_K = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
127
+ self.nn_V = nn.Parameter(torch.Tensor(heads, hidden_dim, head_dim))
128
+ else:
129
+ self.nn_K = None
130
+ self.nn_V = None
131
+
132
+ for param in self.parameters():
133
+ stdv = 1 / math.sqrt(param.size(-1))
134
+ param.data.uniform_(-stdv, stdv)
135
+
136
+ def forward(self, problem, batch_size, worker_num, task_num, memopt=0):
137
+ worker_start, worker_end = self.encode_worker(problem, batch_size, worker_num)
138
+ task = self.encode_task(problem, batch_size, task_num)
139
+ X = torch.cat([worker_start, worker_end, task], 1)
140
+ X = self.nn_norm_worker_task(X)
141
+
142
+ if len(self.edge_dim) > 0:
143
+ edge = self.encode_edge(problem, batch_size, worker_num, task_num)
144
+ edge = self.nn_norm_edge(edge)
145
+ edge = edge.permute(3, 0, 1, 2).contiguous()
146
+ else:
147
+ edge = None
148
+
149
+ #transformer encoding
150
+ for i in range(self.layers):
151
+ X = self.encode_layer(X, edge, i, memopt)
152
+
153
+ finish = repeat(self.nn_finish, X.size(0))
154
+ X = torch.cat([X, finish], 1)
155
+ if self.nn_X is not None:
156
+ X = torch.matmul(X, self.nn_X)
157
+
158
+ if self.nn_K is not None:
159
+ batch_size, item_num, hidden_dim = X.size()
160
+ size = (self.heads, batch_size, item_num, -1)
161
+ X2 = X.reshape(-1, hidden_dim)
162
+ K = torch.matmul(X2, self.nn_K).view(size)
163
+ V = torch.matmul(X2, self.nn_V).view(size)
164
+ else:
165
+ K = torch.ones(0)
166
+ V = torch.ones(0)
167
+ return X, K, V
168
+
169
+ def encode_layer(self, X, edge, i, memopt):
170
+ run_fn = self.encode_layer_fn(i, memopt)
171
+ if self.training and memopt > 6:
172
+ return checkpoint(run_fn, X, edge)
173
+ else:
174
+ return run_fn(X, edge)
175
+
176
+ def encode_layer_fn(self, i, memopt):
177
+ def run_fn(X, edge):
178
+ if self.training and memopt == 6:
179
+ X = X + checkpoint(self.nn_attens[i], X, edge)
180
+ else:
181
+ X = X + self.nn_attens[i](X, edge)
182
+ X = self.nn_norms1[i](X)
183
+
184
+ X = X + self.nn_denses[i](X)
185
+ X = self.nn_norms2[i](X)
186
+ return X
187
+
188
+ return run_fn
189
+
190
+ def encode_worker(self, problem, batch_size, worker_num):
191
+ feature_list = []
192
+ for k, dim in self.worker_dim.items():
193
+ f = self.feature_dict.get(k)
194
+ if isinstance(f, GlobalCategory):
195
+ v = problem[f.name]
196
+ v = self.nn_embed_dict[k](v.long())
197
+ elif isinstance(f, ContinuousFeature):
198
+ v = problem[f.name]
199
+ else:
200
+ raise Exception("unsupported feature type: {}".format(type(f)))
201
+
202
+ if v.dim() == 2:
203
+ v = v[:, :, None]
204
+
205
+ assert dim == v.size(-1), \
206
+ "feature dim error, feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1))
207
+
208
+ feature_list.append(v.float())
209
+
210
+ if feature_list:
211
+ x = torch.cat(feature_list, 2)
212
+ else:
213
+ x = self.nn_finish.new_ones(batch_size, worker_num, 1)
214
+ return self.nn_dense_worker_start(x), self.nn_dense_worker_end(x)
215
+
216
+ def encode_task(self, problem, batch_size, task_num):
217
+ feature_list = []
218
+ for k, dim in self.task_dim.items():
219
+ f = self.feature_dict.get(k)
220
+ if isinstance(f, SparseLocalFeature):
221
+ v = problem[f.value]
222
+ assert v.dim() == 3, \
223
+ "sparse local feature's dimension must 2, feature:{}".format(k)
224
+ v = v.clamp(0, 1).sum(2, dtype=v.dtype)
225
+ elif isinstance(f, GlobalCategory):
226
+ v = problem[f.name]
227
+ v = self.nn_embed_dict[k](v.long())
228
+ elif isinstance(f, LocalFeature):
229
+ v = problem[f.name]
230
+ assert v.dim() == 3, \
231
+ "local feature's dimension must 2, feature:{}".format(k)
232
+ v = v.clamp(0, 1).sum(2, dtype=v.dtype)
233
+ elif isinstance(f, ContinuousFeature):
234
+ v = problem[f.name]
235
+ else:
236
+ raise Exception("unsupported feature type: {}".format(type(f)))
237
+
238
+ if v.dim() == 2:
239
+ v = v[:, :, None]
240
+
241
+ assert dim == v.size(-1), \
242
+ "feature dim error, feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1))
243
+
244
+ feature_list.append(v.float())
245
+
246
+ if feature_list:
247
+ x = torch.cat(feature_list, 2)
248
+ else:
249
+ x = self.nn_finish.new_ones(batch_size, task_num, 1)
250
+ return self.nn_dense_task(x)
251
+
252
+ def encode_edge(self, problem, batch_size, worker_num, task_num):
253
+ NP = batch_size
254
+ NW = worker_num
255
+ NT = task_num
256
+ NWW = NW + NW
257
+ feature_list = []
258
+ for k, dim in self.edge_dim.items():
259
+ f = self.feature_dict.get(k)
260
+ if isinstance(f, LocalCategory):
261
+ assert f.name.startswith("task_")
262
+
263
+ v = problem[k]
264
+ v1 = v[:, :, None]
265
+ v2 = v[:, None, :]
266
+
267
+ v = torch.zeros(NP, NWW + NT, NWW + NT,
268
+ dtype=v.dtype, device=v.device)
269
+ v[:, NWW:, NWW:] = ((v1 == v2) & (v1 >= 0))
270
+ elif isinstance(f, LocalFeature):
271
+ assert f.name.startswith("task_")
272
+
273
+ v = problem[k].float()
274
+ dot_product = torch.matmul(v, v.transpose(-1, -2))
275
+ v_norm = v.norm(dim=2) + 1e-10
276
+ v1_norm = v_norm[:, :, None]
277
+ v2_norm = v_norm[:, None, :]
278
+
279
+ v = torch.zeros(NP, NWW + NT, NWW + NT,
280
+ dtype=v.dtype, device=v.device)
281
+ v[:, NWW:, NWW:] = dot_product / v1_norm / v2_norm
282
+ elif isinstance(f, SparseLocalFeature):
283
+ assert NP == 1
284
+ assert f.index.startswith("task_")
285
+ assert f.value.startswith("task_")
286
+
287
+ index = problem[f.index]
288
+ value = problem[f.value].float()
289
+
290
+ NV = index.max().item() + 1
291
+ spv = value.reshape(-1).tolist()
292
+ spi = index.reshape(-1).tolist()
293
+
294
+ device = value.device
295
+ spj = torch.arange(NT, device=device)
296
+ spj = spj[:, None].expand_as(index)
297
+ spj = spj.reshape(-1).tolist()
298
+
299
+ value1 = torch.sparse_coo_tensor([spj, spi], spv, (NT, NV), device=device)
300
+ value2 = torch.sparse_coo_tensor([spi, spj], spv, (NV, NT), device=device)
301
+
302
+ value1 = value1.coalesce()
303
+ value2 = value2.coalesce()
304
+ cosine = torch.sparse.mm(value1, value2).to_dense()
305
+
306
+ norm = value.norm(dim=-1).reshape(-1)
307
+ norm1 = norm[:, None].expand(-1, NT)
308
+ norm2 = norm[None, :].expand(NT, -1)
309
+ cosine = cosine / (norm1 * norm2 + 1e-10)
310
+
311
+ v = torch.zeros(NP, NWW + NT, NWW + NT,
312
+ dtype=value.dtype, device=value.device)
313
+ v[:, NWW:, NWW:] = cosine
314
+
315
+ elif isinstance(f, ContinuousFeature):
316
+ if f.name.endswith("_matrix"):
317
+ v = problem[k]
318
+ elif f.name.startswith("worker_task_"):
319
+ v = problem[k]
320
+ if v.dim() == 3:
321
+ new_v = torch.zeros(NP, NWW + NT, NWW + NT,
322
+ dtype=v.dtype, device=v.device)
323
+ else:
324
+ new_v = torch.zeros(NP, NWW + NT, NWW + NT, v.size(3),
325
+ dtype=v.dtype, device=v.device)
326
+ problem_index = torch.arange(NP, device=v.device)[:, None, None]
327
+ worker_index = torch.arange(NW, device=v.device)[None, :, None]
328
+ task_index = torch.arange(NT, device=v.device)[None, None, :] + NW + NW
329
+ new_v[problem_index, worker_index, task_index] = v
330
+ new_v[problem_index, task_index, worker_index] = v
331
+ new_v[problem_index, worker_index + NW, task_index] = v
332
+ new_v[problem_index, task_index, worker_index + NW] = v
333
+ v = new_v
334
+ else:
335
+ raise Exception("feature: {}".format(f.name))
336
+ else:
337
+ raise Exception("feature: {}, type: {}".format(k, type(f)))
338
+
339
+ if v.dim() == 3:
340
+ v = v[:, :, :, None]
341
+
342
+ assert dim == v.size(-1), \
343
+ "feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1))
344
+
345
+ feature_list.append(v.float())
346
+
347
+ x = torch.cat(feature_list, 3)
348
+ return self.nn_dense_edge(x)
349
+
greedrl/feature.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def continuous_feature(name):
2
+ return ContinuousFeature(name)
3
+
4
+
5
+ class ContinuousFeature:
6
+ def __init__(self, name):
7
+ self.name = name
8
+
9
+
10
+ def global_category(name, size):
11
+ return GlobalCategory(name, size)
12
+
13
+
14
+ class GlobalCategory:
15
+ def __init__(self, name, size):
16
+ self.name = name
17
+ self.size = size
18
+
19
+
20
+ def local_category(name):
21
+ return LocalCategory(name)
22
+
23
+
24
+ class LocalCategory:
25
+ def __init__(self, name):
26
+ assert name.startswith('task_'), \
27
+ "only task feature supported: {}".format(name)
28
+ self.name = name
29
+
30
+
31
+ def local_feature(name):
32
+ return LocalFeature(name)
33
+
34
+
35
+ class LocalFeature:
36
+ def __init__(self, name):
37
+ assert name.startswith('task_'), \
38
+ "only task feature supported: {}".format(name)
39
+ self.name = name
40
+
41
+
42
+ def sparse_local_feature(index, value):
43
+ return SparseLocalFeature(index, value)
44
+
45
+
46
+ class SparseLocalFeature:
47
+ def __init__(self, index, value):
48
+ assert index.startswith('task_'), \
49
+ "only task feature supported for index: {}".format(index)
50
+ assert value.startswith('task_'), \
51
+ "only task feature supported for value: {}".format(value)
52
+
53
+ self.index = index
54
+ self.value = value
55
+
56
+
57
+ def variable_feature(name):
58
+ return VariableFeature(name)
59
+
60
+
61
+ class VariableFeature:
62
+ def __init__(self, name):
63
+ self.name = name
greedrl/function.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import greedrl_c
2
+ from greedrl_c import task_group_priority
3
+ from greedrl_c import task_group_split
4
+
5
+
greedrl/norm.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from torch import nn
4
+
5
+
6
+ class Norm1D(nn.Module):
7
+
8
+ def __init__(self, dim, ntype='batch', affine=False):
9
+ super(Norm1D, self).__init__()
10
+ clazz_dict = {'batch': nn.BatchNorm1d, 'instance': nn.InstanceNorm1d}
11
+ self.nn_norm = clazz_dict[ntype](dim, eps=1e-10, affine=affine)
12
+
13
+ def forward(self, x):
14
+ return self.nn_norm(x.permute(0, 2, 1)).permute(0, 2, 1)
15
+
16
+
17
+ class Norm2D(nn.Module):
18
+
19
+ def __init__(self, dim, ntype='batch', affine=False):
20
+ super(Norm2D, self).__init__()
21
+ clazz_dict = {'batch': nn.BatchNorm2d, 'instance': nn.InstanceNorm2d}
22
+ self.nn_norm = clazz_dict[ntype](dim, eps=1e-10, affine=affine)
23
+
24
+ def forward(self, x):
25
+ return self.nn_norm(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
greedrl/pyenv.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import math
4
+
5
+ from collections import OrderedDict
6
+ from .const import *
7
+ from .utils import to_list
8
+ from .norm import Norm1D, Norm2D
9
+ from .variable import AttributeVariable, WorkerTaskSequence
10
+
11
+
12
+ class PyEnv(object):
13
+
14
+ def __init__(self, problem, batch_size, sample_num, nn_args):
15
+ super(PyEnv, self).__init__()
16
+
17
+ self._problem = problem
18
+ self._batch_size = batch_size
19
+ self._sample_num = sample_num
20
+ self._debug = -1
21
+
22
+ self._NW = problem.worker_num
23
+ self._NWW = problem.worker_num * 2
24
+ self._NT = problem.task_num
25
+ self._NWWT = self._NWW + self._NT
26
+
27
+ self._feats_dict = nn_args['feature_dict']
28
+ self._vars_dim = nn_args['variable_dim']
29
+
30
+ self._vars_dict = {}
31
+ self._vars = [var(problem, batch_size, sample_num) for var in problem.variables]
32
+ for variable in self._vars:
33
+ save_variable_version(variable)
34
+ assert variable.name not in self._vars_dict, \
35
+ "duplicated variable, name: {}".format(variable.name)
36
+ self._vars_dict[variable.name] = variable
37
+
38
+ self._constraint = problem.constraint()
39
+ self._objective = problem.objective()
40
+
41
+ self._worker_index = torch.full((self._batch_size,), -1,
42
+ dtype=torch.int64,
43
+ device=problem.device)
44
+
45
+ self._batch_index = torch.arange(self._batch_size,
46
+ dtype=torch.int64,
47
+ device=problem.device)
48
+
49
+ self._problem_index = torch.div(self._batch_index, sample_num, rounding_mode='trunc') # self._batch_index // sample_num
50
+
51
+ self._feasible = torch.ones(self._batch_size,
52
+ dtype=torch.bool,
53
+ device=problem.device)
54
+
55
+ self._cost = torch.zeros(self._batch_size, self._NT * 2,
56
+ dtype=torch.float32,
57
+ device=problem.device)
58
+
59
+ self._mask = torch.zeros(self._batch_size,
60
+ self._NWWT + 1,
61
+ dtype=torch.bool,
62
+ device=problem.device)
63
+
64
+ self._worker_task_sequence = torch.full((self._batch_size, self._NT * 2, 3), -1,
65
+ dtype=torch.int64,
66
+ device=problem.device)
67
+ self._step = 0
68
+ self.register_variables(self._constraint)
69
+ self._finished = self._constraint.finished()
70
+
71
+ if hasattr(self._constraint, 'mask_worker_start'):
72
+ self.register_variables(self._constraint)
73
+ mask_start = self._constraint.mask_worker_start()
74
+ else:
75
+ mask_start = False
76
+
77
+ self._mask[:, :self._NW] = mask_start
78
+ self._mask[:, self._NW:] = True
79
+
80
+ if self._debug >= 0:
81
+ print("\n$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$")
82
+ print("new env")
83
+ print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n")
84
+
85
+ def time(self):
86
+ return self._step
87
+
88
+ def step(self, chosen):
89
+ with torch.no_grad():
90
+ self._do_step(chosen)
91
+
92
+ def _do_step(self, chosen):
93
+ if self._debug >= 0:
94
+ print("----------------------------------------------------------------------")
95
+ feasible = self._feasible & ~self._mask[self._problem_index, chosen]
96
+ print("feasible={}".format(feasible[self._debug].tolist()))
97
+
98
+ is_start = (chosen >= 0) & (chosen < self._NW)
99
+ if torch.any(is_start):
100
+ b_index = self._batch_index[is_start]
101
+ p_index = self._problem_index[is_start]
102
+ w_index = chosen[is_start]
103
+ self.step_worker_start(b_index, p_index, w_index)
104
+
105
+ is_end = (chosen >= self._NW) & (chosen < self._NWW)
106
+ if torch.any(is_end):
107
+ b_index = self._batch_index[is_end]
108
+ p_index = self._problem_index[is_end]
109
+ w_index = chosen[is_end] - self._NW
110
+ self.step_worker_end(b_index, p_index, w_index)
111
+
112
+ is_task = (chosen >= self._NWW) & (chosen < self._NWWT)
113
+ if torch.any(is_task):
114
+ b_index = self._batch_index[is_task]
115
+ p_index = self._problem_index[is_task]
116
+ t_index = chosen[is_task] - self._NWW
117
+ step_task_b_index = b_index
118
+ self.step_task(b_index, p_index, t_index)
119
+ else:
120
+ step_task_b_index = None
121
+
122
+ is_finish = chosen == self._NWWT
123
+ if torch.any(is_finish):
124
+ b_index = self._batch_index[is_finish]
125
+ self._worker_task_sequence[b_index, self._step, 0] = GRL_FINISH
126
+ self._worker_task_sequence[b_index, self._step, 1] = 0
127
+ self._worker_task_sequence[b_index, self._step, 2] = -1
128
+
129
+ self.update_mask(step_task_b_index)
130
+
131
+ for var in self._vars:
132
+ check_variable_version(var)
133
+
134
+ if self._debug >= 0:
135
+ print("worker_task_sequence[{}]={}".format(self._step,
136
+ self._worker_task_sequence[self._debug, self._step].tolist()))
137
+ for var in self._vars:
138
+ if var.value is None:
139
+ print("{}={}".format(var.name, None))
140
+ elif isinstance(var, AttributeVariable):
141
+ print("{}={}".format(var.name, to_list(var.value)))
142
+ else:
143
+ print("{}={}".format(var.name, to_list(var.value[self._debug])))
144
+
145
+ self._step += 1
146
+ if self._step >= self._cost.size(1):
147
+ cost = torch.zeros(self._batch_size, self._step + self._NT,
148
+ dtype=torch.float32,
149
+ device=chosen.device)
150
+ cost[:, 0:self._step] = self._cost;
151
+ self._cost = cost
152
+
153
+ worker_task_sequence = torch.full((self._batch_size, self._step + self._NT, 3), -1,
154
+ dtype=torch.int64,
155
+ device=chosen.device)
156
+ worker_task_sequence[:, 0:self._step, :] = self._worker_task_sequence
157
+ self._worker_task_sequence = worker_task_sequence
158
+
159
+ def step_worker_start(self, b_index, p_index, w_index):
160
+ self._worker_task_sequence[b_index, self._step, 0] = GRL_WORKER_START
161
+ self._worker_task_sequence[b_index, self._step, 1] = w_index
162
+ self._worker_task_sequence[b_index, self._step, 2] = -1
163
+ for var in self._vars:
164
+ if hasattr(var, 'step_worker_start'):
165
+ var.step_worker_start(b_index, p_index, w_index)
166
+ save_variable_version(var)
167
+
168
+ if hasattr(self._objective, 'step_worker_start'):
169
+ self.register_variables(self._objective, b_index)
170
+ self.update_cost(self._objective.step_worker_start(), b_index)
171
+
172
+ self._worker_index[b_index] = w_index
173
+ self._mask[b_index, :self._NWW] = True
174
+ self._mask[b_index, self._NWW:] = False
175
+
176
+ def step_worker_end(self, b_index, p_index, w_index):
177
+ self._worker_task_sequence[b_index, self._step, 0] = GRL_WORKER_END
178
+ self._worker_task_sequence[b_index, self._step, 1] = w_index
179
+ self._worker_task_sequence[b_index, self._step, 2] = -1;
180
+
181
+ for var in self._vars:
182
+ if hasattr(var, 'step_worker_end'):
183
+ var.step_worker_end(b_index, p_index, w_index)
184
+ save_variable_version(var)
185
+
186
+ if hasattr(self._objective, 'step_worker_end'):
187
+ self.register_variables(self._objective, b_index)
188
+ self.update_cost(self._objective.step_worker_end(), b_index)
189
+
190
+ self._worker_index[b_index] = -1
191
+
192
+ self.register_variables(self._constraint, b_index)
193
+ self._finished[b_index] |= self._constraint.finished()
194
+ if hasattr(self._constraint, 'mask_worker_start'):
195
+ mask_start = self._constraint.mask_worker_start()
196
+ else:
197
+ mask_start = False
198
+
199
+ self._mask[b_index, :self._NW] = mask_start
200
+ self._mask[b_index, self._NW:] = True
201
+
202
+ def step_task(self, b_index, p_index, t_index):
203
+ self._worker_task_sequence[b_index, self._step, 0] = GRL_TASK
204
+ self._worker_task_sequence[b_index, self._step, 1] = t_index
205
+
206
+ for var in self._vars:
207
+ if not hasattr(var, 'step_task'):
208
+ continue
209
+ elif var.step_task.__code__.co_argcount == 4:
210
+ var.step_task(b_index, p_index, t_index)
211
+ else:
212
+ var.step_task(b_index, p_index, t_index, None)
213
+ save_variable_version(var)
214
+
215
+ if hasattr(self._constraint, 'do_task'):
216
+ self.register_variables(self._constraint, b_index)
217
+ done = self._constraint.do_task()
218
+ self._worker_task_sequence[b_index, self._step, 2] = done.long()
219
+
220
+ for var in self._vars:
221
+ if not hasattr(var, 'step_task'):
222
+ continue
223
+ elif var.step_task.__code__.co_argcount == 4:
224
+ pass
225
+ else:
226
+ check_variable_version(var)
227
+ var.step_task(b_index, p_index, t_index, done)
228
+ save_variable_version(var)
229
+ else:
230
+ done = None
231
+
232
+ if hasattr(self._objective, 'step_task'):
233
+ self.register_variables(self._objective, b_index)
234
+ self.update_cost(self._objective.step_task(), b_index)
235
+
236
+ if hasattr(self._constraint, 'mask_worker_end'):
237
+ self.register_variables(self._constraint, b_index)
238
+ mask_end = self._constraint.mask_worker_end()
239
+ else:
240
+ mask_end = False
241
+
242
+ w_index = self._NW + self._worker_index[b_index]
243
+ self._mask[b_index, w_index] = mask_end
244
+ self._mask[b_index, self._NWW:] = False
245
+ return done
246
+
247
+ def update_cost(self, cost, b_index=None):
248
+ if isinstance(cost, tuple):
249
+ cost, feasible = cost
250
+ if b_index is None:
251
+ self._feasible &= feasible
252
+ else:
253
+ self._feasible[b_index] &= feasible
254
+
255
+ if isinstance(cost, torch.Tensor):
256
+ cost = cost.float()
257
+ else:
258
+ assert type(cost) in (int, float), "unexpected cost's type: {}".format(type(cost))
259
+
260
+ if b_index is None:
261
+ self._cost[:, self._step] = cost
262
+ else:
263
+ self._cost[b_index, self._step] = cost
264
+
265
+ def update_mask(self, step_task_b_index):
266
+ self._mask |= self._finished[:, None]
267
+ self._mask[:, -1] = ~self._finished
268
+ self.register_variables(self._constraint)
269
+ self._mask[:, self._NWW:self._NWWT] |= self._constraint.mask_task()
270
+
271
+ if step_task_b_index is not None:
272
+ b_index = step_task_b_index
273
+ w_index = self._NW + self._worker_index[b_index]
274
+ task_mask = self._mask[b_index, self._NWW:self._NWWT]
275
+ self._mask[b_index, w_index] &= ~torch.all(task_mask, 1)
276
+
277
+ def batch_size():
278
+ return self._batch_size
279
+
280
+ def sample_num():
281
+ return self._sample_num
282
+
283
+ def mask(self):
284
+ return self._mask.clone()
285
+
286
+ def cost(self):
287
+ return self._cost[:, 0:self._step]
288
+
289
+ def feasible(self):
290
+ return self._feasible
291
+
292
+ def worker_task_sequence(self):
293
+ return self._worker_task_sequence[:, 0:self._step]
294
+
295
+ def var(self, name):
296
+ return self._vars_dict[name].value
297
+
298
+ def register_variables(self, obj, b_index=None, finished=False):
299
+ for var in self._vars:
300
+ if var.value is None or b_index is None \
301
+ or isinstance(var, AttributeVariable):
302
+ value = var.value
303
+ else:
304
+ value = var.value[b_index]
305
+ obj.__dict__[var.name] = value
306
+
307
+ if not hasattr(var, 'ext_values'):
308
+ continue
309
+
310
+ for k, v in var.ext_values.items():
311
+ k = var.name + '_' + k
312
+ obj.__dict__[k] = v[b_index]
313
+
314
+ def finished(self):
315
+ return self._finished
316
+
317
+ def all_finished(self):
318
+ return torch.all(self.finished())
319
+
320
+ def finalize(self):
321
+ self._worker_task_sequence[:, self._step, 0] = GRL_FINISH
322
+ self._worker_task_sequence[:, self._step, 1] = 0
323
+ self._worker_task_sequence[:, self._step, 2] = -1
324
+
325
+ for var in self._vars:
326
+ if hasattr(var, 'step_finish'):
327
+ var.step_finish(self.worker_task_sequence())
328
+
329
+ if hasattr(self._objective, 'step_finish'):
330
+ self.register_variables(self._objective, finished=True)
331
+ self.update_cost(self._objective.step_finish())
332
+
333
+ self._step += 1
334
+
335
+ def make_feat(self):
336
+ with torch.no_grad():
337
+ return self.do_make_feat()
338
+
339
+ def do_make_feat(self):
340
+ if not self._vars_dim:
341
+ return None
342
+
343
+ feature_list = []
344
+ for k, dim in self._vars_dim.items():
345
+ f = self._feats_dict[k]
346
+ var = self._vars_dict[f.name]
347
+ v = var.make_feat()
348
+ if v.dim() == 2:
349
+ v = v[:, :, None]
350
+
351
+ assert dim == v.size(-1), \
352
+ "feature dim error, feature: {}, expected: {}, actual: {}".format(k, dim, v.size(-1))
353
+ feature_list.append(v.float())
354
+
355
+ v = torch.cat(feature_list, 2)
356
+ u = v.new_zeros(v.size(0), self._NWW, v.size(2))
357
+ f = v.new_zeros(v.size(0), 1, v.size(2))
358
+ v = torch.cat([u, v, f], 1).permute(0, 2, 1)
359
+
360
+ v[self._mask[:, None, :].expand(v.size())] = 0
361
+
362
+ norm = v.new_ones(self._mask.size())
363
+ norm[self._mask] = 0
364
+ norm = norm.sum(1) + 1e-10
365
+ norm = norm[:, None, None]
366
+
367
+ avg = v.sum(-1, keepdim=True) / norm
368
+ v = v - avg
369
+
370
+ std = v.norm(dim=-1, keepdim=True) / norm + 1e-10
371
+ v = v / std
372
+ return v.contiguous()
373
+
374
+
375
+ def save_variable_version(var):
376
+ if isinstance(var.value, torch.Tensor):
377
+ var.__version__ = var.value._version
378
+
379
+
380
+ def check_variable_version(var):
381
+ if isinstance(var.value, torch.Tensor):
382
+ assert var.__version__ == var.value._version, \
383
+ "variable's value is modified, name: {}".format(var.name)
greedrl/solver.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import copy
5
+ import time
6
+ import queue
7
+ import inspect
8
+ import torch
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+ import torch.distributed as dist
12
+
13
+ from .agent import Agent, parse_nn_args
14
+ from .utils import repeat, get_default_device, cutime_stats
15
+ from .variable import TaskDemandNow
16
+
17
+ from torch.nn.utils import clip_grad_norm_, parameters_to_vector, vector_to_parameters
18
+ from torch.utils.data import Dataset, IterableDataset, DataLoader
19
+ from torch.optim.lr_scheduler import MultiStepLR
20
+
21
+
22
+ class Problem(object):
23
+ def __init__(self, isbatch=False):
24
+ self.isbatch = isbatch
25
+ self.features = []
26
+ self.environment = None
27
+
28
+ def pin_memory(self):
29
+ for k, v in self.feats.items():
30
+ self.feats[k] = v.pin_memory()
31
+ return self
32
+
33
+ def __getattr__(self, name):
34
+ if name not in ('solution'):
35
+ raise AttributeError()
36
+ return self.feats.get(name)
37
+
38
+
39
+ class Solution(object):
40
+ def __init__(self, cost=None):
41
+ self.cost = cost
42
+ self.worker_task_sequence = None
43
+
44
+
45
+ class WrapDataset(Dataset):
46
+ def __init__(self, dataset, solver):
47
+ self._dataset = [solver.to_batch(p) for p in dataset]
48
+
49
+ def __getitem__(self, index):
50
+ return self._dataset[index]
51
+
52
+ def __len__(self):
53
+ return len(self._dataset)
54
+
55
+
56
+ class WrapIterator:
57
+ def __init__(self, iterator, solver):
58
+ self._iterator = iterator
59
+ self._solver = solver
60
+
61
+ def __next__(self):
62
+ p = next(self._iterator)
63
+ p = self._solver.to_batch(p, False)
64
+ return p
65
+
66
+
67
+ class WrapIterableDataset(IterableDataset):
68
+ def __init__(self, dataset, solver):
69
+ self._dataset = dataset
70
+ self._solver = solver
71
+
72
+ def __iter__(self):
73
+ return WrapIterator(iter(self._dataset), self._solver)
74
+
75
+
76
+ class CyclicIterator:
77
+ def __init__(self, iterable):
78
+ self._iterable = iterable
79
+ self._iterator = iter(iterable)
80
+
81
+ def __iter__(self):
82
+ return self
83
+
84
+ def __next__(self):
85
+ try:
86
+ return next(self._iterator)
87
+ except StopIteration:
88
+ self._iterator = iter(self._iterable)
89
+ return next(self._iterator)
90
+
91
+
92
+ class BufferedIterator:
93
+ def __init__(self, iterator, size, reuse):
94
+ self._iterator = iterator
95
+ self._reuse = reuse
96
+ self._queue = queue.Queue(size)
97
+ self._buffer = []
98
+ self._iter_step = 0
99
+
100
+ def __next__(self):
101
+ if not self._queue.full() or self._iter_step % self._reuse == 0:
102
+ problem = next(self._iterator)
103
+ if self._queue.full():
104
+ index = self._queue.get()
105
+ self._buffer[index] = problem
106
+ else:
107
+ index = len(self._buffer)
108
+ self._buffer.append(problem)
109
+ self._queue.put(index)
110
+ self._iter_step += 1
111
+ index = torch.randint(0, len(self._buffer), (1,)).item()
112
+ return self._buffer[index]
113
+
114
+
115
+ class Solver(object):
116
+ def __init__(self, device=None, nn_args=None):
117
+
118
+ if device is None:
119
+ self.device = get_default_device()
120
+ elif device == 'cuda':
121
+ self.device = get_default_device()
122
+ assert self.device.type == 'cuda', 'no cuda device available!'
123
+ else:
124
+ self.device = torch.device(device)
125
+
126
+ if nn_args is None:
127
+ nn_args = {}
128
+ self.nn_args = nn_args
129
+
130
+ self.agent = None
131
+
132
+ def parse_nn_args(self, problem):
133
+ parse_nn_args(problem, self.nn_args)
134
+
135
+ def new_agent(self):
136
+ return Agent(self.nn_args)
137
+
138
+ def train(self, agent_filename, train_dataset, valid_dataset, **kwargs):
139
+ if dist.is_initialized():
140
+ torch.manual_seed(torch.initial_seed() + dist.get_rank() * 20000)
141
+
142
+ train_dataset_workers = kwargs.pop('train_dataset_workers', 1)
143
+ train_dataset_buffers = kwargs.pop('train_dataset_buffers', 2)
144
+ valid_dataset_workers = kwargs.pop('valid_dataset_workers', 1)
145
+ valid_dataset_buffers = kwargs.pop('valid_dataset_buffers', 2)
146
+
147
+ train_dataset = self.wrap_dataset(train_dataset, train_dataset_workers,
148
+ train_dataset_buffers, torch.initial_seed() + 1)
149
+ valid_dataset = self.wrap_dataset(valid_dataset, valid_dataset_workers,
150
+ valid_dataset_buffers, torch.initial_seed() + 10001)
151
+
152
+ if self.device.type == 'cuda':
153
+ with torch.cuda.device(cuda_or_none(self.device)):
154
+ self.do_train(agent_filename, train_dataset, valid_dataset, **kwargs)
155
+ else:
156
+ self.do_train(agent_filename, train_dataset, valid_dataset, **kwargs)
157
+
158
+ def do_train(self, agent_filename, train_dataset, valid_dataset, reuse_buffer=0, reuse_times=1, on_policy=True,
159
+ advpow=1, batch_size=512, topk_size=1, init_lr=0.0001, sched_lr=(int(1e10),), gamma_lr=0.5,
160
+ warmup_steps=100, log_steps=-1, optim_steps=1, valid_steps=100, max_steps=int(1e10), memopt=1):
161
+
162
+ for arg in inspect.getfullargspec(self.do_train)[0][1:]:
163
+ if arg not in ('train_dataset', 'valid_dataset'):
164
+ print("train_args: {} = {}".format(arg, locals()[arg]))
165
+
166
+ if log_steps < 0:
167
+ log_steps = valid_steps
168
+
169
+ train_dataset = CyclicIterator(train_dataset)
170
+ if reuse_buffer > 0:
171
+ train_dataset = BufferedIterator(train_dataset, reuse_buffer, reuse_times)
172
+
173
+ valid_dataset = list(valid_dataset)
174
+
175
+ if dist.is_initialized() and dist.get_rank() != 0:
176
+ dist.barrier()
177
+
178
+ if agent_filename is not None and os.path.exists(agent_filename):
179
+ saved_state = torch.load(agent_filename, map_location='cpu')
180
+ self.nn_args = saved_state['nn_args']
181
+ else:
182
+ saved_state = None
183
+ self.parse_nn_args(valid_dataset[0])
184
+
185
+ step = 0
186
+ start_step = 0
187
+ self.agent = self.new_agent().train()
188
+ self.agent.to(self.device)
189
+ self.print_nn_args()
190
+
191
+ best_agent = copy.deepcopy(self.agent).eval()
192
+ min_valid_cost = math.inf
193
+
194
+ optimizer = torch.optim.Adam(self.agent.parameters(), lr=init_lr)
195
+ scheduler = MultiStepLR(optimizer, milestones=sched_lr, gamma=gamma_lr)
196
+
197
+ def do_save_state(rng_state, cuda_rng_state):
198
+ if agent_filename is not None:
199
+ save_data = {'step': step, 'rng_state': rng_state}
200
+ if cuda_rng_state is not None:
201
+ save_data['cuda_rng_state'] = cuda_rng_state
202
+ save_data['nn_args'] = self.agent.nn_args_dict()
203
+ save_data['agent_state'] = self.agent.state_dict()
204
+ save_data['best_agent_state'] = best_agent.state_dict()
205
+ save_data['optimizer_state'] = optimizer.state_dict()
206
+ save_data['scheduler_state'] = scheduler.state_dict()
207
+ torch.save(save_data, agent_filename)
208
+
209
+ def valid_sched_save(step):
210
+ if dist.is_initialized():
211
+ params = parameters_to_vector(self.agent.parameters())
212
+ params_clone = params.clone()
213
+ dist.broadcast(params_clone, 0)
214
+ assert torch.all(params == params_clone)
215
+
216
+ rng_state = torch.get_rng_state()
217
+ cuda_rng_state = None
218
+ if self.device.type == 'cuda':
219
+ cuda_rng_state = torch.cuda.get_rng_state(self.device)
220
+
221
+ print("{} - step={}, validate...".format(time.strftime("%Y-%m-%d %H:%M:%S"), step))
222
+ sys.stdout.flush()
223
+
224
+ if self.device.type == 'cuda':
225
+ torch.cuda.synchronize(self.device)
226
+ start_time = time.time()
227
+ valid_result = self.validate(valid_dataset, batch_size)
228
+ avg_cost1, avg_cost2, avg_feasible = valid_result
229
+ if self.device.type == 'cuda':
230
+ torch.cuda.synchronize(self.device)
231
+
232
+ duration = time.time() - start_time
233
+
234
+ if step > 0:
235
+ scheduler.step()
236
+
237
+ if not dist.is_initialized() or dist.get_rank() == 0:
238
+ do_save_state(rng_state, cuda_rng_state)
239
+
240
+ strftime = time.strftime("%Y-%m-%d %H:%M:%S")
241
+ print("{} - step={}, cost=[{:.6g}, {:.6g}], feasible={:.0%}".format(
242
+ strftime, step, avg_cost1, avg_cost2, avg_feasible))
243
+ print("{} - step={}, min_valid_cost={:.6g}, time={:.3f}s".format(
244
+ strftime, step, min(min_valid_cost, avg_cost2), duration))
245
+ print("---------------------------------------------------------------------------------------")
246
+ sys.stdout.flush()
247
+ return avg_cost2
248
+
249
+ if saved_state is not None:
250
+ start_step = saved_state['step']
251
+
252
+ if not dist.is_initialized() or dist.get_rank() == 0:
253
+ torch.set_rng_state(saved_state['rng_state'])
254
+ if torch.cuda.is_available():
255
+ torch.cuda.set_rng_state(saved_state['cuda_rng_state'], self.device)
256
+
257
+ best_agent.load_state_dict(saved_state['best_agent_state'])
258
+ self.agent.load_state_dict(saved_state['best_agent_state'])
259
+
260
+ # if 'agent_state' in saved_state:
261
+ # self.agent.load_state_dict(saved_state['agent_state'])
262
+ # else:
263
+ # self.agent.load_state_dict(saved_state['best_agent_state'])
264
+
265
+ if 'optimizer_state' in saved_state:
266
+ optimizer.load_state_dict(saved_state['optimizer_state'])
267
+ if 'scheduler_state' in saved_state:
268
+ scheduler.load_state_dict(saved_state['scheduler_state'])
269
+ else:
270
+ if dist.is_initialized() and dist.get_rank() == 0:
271
+ rng_state = torch.get_rng_state()
272
+ cuda_rng_state = None
273
+ if self.device.type == 'cuda':
274
+ cuda_rng_state = torch.cuda.get_rng_state(self.device)
275
+ do_save_state(rng_state, cuda_rng_state)
276
+
277
+ if dist.is_initialized() and dist.get_rank() == 0:
278
+ dist.barrier()
279
+
280
+ for step in range(start_step, max_steps):
281
+ if step % valid_steps == 0:
282
+ valid_cost = valid_sched_save(step)
283
+ if valid_cost < min_valid_cost:
284
+ best_agent.load_state_dict(self.agent.state_dict())
285
+ min_valid_cost = valid_cost
286
+
287
+ start_time = time.time()
288
+
289
+ # problem
290
+ with torch.no_grad():
291
+ problem = next(train_dataset)
292
+ if step < warmup_steps:
293
+ batch_size_now = batch_size // 2
294
+ else:
295
+ batch_size_now = batch_size
296
+ problem = self.to_device(problem)
297
+
298
+ if not on_policy:
299
+ data_agent = best_agent
300
+ else:
301
+ data_agent = self.agent
302
+
303
+ data_agent.eval()
304
+
305
+ # solution
306
+ if topk_size > 1:
307
+ with torch.no_grad():
308
+ batch_size_topk = batch_size_now * topk_size
309
+ env, logp = data_agent(problem, batch_size_topk)
310
+ cost = env.cost().sum(1).float()
311
+ solution = env.worker_task_sequence()
312
+
313
+ NP = problem.batch_size
314
+ NK = batch_size_now // NP
315
+ NS = solution.size(1)
316
+
317
+ cost = cost.view(NP, -1)
318
+ cost, kidx = cost.topk(NK, 1, False, False)
319
+ cost = cost.view(-1)
320
+ kidx = kidx[:, :, None, None].expand(-1, -1, NS, 3)
321
+ solution = solution.view(NP, -1, NS, 3)
322
+ solution = solution.gather(1, kidx).view(-1, NS, 3)
323
+
324
+ elif not on_policy:
325
+ with torch.no_grad():
326
+ env, logp = data_agent(problem, batch_size_now)
327
+ cost = env.cost().sum(1).float()
328
+ solution = env.worker_task_sequence()
329
+ else:
330
+ self.agent.train()
331
+ env, logp = self.agent(problem, batch_size_now, memopt=memopt)
332
+ cost = env.cost().sum(1).float()
333
+ solution = env.worker_task_sequence()
334
+
335
+ self.agent.train()
336
+
337
+ # advantage
338
+ with torch.no_grad():
339
+ NP = problem.batch_size
340
+ if topk_size > 1:
341
+ baseline = cost.view(NP, -1).max(1)[0]
342
+ else:
343
+ baseline = cost.view(NP, -1).mean(1)
344
+ baseline = repeat(baseline, cost.size(0) // NP)
345
+ adv = (cost - baseline)[:, None]
346
+ adv_norm = adv.norm()
347
+ if adv_norm > 0:
348
+ adv = adv / adv.norm() * adv.size(0)
349
+ adv = adv.sign() * adv.abs().pow(advpow)
350
+
351
+ # backward
352
+ if topk_size > 1 or not on_policy:
353
+ env, logp = self.agent(problem, batch_size_now, solution=solution, memopt=memopt)
354
+
355
+ loss = adv * logp
356
+ loss = loss.mean()
357
+ loss.backward()
358
+
359
+ if step % optim_steps == 0:
360
+ if dist.is_initialized():
361
+ params = filter(lambda a: a.grad is not None, self.agent.parameters())
362
+ grad_list = [param.grad for param in params]
363
+ grad_vector = parameters_to_vector(grad_list)
364
+ dist.all_reduce(grad_vector, op=dist.ReduceOp.SUM)
365
+ vector_to_parameters(grad_vector, grad_list)
366
+
367
+ grad_norm = clip_grad_norm_(self.agent.parameters(), 1)
368
+ optimizer.step()
369
+ optimizer.zero_grad()
370
+
371
+ if step % log_steps == 0:
372
+ strftime = time.strftime("%Y-%m-%d %H:%M:%S")
373
+ lr = optimizer.param_groups[0]['lr']
374
+ duration = time.time() - start_time
375
+ with torch.no_grad():
376
+ p = logp.to(torch.float64).sum(1).exp().mean()
377
+ print("{} - step={}, grad={:.6g}, lr={:.6g}, p={:.6g}".format(
378
+ strftime, step, grad_norm, lr, p))
379
+
380
+ print("{} - step={}, cost={:.6g}, time={:.3f}s".format(strftime, step, cost.mean(), duration))
381
+ print("---------------------------------------------------------------------------------------")
382
+ sys.stdout.flush()
383
+
384
+ valid_sched_save(step)
385
+
386
+ def solve(self, problem, greedy=False, batch_size=512):
387
+ if self.device.type == 'cuda':
388
+ with torch.cuda.device(cuda_or_none(self.device)):
389
+ return self.do_solve(problem, greedy, batch_size)
390
+ else:
391
+ return self.do_solve(problem, greedy, batch_size)
392
+
393
+ def do_solve(self, problem, greedy, batch_size):
394
+ isbatch = problem.isbatch
395
+ problem = self.to_batch(problem)
396
+ problem = self.to_device(problem)
397
+
398
+ if self.agent is None:
399
+ self.parse_nn_args(problem)
400
+ self.agent = self.new_agent()
401
+ self.agent.to(self.device)
402
+
403
+ self.agent.eval()
404
+
405
+ with torch.no_grad():
406
+ env, prob = self.agent(problem, batch_size, greedy, problem.solution)
407
+
408
+ NP = problem.batch_size
409
+ NR = prob.size(0) // NP
410
+
411
+ prob = prob.view(NP, NR, -1)
412
+ cost = env.cost().sum(1).view(NP, NR)
413
+ feasible = env.feasible().view(NP, NR)
414
+ size = list(env.worker_task_sequence().size())
415
+ size = [NP, NR] + size[1:]
416
+ worker_task_sequence = env.worker_task_sequence().view(size)
417
+
418
+ p_index = torch.arange(NP)
419
+ base_cost = cost.max() + 1
420
+ cost[~feasible] += base_cost
421
+ cost, s_index = cost.min(1)
422
+ feasible = feasible[p_index, s_index]
423
+ cost[~feasible] -= base_cost
424
+ probability = prob[p_index, s_index].exp()
425
+ worker_task_sequence = worker_task_sequence[p_index, s_index]
426
+
427
+ if isbatch:
428
+ solution = Solution(cost)
429
+ solution.feasible = feasible
430
+ solution.probability = probability
431
+ solution.worker_task_sequence = worker_task_sequence
432
+ else:
433
+ solution = Solution(cost.item())
434
+ solution.feasible = feasible.item()
435
+ solution.probability = probability.squeeze(0)
436
+ solution.worker_task_sequence = worker_task_sequence.squeeze(0)
437
+
438
+ return solution
439
+
440
+ def load_agent(self, filename, strict=True):
441
+ if self.device.type == 'cuda':
442
+ with torch.cuda.device(cuda_or_none(self.device)):
443
+ self.do_load_agent(filename, strict)
444
+ else:
445
+ self.do_load_agent(filename, strict)
446
+
447
+ def do_load_agent(self, filename, strict=True):
448
+ saved_state = torch.load(filename, map_location='cpu')
449
+ self.nn_args = saved_state['nn_args']
450
+
451
+ self.agent = self.new_agent()
452
+ self.agent.to(self.device)
453
+ self.agent.load_state_dict(saved_state['best_agent_state'], strict)
454
+ self.print_nn_args()
455
+
456
+ def to_batch(self, problem, pin_memory=True):
457
+ assert not hasattr(problem, 'feats')
458
+
459
+ NW = 1
460
+ NT = 1
461
+ NP = 1
462
+ isbatch = problem.isbatch
463
+ for k, v in problem.__dict__.items():
464
+ if k.startswith("worker_"):
465
+ NW = len(v[0]) if isbatch else len(v)
466
+ elif k.startswith("task_"):
467
+ NP = len(v) if isbatch else 1
468
+ NT = len(v[0]) if isbatch else len(v)
469
+ NWW = NW * 2
470
+
471
+ new_problem = Problem(True)
472
+ new_problem.feats = {}
473
+ new_problem.device = 'cpu'
474
+
475
+ new_problem.batch_size = NP
476
+ new_problem.worker_num = NW
477
+ new_problem.task_num = NT
478
+
479
+ new_problem.features = problem.features
480
+
481
+ if type(self) == Solver:
482
+ new_problem.variables = problem.variables
483
+ new_problem.constraint = problem.constraint
484
+ new_problem.objective = problem.objective
485
+ new_problem.environment = problem.environment
486
+ else:
487
+ new_problem.variables = []
488
+ new_problem.constraints = problem.constraints
489
+ new_problem.oa_estimate_tasks = problem.oa_estimate_tasks
490
+ new_problem.oa_multiple_steps = problem.oa_multiple_steps
491
+
492
+ edge_size_list = ((NWW + NT, NWW + NT), (NW + NT, NW + NT))
493
+
494
+ def check_size(f, k, v):
495
+ assert f, "size error, feature: {}, size: {}".format(k, tuple(v.size()))
496
+
497
+ for k, v in problem.__dict__.items():
498
+ if k == 'solution' and v is not None:
499
+ v = to_tensor(k, v, isbatch)
500
+ check_size(v.dim() == 3 and v.size(-1) == 3, k, v)
501
+ elif k.startswith("worker_task_"):
502
+ v = to_tensor(k, v, isbatch)
503
+ check_size(v.dim() in (3, 4) and v.size()[1:3] == (NW, NT), k, v)
504
+ elif k.startswith("worker_"):
505
+ v = to_tensor(k, v, isbatch)
506
+ check_size(v.dim() in (2, 3) and v.size(1) == NW, k, v)
507
+ elif k.startswith("task_"):
508
+ v = to_tensor(k, v, isbatch)
509
+ check_size(v.dim() in (2, 3) and v.size(1) == NT, k, v)
510
+ elif k.endswith("_matrix"):
511
+ v = to_tensor(k, v, isbatch)
512
+ check_size(v.dim() in (3, 4) and v.size()[1:3] in edge_size_list, k, v)
513
+ if v.size()[1:3] == (NW + NT, NW + NT):
514
+ worker_index = torch.arange(NW)
515
+ task_index = torch.arange(NT) + NW
516
+ index = torch.cat([worker_index, worker_index, task_index])
517
+ index1 = index[:, None]
518
+ index2 = index[None, :]
519
+ v = v[:, index1, index2]
520
+ elif isinstance(v, np.ndarray):
521
+ v = torch.tensor(v)
522
+
523
+ if isinstance(v, torch.Tensor):
524
+ new_problem.feats[k] = v
525
+
526
+ if pin_memory and self.device.type == 'cuda':
527
+ new_problem.pin_memory()
528
+
529
+ return new_problem
530
+
531
+ def to_device(self, problem):
532
+
533
+ assert hasattr(problem, 'feats')
534
+
535
+ new_problem = copy.copy(problem)
536
+ new_problem.device = self.device
537
+ new_problem.feats = {}
538
+
539
+ non_blocking = self.device.type == 'cuda'
540
+ for k, v in problem.feats.items():
541
+ v = v.to(self.device, non_blocking=non_blocking)
542
+ new_problem.feats[k] = v
543
+
544
+ return new_problem
545
+
546
+ def validate(self, problem_list, batch_size):
547
+ self.agent.eval()
548
+ with torch.no_grad():
549
+ valid_result = self.do_validate(problem_list, batch_size)
550
+
551
+ self.agent.train()
552
+ return valid_result
553
+
554
+ def do_validate(self, problem_list, batch_size):
555
+ total_cost1 = 0
556
+ total_cost2 = 0
557
+ total_feasible = 0
558
+ total_problem = 0
559
+ start_time = time.time()
560
+ for problem in problem_list:
561
+ problem = self.to_device(problem)
562
+ env, _, = self.agent(problem, batch_size)
563
+
564
+ NP = problem.batch_size
565
+ cost = env.cost().sum(1).view(NP, -1)
566
+ cost1, _ = cost.min(1)
567
+ cost2 = cost.mean(1)
568
+ feasible = env.feasible().view(NP, -1)
569
+ feasible = torch.any(feasible, 1)
570
+
571
+ total_cost1 += cost1.sum().item()
572
+ total_cost2 += cost2.sum().item()
573
+ total_feasible += feasible.int().sum().item()
574
+ total_problem += NP
575
+
576
+ if dist.is_initialized():
577
+ data = [total_cost1, total_cost2, total_feasible, total_problem]
578
+ data = torch.tensor(data, device=self.device)
579
+ dist.all_reduce(data, op=dist.ReduceOp.SUM)
580
+ total_cost1, total_cost2, total_feasible, total_problem = data.tolist()
581
+
582
+ avg_cost1 = total_cost1 / total_problem
583
+ avg_cost2 = total_cost2 / total_problem
584
+ avg_feasible = total_feasible / total_problem
585
+
586
+ return avg_cost1, avg_cost2, avg_feasible
587
+
588
+ def wrap_dataset(self, dataset, workers, buffers, seed):
589
+ if isinstance(dataset, IterableDataset):
590
+ dataset = WrapIterableDataset(dataset, self)
591
+ dataset = DataLoader(dataset, batch_size=None, pin_memory=True,
592
+ num_workers=workers, prefetch_factor=buffers,
593
+ worker_init_fn=lambda worker_id: torch.manual_seed(seed + worker_id))
594
+ else:
595
+ if self.device.type == 'cuda':
596
+ with torch.cuda.device(cuda_or_none(self.device)):
597
+ dataset = WrapDataset(dataset, self)
598
+ dataset = DataLoader(dataset, batch_size=None, pin_memory=True, shuffle=True)
599
+ else:
600
+ dataset = WrapDataset(dataset, self)
601
+ dataset = DataLoader(dataset, batch_size=None, pin_memory=True, shuffle=True)
602
+
603
+ return dataset
604
+
605
+ def print_nn_args(self):
606
+ for key, value in self.nn_args.items():
607
+ if type(value) in [int, float, str, bool]:
608
+ print("nn_args: {} = {}".format(key, value))
609
+ sys.stdout.flush()
610
+
611
+
612
+ def to_tensor(key, value, isbatch):
613
+ if isinstance(value, torch.Tensor):
614
+ tensor = value.to('cpu')
615
+ else:
616
+ tensor = torch.tensor(value, device='cpu')
617
+
618
+ if not isbatch:
619
+ tensor = tensor[None]
620
+
621
+ return tensor
622
+
623
+
624
+ def cuda_or_none(device):
625
+ return device if device.type == 'cuda' else None
greedrl/utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+
5
+ act_dict = {}
6
+ act_dict['none'] = lambda x: x
7
+ act_dict['relu'] = torch.relu
8
+ act_dict['tanh'] = torch.tanh
9
+ act_dict['sigmoid'] = torch.sigmoid
10
+
11
+
12
+ def get_act(act):
13
+ return act_dict[act]
14
+
15
+
16
+ def to_list(var):
17
+ if isinstance(var, dict):
18
+ return {k: to_list(v) for k, v in var.items()}
19
+ elif isinstance(var, list):
20
+ return [to_list(v) for v in var]
21
+ elif isinstance(var, tuple):
22
+ return (to_list(v) for v in var)
23
+ elif isinstance(var, torch.Tensor):
24
+ return var.tolist()
25
+ else:
26
+ return var
27
+
28
+
29
+ def repeat(tensor, size, dim=0):
30
+ return tensor.repeat_interleave(size, dim)
31
+
32
+
33
+ def get_default_device():
34
+ if not torch.cuda.is_available():
35
+ return torch.device("cpu")
36
+
37
+ cmd = 'nvidia-smi -q -d Memory | grep -A4 GPU | grep Free'
38
+ with os.popen(cmd) as result:
39
+ max_free_mem = 0
40
+ max_cuda_index = -1
41
+ for i, line in enumerate(result):
42
+ free_mem = int(line.strip().split()[2])
43
+ if free_mem > max_free_mem:
44
+ max_free_mem = free_mem
45
+ max_cuda_index = i
46
+
47
+ return torch.device("cuda:{}".format(max_cuda_index))
48
+
49
+
50
+ def cumem_stats(device, msg):
51
+ torch.cuda.empty_cache()
52
+ print("{}, device:{}, memory_allocated: {:.3f}G".format(msg, device,
53
+ torch.cuda.memory_allocated(device) / (1024 * 1024 * 1024)))
54
+
55
+
56
+ cutime_stats_time = None
57
+
58
+
59
+ def cutime_stats(device, msg=''):
60
+ global cutime_stats_time
61
+ torch.cuda.synchronize(device)
62
+ if cutime_stats_time is not None:
63
+ print("{} time: {:.6f}s".format(msg, time.time() - cutime_stats_time))
64
+
65
+ cutime_stats_time = time.time()
greedrl/variable.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import functools
3
+
4
+ from .utils import repeat
5
+
6
+
7
+ class VarMeta(object):
8
+ def __init__(self, clazz, **kwargs):
9
+ self.clazz = clazz
10
+ self._kwargs = kwargs
11
+ for k, v in kwargs.items():
12
+ setattr(self, k, v)
13
+
14
+ def __call__(self, problem, batch_size, sample_num):
15
+ kwargs = self._kwargs.copy()
16
+ kwargs['problem'] = problem.feats
17
+ kwargs['batch_size'] = batch_size
18
+ kwargs['sample_num'] = sample_num
19
+ kwargs['worker_num'] = problem.worker_num
20
+ kwargs['task_num'] = problem.task_num
21
+ return self.clazz(**kwargs)
22
+
23
+
24
+ def attribute_variable(name, attribute=None):
25
+ return VarMeta(AttributeVariable, name=name, attribute=attribute)
26
+
27
+
28
+ class AttributeVariable:
29
+ def __init__(self, name, attribute, problem, batch_size, sample_num, worker_num, task_num):
30
+ if attribute is None:
31
+ attribute = name;
32
+
33
+ self.name = name
34
+ self.value = problem[attribute]
35
+
36
+
37
+ def feature_variable(name, feature=None):
38
+ return VarMeta(FeatureVariable, name=name, feature=feature)
39
+
40
+
41
+ class FeatureVariable:
42
+ def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num):
43
+ if feature is None:
44
+ feature = name
45
+
46
+ assert feature == 'id' or feature.startswith("worker_") or feature.startswith("task_")
47
+
48
+ self.name = name
49
+ self.feature = problem[feature]
50
+ self.value = repeat(self.feature, sample_num)
51
+
52
+
53
+ def task_variable(name, feature=None):
54
+ return VarMeta(TaskVariable, name=name, feature=feature)
55
+
56
+
57
+ class TaskVariable:
58
+ def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num):
59
+ if feature is None:
60
+ feature = name
61
+
62
+ assert feature.startswith("task_")
63
+
64
+ self.name = name
65
+ self.feature = problem[feature]
66
+
67
+ size = list(self.feature.size())
68
+ size[0] = batch_size
69
+ del size[1]
70
+ self.value = self.feature.new_zeros(size)
71
+
72
+ def step_task(self, b_index, p_index, t_index):
73
+ self.value[b_index] = self.feature[p_index, t_index]
74
+
75
+
76
+ def worker_variable(name, feature=None):
77
+ return VarMeta(WorkerVariable, name=name, feature=feature)
78
+
79
+
80
+ class WorkerVariable:
81
+ def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num):
82
+ if feature is None:
83
+ feature = name
84
+
85
+ assert feature.startswith("worker_")
86
+
87
+ self.name = name
88
+ self.feature = problem[feature]
89
+
90
+ size = list(self.feature.size())
91
+ size[0] = batch_size
92
+ del size[1]
93
+ self.value = self.feature.new_zeros(size)
94
+
95
+ def step_worker_start(self, b_index, p_index, w_index):
96
+ self.value[b_index] = self.feature[p_index, w_index]
97
+
98
+
99
+ def worker_task_variable(name, feature=None):
100
+ return VarMeta(WorkerTaskVariable, name=name, feature=feature)
101
+
102
+
103
+ class WorkerTaskVariable:
104
+ def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num):
105
+ if feature is None:
106
+ feature = name
107
+
108
+ assert feature.startswith("worker_task_")
109
+
110
+ self.name = name
111
+ self.feature = problem[feature]
112
+
113
+ size = list(self.feature.size())
114
+ size[0] = batch_size
115
+
116
+ del size[1]
117
+ self._feature = self.feature.new_zeros(size)
118
+
119
+ del size[2]
120
+ self.value = self.feature.new_zeros(size)
121
+
122
+ def step_worker_start(self, b_index, p_index, w_index):
123
+ self._feature[b_index] = self.feature[p_index, w_index]
124
+
125
+ def step_task(self, b_index, p_index, t_index):
126
+ self.value[b_index] = self._feature[b_index, t_index]
127
+
128
+
129
+ def worker_task_group(name, feature=None):
130
+ return VarMeta(WorkerTaskGroup, name=name, feature=feature)
131
+
132
+
133
+ class WorkerTaskGroup:
134
+ def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num):
135
+ if feature is None:
136
+ feature = name
137
+
138
+ assert feature.startswith("task_")
139
+
140
+ self.name = name
141
+ self.feature = problem[feature].long()
142
+
143
+ NG = self.feature.max() + 1
144
+ assert torch.all(self.feature >= 0)
145
+
146
+ self.value = self.feature.new_zeros(batch_size, NG)
147
+
148
+ def step_worker_start(self, b_index, p_index, w_index):
149
+ self.value[b_index] = 0
150
+
151
+ def step_task(self, b_index, p_index, t_index):
152
+ group = self.feature[p_index, t_index]
153
+ self.value[b_index, group] += 1;
154
+
155
+
156
+ def worker_task_item(name, item_id, item_num):
157
+ return VarMeta(WorkerTaskItem, name=name, item_id=item_id, item_num=item_num)
158
+
159
+
160
+ class WorkerTaskItem:
161
+ def __init__(self, name, item_id, item_num, problem, batch_size, sample_num, worker_num, task_num):
162
+ assert item_id.startswith('task_')
163
+ assert item_num.startswith('task_')
164
+
165
+ self.name = name
166
+ self.item_id = repeat(problem[item_id], sample_num).long()
167
+ self.item_num = repeat(problem[item_num], sample_num)
168
+
169
+ assert torch.all(self.item_id >= 0)
170
+
171
+ size = [0, 0]
172
+ size[0] = self.item_id.size(0)
173
+ size[1] = self.item_id.max() + 1
174
+ self.value = self.item_num.new_zeros(size)
175
+
176
+ def step_worker_start(self, b_index, p_index, w_index):
177
+ self.value[b_index] = 0
178
+
179
+ def step_task(self, b_index, p_index, t_index):
180
+ item_id = self.item_id[b_index, t_index]
181
+ item_num = self.item_num[b_index, t_index]
182
+ self.value[b_index[:, None], item_id] += item_num
183
+
184
+ def make_feat(self):
185
+ NT = self.item_id.size(1)
186
+ v = self.value[:, None, :]
187
+ v = v.expand(-1, NT, -1)
188
+
189
+ v = v.gather(2, self.item_id).clamp(0, 1)
190
+ v = self.item_num.clamp(0, 1) - v
191
+ return v.clamp(0, 1).sum(2)
192
+
193
+
194
+ def task_demand_now(name, feature=None, only_this=False):
195
+ return VarMeta(TaskDemandNow, name=name, feature=feature, only_this=only_this)
196
+
197
+
198
+ class TaskDemandNow:
199
+ def __init__(self, name, feature, only_this, problem, batch_size, sample_num, worker_num, task_num):
200
+
201
+ if feature is None:
202
+ feature = name
203
+
204
+ assert feature.startswith("task_")
205
+
206
+ self.name = name
207
+ self.only_this = only_this
208
+ self._value = repeat(problem[feature], sample_num)
209
+
210
+ assert self._value.dtype in \
211
+ (torch.int8, torch.int16, torch.int32, torch.int64)
212
+ assert torch.all(self._value >= 0)
213
+
214
+ if only_this:
215
+ size = self._value.size(0)
216
+ self.value = self._value.new_zeros(size)
217
+ else:
218
+ self.value = self._value
219
+
220
+ def step_task(self, b_index, p_index, t_index, done):
221
+ if done is not None:
222
+ self._value[b_index, t_index] -= done
223
+
224
+ if self.only_this:
225
+ self.value[b_index] = self._value[b_index, t_index]
226
+ else:
227
+ self.value = self._value
228
+
229
+
230
+ def worker_count_now(name, feature=None):
231
+ return VarMeta(WorkerCountNow, name=name, feature=feature)
232
+
233
+
234
+ class WorkerCountNow:
235
+ def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num):
236
+ if feature is None:
237
+ feature = name
238
+
239
+ assert feature.startswith("worker_")
240
+
241
+ self.name = name
242
+ self.value = repeat(problem[feature], sample_num)
243
+
244
+ assert self.value.dtype in \
245
+ (torch.int8, torch.int16, torch.int32, torch.int64)
246
+ assert torch.all(self.value >= 0)
247
+
248
+ def step_worker_start(self, b_index, p_index, w_index):
249
+ self.value[b_index, w_index] -= 1
250
+
251
+
252
+ def edge_variable(name, feature, last_to_this=False,
253
+ this_to_task=False, task_to_end=False, last_to_loop=False):
254
+ return VarMeta(EdgeVariable, name=name, feature=feature,
255
+ last_to_this=last_to_this, this_to_task=this_to_task, task_to_end=task_to_end,
256
+ last_to_loop=last_to_loop)
257
+
258
+
259
+ class EdgeVariable:
260
+ def __init__(self, name, feature, last_to_this, this_to_task, task_to_end, last_to_loop,
261
+ problem, batch_size, sample_num, worker_num, task_num):
262
+
263
+ assert feature.endswith("_matrix")
264
+
265
+ flags = [last_to_this, this_to_task, task_to_end, last_to_loop]
266
+ assert flags.count(True) == 1 and flags.count(False) == 3
267
+
268
+ if feature is None:
269
+ feature = name
270
+
271
+ self.name = name
272
+ self.last_to_this = last_to_this
273
+ self.this_to_task = this_to_task
274
+ self.task_to_end = task_to_end
275
+ self.last_to_loop = last_to_loop
276
+
277
+ self.worker_num = worker_num
278
+ self.task_num = task_num
279
+
280
+ self.feature = problem[feature]
281
+
282
+ size = list(self.feature.size())
283
+ size[0] = batch_size
284
+ del size[1:3]
285
+
286
+ if self.this_to_task or self.task_to_end:
287
+ size.insert(1, task_num)
288
+ self.value = self.feature.new_zeros(size)
289
+ else:
290
+ self.value = self.feature.new_zeros(size)
291
+
292
+ self.end_index = self.feature.new_zeros(size[0], dtype=torch.int64)
293
+ self.loop_index = self.feature.new_zeros(size[0], dtype=torch.int64)
294
+ self.last_index = self.feature.new_zeros(size[0], dtype=torch.int64)
295
+ self.task_index = (torch.arange(task_num) + worker_num * 2)[None, :]
296
+
297
+ def step_worker_start(self, b_index, p_index, w_index):
298
+ if self.last_to_this:
299
+ self.value[b_index] = 0
300
+ self.last_index[b_index] = w_index
301
+ elif self.this_to_task:
302
+ self.do_this_to_task(b_index, p_index, w_index)
303
+ elif self.task_to_end:
304
+ self.end_index[b_index] = w_index + self.worker_num
305
+ self.do_task_to_end(b_index, p_index)
306
+ elif self.last_to_loop:
307
+ self.value[b_index] = 0
308
+ self.last_index[b_index] = w_index
309
+
310
+ def step_worker_end(self, b_index, p_index, w_index):
311
+ this_index = w_index + self.worker_num
312
+ if self.last_to_this:
313
+ self.do_last_to_this(b_index, p_index, this_index)
314
+ elif self.this_to_task:
315
+ self.do_this_to_task(b_index, p_index, this_index)
316
+ elif self.task_to_end:
317
+ pass
318
+ elif self.last_to_loop:
319
+ self.do_last_to_loop(b_index, p_index)
320
+
321
+ def step_task(self, b_index, p_index, t_index):
322
+ this_index = t_index + self.worker_num * 2
323
+ if self.last_to_this:
324
+ self.do_last_to_this(b_index, p_index, this_index)
325
+ self.last_index[b_index] = this_index
326
+ elif self.this_to_task:
327
+ self.do_this_to_task(b_index, p_index, this_index)
328
+ elif self.task_to_end:
329
+ pass
330
+ elif self.last_to_loop:
331
+ last_index = self.last_index[b_index]
332
+ loop_index = self.loop_index[b_index]
333
+ self.loop_index[b_index] = torch.where(last_index < self.worker_num, this_index, loop_index)
334
+ self.last_index[b_index] = this_index
335
+
336
+ def do_last_to_this(self, b_index, p_index, this_index):
337
+ last_index = self.last_index[b_index]
338
+ self.value[b_index] = self.feature[p_index, last_index, this_index]
339
+
340
+ def do_this_to_task(self, b_index, p_index, this_index):
341
+ p_index2 = p_index[:, None]
342
+ this_index2 = this_index[:, None]
343
+ task_index2 = self.task_index
344
+ self.value[b_index] = self.feature[p_index2, this_index2, task_index2]
345
+
346
+ def do_task_to_end(self, b_index, p_index):
347
+ p_index2 = p_index[:, None]
348
+ task_index2 = self.task_index
349
+ end_index = self.end_index[b_index]
350
+ end_index2 = end_index[:, None]
351
+ self.value[b_index] = self.feature[p_index2, task_index2, end_index2]
352
+
353
+ def do_last_to_loop(self, b_index, p_index):
354
+ loop_index = self.loop_index[b_index]
355
+ last_index = self.last_index[b_index]
356
+ self.value[b_index] = self.feature[p_index, last_index, loop_index]
357
+
358
+ def make_feat(self):
359
+ assert self.this_to_task or self.task_to_end, \
360
+ "one of [this_to_task, task_to_end] must be true"
361
+ return self.value.clone()
362
+
363
+
364
+ def worker_used_resource(name, edge_require=None, task_require=None, task_ready=None, worker_ready=None, task_due=None):
365
+ return VarMeta(WorkerUsedResource, name=name, edge_require=edge_require, task_require=task_require,
366
+ task_ready=task_ready, worker_ready=worker_ready, task_due=task_due)
367
+
368
+
369
+ class WorkerUsedResource:
370
+ def __init__(self, name, edge_require, task_require, task_ready, worker_ready, task_due,
371
+ problem, batch_size, sample_num, worker_num, task_num):
372
+
373
+ assert edge_require is None or edge_require.endswith("_matrix"), "unsupported edge: {}".format(edge_require)
374
+ assert task_require is None or task_require.startswith("task_"), "unsupported task_require: {}".format(
375
+ task_require)
376
+ assert task_ready is None or task_ready.startswith("task_"), "unsupported task_service: {}".format(task_ready)
377
+ assert worker_ready is None or worker_ready.startswith("worker_") and not worker_ready.startswith(
378
+ "worker_task_")
379
+ assert task_due is None or task_due.startswith("task_"), "unsupported task_due: {}".format(task_due)
380
+
381
+ self.name = name
382
+
383
+ self.worker_num = worker_num
384
+ self.task_num = task_num
385
+
386
+ if edge_require is None:
387
+ self.edge_require = None
388
+ else:
389
+ self.edge_require = problem[edge_require]
390
+ self.last_index = self.edge_require.new_zeros(batch_size, dtype=torch.int64)
391
+
392
+ if task_require is None:
393
+ self.task_require = None
394
+ else:
395
+ self.task_require = problem[task_require]
396
+ self.task_require2 = repeat(self.task_require, sample_num)
397
+
398
+ if task_ready is None:
399
+ self.task_ready = None
400
+ else:
401
+ self.task_ready = problem[task_ready]
402
+
403
+ if worker_ready is None:
404
+ self.worker_ready = None
405
+ else:
406
+ self.worker_ready = problem[worker_ready]
407
+
408
+ if task_due is None:
409
+ self.task_due = None
410
+ else:
411
+ self.task_due = problem[task_due]
412
+
413
+ tenors = [self.edge_require, self.task_require, self.task_ready, self.worker_ready]
414
+ tenors = list(filter(lambda x: x is not None, tenors))
415
+ assert tenors, "at least one of edge_require, task_require, task_ready, worker_ready is required!"
416
+
417
+ size = list(tenors[0].size())
418
+ size[0] = batch_size
419
+ if self.edge_require is None:
420
+ del size[1]
421
+ else:
422
+ del size[1:3]
423
+
424
+ self.value = tenors[0].new_zeros(size)
425
+
426
+ def step_worker_start(self, b_index, p_index, w_index):
427
+ if self.worker_ready is None:
428
+ self.value[b_index] = 0
429
+ else:
430
+ self.value[b_index] = self.worker_ready[p_index, w_index]
431
+
432
+ if self.edge_require is not None:
433
+ self.last_index[b_index] = w_index
434
+
435
+ def step_worker_end(self, b_index, p_index, w_index):
436
+ if self.edge_require is not None:
437
+ last_index = self.last_index[b_index]
438
+ this_index = w_index + self.worker_num
439
+ self.value[b_index] += self.edge_require[p_index, last_index, this_index]
440
+ self.last_index[b_index] = this_index;
441
+
442
+ def step_task(self, b_index, p_index, t_index, done):
443
+ if done is None:
444
+ if self.edge_require is not None:
445
+ last_index = self.last_index[b_index]
446
+ this_index = t_index + (self.worker_num * 2)
447
+ self.value[b_index] += self.edge_require[p_index, last_index, this_index]
448
+ self.last_index[b_index] = this_index
449
+
450
+ if self.task_ready is not None:
451
+ self.value[b_index] = torch.max(self.value[b_index], self.task_ready[p_index, t_index])
452
+
453
+ else:
454
+ if self.task_require is not None:
455
+ if self.value.dim() == 2:
456
+ done = done[:, None]
457
+ self.value[b_index] += self.task_require[p_index, t_index] * done
458
+
459
+ def make_feat(self):
460
+ assert self.value.dim() == 2, \
461
+ "value's dim must be 2, actual: {}".format(self.value.dim())
462
+ assert self.task_require is not None, "task_require is required"
463
+
464
+ v = self.value[:, None, :] + self.task_require2
465
+ return v.clamp(0, 1).sum(2, dtype=v.dtype)
466
+
467
+
468
+ def worker_task_sequence(name):
469
+ return VarMeta(WorkerTaskSequence, name=name)
470
+
471
+
472
+ class WorkerTaskSequence:
473
+ def __init__(self, name, problem, batch_size, sample_num, worker_num, task_num):
474
+ self.name = name
475
+ self.value = None
476
+
477
+ def step_finish(self, worker_task_seq):
478
+ self.value = worker_task_seq
images/GREEDRL-Framwork.png ADDED
images/GREEDRL-Framwork_en.png ADDED
images/GREEDRL-Logo-Original-640.png ADDED
images/GREEDRL-Network.png ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==1.12.1+cu113
2
+ torchvision==0.13.1+cu113
3
+ torchaudio==0.12.1
4
+ numpy==1.24.2
5
+ Cython==0.29.34
6
+ ortools==9.6.2534
7
+
setup.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+ import time
5
+ import subprocess
6
+
7
+ from distutils import sysconfig
8
+ from setuptools import setup, Extension, find_packages
9
+ from Cython.Build import build_ext, cythonize
10
+
11
+
12
+ class CMakeExtension(Extension):
13
+ def __init__(self, name, sourcedir=''):
14
+ Extension.__init__(self, name, sources=[])
15
+ self.sourcedir = os.path.abspath(sourcedir)
16
+
17
+
18
+ class CMakeBuild(build_ext):
19
+ def build_extension(self, ext):
20
+ if isinstance(ext, CMakeExtension):
21
+ extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
22
+ if not extdir.endswith(os.path.sep):
23
+ extdir += os.path.sep
24
+
25
+ if not os.path.exists(self.build_temp):
26
+ os.makedirs(self.build_temp)
27
+
28
+ subprocess.check_call(['cmake', ext.sourcedir, '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir],
29
+ cwd=self.build_temp)
30
+ subprocess.check_call(['cmake', '--build', '.', '--', 'VERBOSE=1', '-j8'], cwd=self.build_temp)
31
+ else:
32
+ super().build_extension(ext)
33
+
34
+
35
+ ext_modules = [CMakeExtension('greedrl_c')]
36
+
37
+ setup(
38
+ name='greedrl',
39
+ version='1.0.0',
40
+ packages=find_packages(),
41
+ ext_modules=ext_modules,
42
+ cmdclass={'build_ext': CMakeBuild},
43
+ install_requires=["torch==1.12.1+cu113"],
44
+ )
test/all_test.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from solver_test import *
2
+ from function_test import *
3
+
4
+ if __name__ == '__main__':
5
+
6
+ unittest.main()
7
+
test/basetest.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import unittest
3
+
4
+
5
+ class TestCase(unittest.TestCase):
6
+ def tearDown(self):
7
+ torch.cuda.empty_cache()
8
+
test/function_test.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import torch
4
+ import unittest
5
+ import basetest
6
+
7
+ from greedrl import Solver
8
+ from greedrl.function import *
9
+
10
+ device = Solver().device
11
+
12
+
13
+ class TestFunction(basetest.TestCase):
14
+
15
+ def test_task_group_split(self):
16
+ group = torch.ones((8, 8), dtype=torch.int32)
17
+ group[:, 0:4] = 0
18
+ value = torch.zeros((8, 8), dtype=torch.bool)
19
+ value[:, 0:4] = True
20
+ result = task_group_split(group, value)
21
+ assert not torch.any(result)
22
+
23
+ value[:, 0:2] = False
24
+ result = task_group_split(group, value)
25
+ assert torch.all(result)
26
+
27
+ def test_task_group_split2(self):
28
+ group = torch.randint(48, (1024, 1000), dtype=torch.int32)
29
+ value = torch.randint(2, (1024, 1000), dtype=torch.int8) <= 0
30
+ self.do_test(task_group_split, group, value)
31
+
32
+ def test_task_group_priority(self):
33
+ group = torch.ones((8, 8), dtype=torch.int32)
34
+ group[:, 0:4] = 0
35
+ priority = torch.tensor([0, 1, 2, 3, 0, 1, 2, 3], dtype=torch.int32)
36
+ priority = priority[None, :].expand(8, -1).clone()
37
+ value = torch.zeros((8, 8), dtype=torch.bool)
38
+ value[:, 4:6] = True
39
+
40
+ result = task_group_priority(group, priority, value)
41
+ expected = torch.tensor([False, True, True, True, True, True, False, True])
42
+ expected = expected[None, :].expand(8, -1)
43
+ assert torch.all(result == expected)
44
+
45
+ def test_task_group_priority2(self):
46
+ group = torch.randint(48, (1024, 1000), dtype=torch.int32)
47
+ value = torch.randint(2, (1024, 1000), dtype=torch.int8) < 1
48
+ priority = torch.randint(2, (1024, 1000), dtype=torch.int32)
49
+ self.do_test(task_group_priority, group, priority, value)
50
+
51
+ def do_test(self, function, *args):
52
+ print("\ntest {} ...".format(function.__name__))
53
+ start = time.time()
54
+ result1 = function(*args)
55
+ print("time: {:.6f}s, device: {}".format(time.time() - start, args[0].device))
56
+
57
+ args = [arg.to(device) for arg in args]
58
+ result1 = result1.to(device)
59
+
60
+ function(*args)
61
+ self.sync_device(device)
62
+
63
+ start = time.time()
64
+ result2 = function(*args)
65
+ self.sync_device(device)
66
+ print("time: {:.6f}s, device: {} ".format(time.time() - start, args[0].device))
67
+
68
+ if result1.is_floating_point():
69
+ assert torch.all(torch.abs(result1 - result2) < 1e-6)
70
+ else:
71
+ assert torch.all(result1 == result2)
72
+
73
+ def sync_device(self, device):
74
+ if device.type == 'cuda':
75
+ torch.cuda.synchronize(device)
76
+
77
+
78
+ if __name__ == '__main__':
79
+ unittest.main()
test/solver_test.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os.path as osp
3
+ import torch
4
+ import unittest
5
+
6
+ import basetest
7
+ from greedrl import Solver
8
+ from greedrl.const import *
9
+
10
+ sys.path.append(osp.join(osp.dirname(osp.abspath(__file__)), "../"))
11
+ from examples.cvrp import cvrp
12
+
13
+
14
+ class TestSolver(basetest.TestCase):
15
+ def test(self):
16
+ problem_list = cvrp.make_problem(1)
17
+
18
+ nn_args = {}
19
+ nn_args['decode_rnn'] = 'GRU'
20
+ solver = Solver(None, nn_args)
21
+
22
+ solver.train(None, problem_list, problem_list,
23
+ batch_size=32, max_steps=5, memopt=10)
24
+
25
+ solver.train(None, problem_list, problem_list,
26
+ batch_size=32, max_steps=5, memopt=10, topk_size=10)
27
+
28
+ solver.train(None, problem_list, problem_list,
29
+ batch_size=32, max_steps=5, memopt=10, on_policy=False)
30
+
31
+ solution = solver.solve(problem_list[0], batch_size=8)
32
+ assert torch.all(solution.worker_task_sequence[:, -1, 0] == GRL_FINISH)
33
+ problem_list[0].solution = solution.worker_task_sequence[:, 0:-1, :]
34
+
35
+ solution2 = solver.solve(problem_list[0], batch_size=1)
36
+ assert torch.all(solution.worker_task_sequence == solution2.worker_task_sequence)
37
+
38
+
39
+ if __name__ == '__main__':
40
+ unittest.main()