Wan Xinyi commited on
Commit
4835f75
·
1 Parent(s): 29c9647

Render schedule

Browse files
Files changed (4) hide show
  1. app.py +62 -29
  2. description1.md +1 -26
  3. description2.md +24 -0
  4. svg_event.py +367 -0
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
2
  import auto_schedule
3
  import v_schedule
4
-
 
 
5
  def greet(name, is_morning, temperature):
6
  salutation = "Good morning" if is_morning else "Good evening"
7
  greeting = f"{salutation} {name}. It is {temperature} degrees today"
@@ -11,7 +13,7 @@ def greet(name, is_morning, temperature):
11
  def percentage(x):
12
  return f"{x*100:.2f}%"
13
 
14
- def get_schedule_time_and_image(result):
15
  result = [
16
  list(filter(lambda x: x.type in {'F', 'B', 'W'}, r)) for r in result
17
  ]
@@ -20,13 +22,34 @@ def get_schedule_time_and_image(result):
20
  max([x.completion_time for x in stage]) - min([x.start_time for x in stage]) for stage in result
21
  ]
22
  )
23
- return time, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def calculate(p, m, f, b, w, c, mem):
26
- baseline_time=(f+b+w)*m + (f+b+w+c)*(p-1)
27
- baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1)
28
- baseline_acceleration=percentage(0)
29
- baseline_image=None
 
 
 
 
 
 
30
 
31
 
32
  zb_result = auto_schedule.auto_schedule(p, m, auto_schedule.GraphConfig(
@@ -37,29 +60,39 @@ def calculate(p, m, f, b, w, c, mem):
37
  max_mem=mem * 2,
38
  print_scaling=1000
39
  ))
40
- zb_time,zb_image=get_schedule_time_and_image(zb_result)
 
41
 
42
  zb_bubble=percentage(zb_time/(f+b+w)/m - 1)
43
- zb_acceleration=percentage(baseline_time/zb_time - 1)
44
 
45
- zbv_graph = v_schedule.PipelineGraph(
46
- n_stage=p,
47
- n_micro=m,
48
- f_cost=f/2,
49
- b_cost=b/2,
50
- w_cost=w/2,
51
- c_cost=c,
52
- f_mem=2,
53
- b_mem=-1,
54
- w_mem=-1,
55
- max_mem=mem * 4,
56
- )
57
- zbv_result = zbv_graph.get_v_schedule()
 
 
 
 
 
 
 
 
 
 
58
 
59
- zbv_time,zbv_image = get_schedule_time_and_image(zbv_result)
60
- zbv_bubble=percentage(zbv_time/(f+b+w)/m - 1)
61
- zbv_acceleration=percentage(baseline_time/zbv_time - 1)
62
- zbv_image=None
63
 
64
  return [baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image]
65
 
@@ -87,8 +120,8 @@ with gr.Blocks() as demo:
87
  print("update")
88
  if s=="custom":
89
  return mem
90
- return p*int(s[:-1])
91
- memsel=gr.Radio(choices=["1p", "2p", "3p", "custom"], value="1p")
92
  mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For ZBV schedules, this is relative to two virtual stages on a stage combined.", value=p.value, interactive=True, precision=0)
93
  memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
94
  p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
@@ -115,7 +148,7 @@ with gr.Blocks() as demo:
115
  with gr.Column(scale=4):
116
  zb_image=gr.Image(None, interactive=False, label="Schedule Image")
117
  with gr.Group():
118
- gr.Markdown("Zero Bubble V Schedule")
119
  with gr.Row():
120
  with gr.Column(scale=1):
121
  zbv_time=gr.Textbox("", label="Longest Stage Time")
 
1
  import gradio as gr
2
  import auto_schedule
3
  import v_schedule
4
+ from PIL import Image
5
+ from svg_event import render_manual_graph
6
+ import pathlib
7
  def greet(name, is_morning, temperature):
8
  salutation = "Good morning" if is_morning else "Good evening"
9
  greeting = f"{salutation} {name}. It is {temperature} degrees today"
 
13
  def percentage(x):
14
  return f"{x*100:.2f}%"
15
 
16
+ def get_schedule_time(result):
17
  result = [
18
  list(filter(lambda x: x.type in {'F', 'B', 'W'}, r)) for r in result
19
  ]
 
22
  max([x.completion_time for x in stage]) - min([x.start_time for x in stage]) for stage in result
23
  ]
24
  )
25
+ return time
26
+
27
+ img_queue = []
28
+ def get_schedule_image(result, max_time):
29
+ result = [
30
+ list(filter(lambda x: x.type in {'F', 'B', 'W'}, r)) for r in result
31
+ ]
32
+ svg = render_manual_graph(result, max_time, len(result[0]) <= 72)
33
+ img_queue.append(svg)
34
+ if len(img_queue) > 32:
35
+ poped = img_queue.pop(0)
36
+ pathlib.Path(poped).unlink()
37
+
38
+ return pathlib.Path(svg)
39
+
40
+
41
 
42
  def calculate(p, m, f, b, w, c, mem):
43
+ if mem < p:
44
+ baseline_time=None
45
+ baseline_bubble=None
46
+ baseline_acceleration=None
47
+ baseline_image=None
48
+ else:
49
+ baseline_time=(f+b+w)*m + (f+b+w+c)*(p-1)
50
+ baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1)
51
+ baseline_acceleration=percentage(0)
52
+ baseline_image=None
53
 
54
 
55
  zb_result = auto_schedule.auto_schedule(p, m, auto_schedule.GraphConfig(
 
60
  max_mem=mem * 2,
61
  print_scaling=1000
62
  ))
63
+
64
+ zb_time=get_schedule_time(zb_result)
65
 
66
  zb_bubble=percentage(zb_time/(f+b+w)/m - 1)
67
+ zb_acceleration=percentage(baseline_time/zb_time - 1) if baseline_time is not None else None
68
 
69
+ if mem < p:
70
+ zbv_time=None
71
+ zbv_bubble=None
72
+ zbv_acceleration=None
73
+ zbv_image=None
74
+ else:
75
+ zbv_graph = v_schedule.PipelineGraph(
76
+ n_stage=p,
77
+ n_micro=m,
78
+ f_cost=f/2,
79
+ b_cost=b/2,
80
+ w_cost=w/2,
81
+ c_cost=c,
82
+ f_mem=2,
83
+ b_mem=-1,
84
+ w_mem=-1,
85
+ max_mem=mem * 4,
86
+ )
87
+ zbv_result = zbv_graph.get_v_schedule()
88
+
89
+ zbv_time = get_schedule_time(zbv_result)
90
+ zbv_bubble=percentage(zbv_time/(f+b+w)/m - 1)
91
+ zbv_acceleration=percentage(baseline_time/zbv_time - 1) if baseline_time is not None else None
92
 
93
+ max_time = max([baseline_time, zb_time, zbv_time])
94
+ zb_image = get_schedule_image(zb_result, max_time)
95
+ zbv_image = get_schedule_image(zbv_result, max_time)
 
96
 
97
  return [baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image]
98
 
 
120
  print("update")
121
  if s=="custom":
122
  return mem
123
+ return p*int(s[:1])
124
+ memsel=gr.Radio(choices=["1p (Same as 1F1B)", "2p", "3p", "custom"], value="1p (Same as 1F1B)")
125
  mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For ZBV schedules, this is relative to two virtual stages on a stage combined.", value=p.value, interactive=True, precision=0)
126
  memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
127
  p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
 
148
  with gr.Column(scale=4):
149
  zb_image=gr.Image(None, interactive=False, label="Schedule Image")
150
  with gr.Group():
151
+ gr.Markdown("Zero Bubble V Schedule (ZBV)")
152
  with gr.Row():
153
  with gr.Column(scale=1):
154
  zbv_time=gr.Textbox("", label="Longest Stage Time")
description1.md CHANGED
@@ -6,29 +6,4 @@ Our paper is coming soon.
6
 
7
  Try out our implementation based on Megatron on [https://github.com/sail-sg/zero-bubble-pipeline-parallelism](https://github.com/sail-sg/zero-bubble-pipeline-parallelism)
8
 
9
- Experiments shows zero bubble pipeline parallelism can accelerate training up to 30% with a similar memory comsumption. A detailed table of experiments is coming soon.
10
-
11
- ## Zero Bubble Schedules
12
- The key of achieving zero bubble is to breaking a backward pass into a B pass and W pass. B on one stage will only depend on the B on its next stage, compared to depending on both B and W of in 1F1B.
13
-
14
- ![image](https://hackmd.io/_uploads/Bkc7CL7N6.png)
15
-
16
- ### Comparision of Schedules
17
- * 1F1B
18
- ![image](https://hackmd.io/_uploads/Hkq-gD7N6.png)
19
- * ZB1P
20
- ![image](https://hackmd.io/_uploads/Hy2GxwmEa.png)
21
- * ZB2P
22
- ![image](https://hackmd.io/_uploads/S10QgvmV6.png)
23
- * ZBV - Each device is assigned to exactly 2 chunks (virtual stages), where white text colors represent the first chunk and black text colors represent the second chunk. The sequence of dependencies among model chunks follows a ”V” shape pattern for both the forward and backward passes.
24
- ![image](https://hackmd.io/_uploads/Sk9uyY4ra.png)
25
-
26
-
27
-
28
-
29
- | Comparison assuming T_F=T_B=T_W | 1F1B | ZB1P | ZB2P | ZBV (Recommended) |
30
- | ----------------------------------------------------- | ------- | -------- | ---- | --- |
31
- | Bubble Rate | (p-1)/m | (p-1)/3m | 0 | 0 |
32
- | Activation Memory <br> (Compared to 1F1B) | 1x | 1x | 2x | 1x |
33
- | Pipeline Communication Volume <br> (Compared to 1F1B) | 1x | 1x | 1x | 2x |
34
-
 
6
 
7
  Try out our implementation based on Megatron on [https://github.com/sail-sg/zero-bubble-pipeline-parallelism](https://github.com/sail-sg/zero-bubble-pipeline-parallelism)
8
 
9
+ Experiments shows zero bubble pipeline parallelism can accelerate training up to 30% with a similar memory comsumption. A detailed table of experiments is coming soon.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
description2.md CHANGED
@@ -1,3 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
 
3
  ## Optimizer Post Validation
 
1
+ ## Zero Bubble Schedules
2
+ The key of achieving zero bubble is to breaking a backward pass into a B pass and W pass. B on one stage will only depend on the B on its next stage, compared to depending on both B and W of in 1F1B.
3
+
4
+ ![image](https://hackmd.io/_uploads/Bkc7CL7N6.png)
5
+
6
+ ### Comparision of Schedules
7
+ * 1F1B
8
+ ![image](https://hackmd.io/_uploads/Hkq-gD7N6.png)
9
+ * ZB1P
10
+ ![image](https://hackmd.io/_uploads/Hy2GxwmEa.png)
11
+ * ZB2P
12
+ ![image](https://hackmd.io/_uploads/S10QgvmV6.png)
13
+ * ZBV - Each device is assigned to exactly 2 chunks (virtual stages), where white text colors represent the first chunk and black text colors represent the second chunk. The sequence of dependencies among model chunks follows a ”V” shape pattern for both the forward and backward passes.
14
+ ![image](https://hackmd.io/_uploads/rkfUVYNrp.png)
15
+
16
+
17
+
18
+
19
+ | Comparison assuming T_F=T_B=T_W | 1F1B | ZB1P | ZB2P | ZBV (Recommended) |
20
+ | ----------------------------------------------------- | ------- | -------- | ---- | --- |
21
+ | Bubble Rate | (p-1)/m | (p-1)/3m | 0 | 0 |
22
+ | Activation Memory <br> (Compared to 1F1B) | 1x | 1x | 2x | 1x |
23
+ | Pipeline Communication Volume <br> (Compared to 1F1B) | 1x | 1x | 1x | 2x |
24
+
25
 
26
 
27
  ## Optimizer Post Validation
svg_event.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import numpy as np
4
+ import drawsvg as draw
5
+ import colorsys
6
+ import tempfile, os
7
+
8
+ def filter_time(e, start, end):
9
+ # return e["start_time"] >= start and (end is None or e["completion_time"] <= end)
10
+ # Check completion_time here to include the last long W
11
+ return e.completion_time >= start and (end is None or e.completion_time <= end)
12
+
13
+
14
+ def load_json_data(filename, start=0, end=None, time_scale=1):
15
+ with open(filename) as f:
16
+ data = json.loads(f.read())
17
+ fbw_types = {"F", "B", "W", "Optimizer"}
18
+ return [[{
19
+ "type": e["type"],
20
+ "start_time": int(max(e["start_time"] - start, 0)) * time_scale,
21
+ "completion_time": int(e["completion_time"] - start) * time_scale,
22
+ "minibatch": e.get("minibatch", None),
23
+ } for e in dev_evs
24
+ if e["type"] in fbw_types and filter_time(e, start, end)
25
+ ] for dev_evs in data]
26
+
27
+
28
+ ENABLE_BORDER = True
29
+ ENABLE_BATCH_ID = True
30
+ ENABLE_EDGE_BLUR = False
31
+ SCALE_FACTOR = 2
32
+ S = SCALE_FACTOR
33
+
34
+ # TIME_PER_UNIT = 300 // SCALE_FACTOR
35
+ TIME_PER_UNIT = 4000 // SCALE_FACTOR
36
+
37
+
38
+ def to_color_fmt(c):
39
+ # c = to_greyscale(c)
40
+ return f"#{hex(c[0])[2:]}{hex(c[1])[2:]}{hex(c[2])[2:]}"
41
+
42
+
43
+ GREYSCALE_WEIGHTS = np.array([0.299, 0.587, 0.114])
44
+
45
+
46
+ def to_greyscale(color):
47
+ c = np.dot(GREYSCALE_WEIGHTS, color[:3].astype(float)).astype(int)
48
+ return np.array([c, c, c, 255])
49
+
50
+
51
+ COLOR_VALUE_MAP = {
52
+ "F": np.array([57, 122, 242]),
53
+ "B": np.array([62, 181, 191]),
54
+ # "B": np.array([68, 211, 218]), # sea color
55
+ # "W": to_color_fmt(np.array([47, 158, 73, 255])),
56
+ "W": np.array([41, 137, 64]),
57
+ # "W": np.array([224, 240, 231]), # sea color
58
+ # "Optimizer": to_color_fmt(np.array([255, 240, 197, 255])),
59
+ "Optimizer": np.array([255, 217, 102]),
60
+ }
61
+
62
+
63
+ COLOR_MAP = {k: to_color_fmt(v) for k, v in COLOR_VALUE_MAP.items()}
64
+
65
+
66
+ # BORDER_SIZE = SCALE_FACTOR // 2
67
+ BORDER_SIZE = 1
68
+ SPAN_HEIGHT = SCALE_FACTOR * 10
69
+ FONT_SIZE = SCALE_FACTOR * 10
70
+ TITLE_WIDTH = SCALE_FACTOR * 60
71
+ CENTER_TITLE_HEIGHT = SPAN_HEIGHT * 6
72
+
73
+ WHITE = to_color_fmt(np.array([255, 255, 255, 255]))
74
+ BLACK = to_color_fmt(np.array([0, 0, 0, 255]))
75
+
76
+
77
+ class DrawCtx:
78
+ def __init__(self, d, oy, ox):
79
+ assert not isinstance(d, DrawCtx)
80
+ self.d = d
81
+ self.oy = oy
82
+ self.ox = ox
83
+
84
+ @classmethod
85
+ def from_base_ctx(cls, base_ctx, oy, ox):
86
+ assert isinstance(base_ctx, DrawCtx)
87
+ return cls(base_ctx.d, base_ctx.oy + oy, base_ctx.ox + ox)
88
+
89
+ def width(self):
90
+ return self.d.width
91
+
92
+ def height(self):
93
+ return self.d.height
94
+
95
+ def line(self, sy, sx, ey, ex, width=None):
96
+ self.d.append(draw.Line(
97
+ self.ox + sx,
98
+ self.oy + sy,
99
+ self.ox + ex,
100
+ self.oy + ey,
101
+ stroke='black',
102
+ stroke_width=width or BORDER_SIZE,
103
+ ))
104
+
105
+ def rect(self, sy, sx, h, w, color):
106
+ self.d.append(draw.Rectangle(
107
+ self.ox + sx,
108
+ self.oy + sy,
109
+ w, h,
110
+ fill=color,
111
+ shape_rendering="geometricPrecision",
112
+ ))
113
+
114
+ def rect_frame(self, sy, sx, h, w):
115
+ self.d.append(draw.Rectangle(
116
+ self.ox + sx,
117
+ self.oy + sy,
118
+ w, h,
119
+ fill="none",
120
+ stroke=BLACK,
121
+ stroke_width=BORDER_SIZE,
122
+ ))
123
+
124
+ def text(self, y, x, text, anchor="middle", font_scale=1, fill='black'):
125
+ font_size = FONT_SIZE * font_scale
126
+ tl = len(text) * font_size // 2
127
+ self.d.append(draw.Text(
128
+ text, font_size,
129
+ self.ox + x,
130
+ # Magic 3 to make it vertical center
131
+ self.oy + y + font_size - 3,
132
+ textLength=tl, lengthAdjust='spacing',
133
+ text_anchor=anchor,
134
+ font_family="Times New Roman",
135
+ fill=fill,
136
+ # font_style="oblique",
137
+ # font_family="Computer Modern Roman",
138
+ ))
139
+
140
+
141
+ def change_color_sat(c, percentage):
142
+ c = c.astype(float) / 255.0
143
+ (h, s, v) = colorsys.rgb_to_hsv(c[0], c[1], c[2])
144
+ s *= percentage
145
+ r, g, b = colorsys.hsv_to_rgb(h, s, v)
146
+ c = np.array([r, g, b]) * 255
147
+ return c.astype(int)
148
+
149
+
150
+ def draw_experiment_and_schedule(exp_events, sched_events, output_filename, tail=10):
151
+ exp_canvas_info = CanvasInfo(exp_events, tail, 0)
152
+ sched_canvas_info = CanvasInfo(sched_events, tail, 0, False)
153
+ width = max(exp_canvas_info.get_canvas_size()[1], sched_canvas_info.get_canvas_size()[1])
154
+ height = exp_canvas_info.get_canvas_size()[0] + sched_canvas_info.get_canvas_size()[0]
155
+
156
+ include_w = True
157
+
158
+ # d = draw.Drawing(width, sched_canvas_info.get_canvas_size()[0], origin="top-left")
159
+ d = draw.Drawing(width, height, origin="top-left")
160
+ ctx = DrawCtx(d, 0, 0)
161
+ plot_events(ctx, sched_events, "", sched_canvas_info, include_w, include_o=False, include_info=False)
162
+ # plot_events(ctx, sched_events, "", sched_canvas_info, include_w, include_o=False)
163
+ # d.save_svg("pics/schedule.svg")
164
+
165
+ # d = draw.Drawing(width, sched_canvas_info.get_canvas_size()[0], origin="top-left")
166
+ # exp_ctx = DrawCtx(d, 0, 0)
167
+ exp_ctx = DrawCtx.from_base_ctx(ctx, sched_canvas_info.get_canvas_size()[0], 0)
168
+ plot_events(exp_ctx, exp_events, "", exp_canvas_info, include_w, include_o=True)
169
+ # plot_events(exp_ctx, exp_events, "", exp_canvas_info, include_w, include_o=True)
170
+ d.save_svg(output_filename)
171
+
172
+
173
+ def draw_events(events, output_filename, include_w=True, include_o=True, tail=50):
174
+ canvas_info = CanvasInfo(events, tail, center_title_height=0, enable_info=True)
175
+ max_len = canvas_info.max_len
176
+ # height = canvas_info.height
177
+ # info_height = canvas_info.info_height
178
+ height, width = canvas_info.get_canvas_size()
179
+
180
+ d = draw.Drawing(width, height, origin="top-left")
181
+ ctx = DrawCtx(d, 0, 0)
182
+
183
+ plot_events(ctx, events, "", canvas_info, include_w, include_o)
184
+ d.save_svg(output_filename)
185
+
186
+
187
+ class CanvasInfo:
188
+ def __init__(self, events, tail, center_title_height=CENTER_TITLE_HEIGHT, enable_info=True):
189
+ last_time = max(max([e["completion_time"] for e in dev_evs]) for dev_evs in events)
190
+ self.max_len = (last_time + TIME_PER_UNIT - 1) // TIME_PER_UNIT + tail
191
+
192
+ self.height = SPAN_HEIGHT * len(events) + BORDER_SIZE * (len(events) + 1)
193
+ color_text_row_height = int(SPAN_HEIGHT * 1.6)
194
+ self.color_text_height = color_text_row_height + BORDER_SIZE
195
+ self.info_height = SPAN_HEIGHT + color_text_row_height + 3 * BORDER_SIZE
196
+ if not enable_info:
197
+ self.info_height /= 2
198
+ self.center_title_height = center_title_height
199
+ # self.center_title_height = 0
200
+
201
+ def get_canvas_size(self):
202
+ # height, width
203
+ return self.height + self.info_height + self.center_title_height, self.max_len + TITLE_WIDTH
204
+
205
+
206
+ def plot_events(ctx, events, title_text: str, canvas_info: CanvasInfo, include_w=True, include_o=True, include_info=True):
207
+ max_len = canvas_info.max_len
208
+ height = canvas_info.height
209
+ color_text_height = canvas_info.color_text_height
210
+ info_height = canvas_info.info_height
211
+
212
+ data_ctx = DrawCtx.from_base_ctx(ctx, 0, TITLE_WIDTH)
213
+
214
+ for i, evs in enumerate(events):
215
+ h = i * SPAN_HEIGHT + (i + 1) * BORDER_SIZE
216
+ for e in evs:
217
+ start = BORDER_SIZE + e["start_time"] // TIME_PER_UNIT
218
+ end = BORDER_SIZE + e["completion_time"] // TIME_PER_UNIT
219
+ if start == end or not ENABLE_EDGE_BLUR:
220
+ plot_span(data_ctx, start, end, h, COLOR_MAP[e["type"]])
221
+ else:
222
+ plot_span(data_ctx, start + 1, end - 1, h, COLOR_MAP[e["type"]])
223
+ # plot_span(data_ctx, start, end - 1, h, COLOR_MAP[e["type"]])
224
+ c = change_color_sat(
225
+ COLOR_VALUE_MAP[e["type"]],
226
+ (e["start_time"] / TIME_PER_UNIT) % 1.0)
227
+ plot_span(data_ctx, start, start + 1, h, to_color_fmt(c))
228
+ c = change_color_sat(
229
+ COLOR_VALUE_MAP[e["type"]],
230
+ (e["completion_time"] / TIME_PER_UNIT) % 1.0)
231
+ plot_span(data_ctx, end - 1, end, h, to_color_fmt(c))
232
+
233
+ if ENABLE_BATCH_ID:
234
+ minibatch = str(e["minibatch"])
235
+ center = (start + end) // 2
236
+ data_ctx.text(h, center, minibatch, font_scale=0.7, fill='black' if e["chunk"] == 0 else 'white')
237
+ if ENABLE_BORDER:
238
+ data_ctx.line(h+SPAN_HEIGHT, 0, h+SPAN_HEIGHT+BORDER_SIZE, max_len - 1)
239
+
240
+ if ENABLE_BORDER:
241
+ data_ctx.line(0, 0, 0, max_len - 1)
242
+ data_ctx.line(0, 0, height, 0)
243
+ data_ctx.line(0, max_len - 1, height, max_len - 1)
244
+
245
+ dev_title_ctx = DrawCtx.from_base_ctx(ctx, 0, 0)
246
+ ndev = len(events)
247
+ add_devices(dev_title_ctx, ndev)
248
+
249
+ if not include_info:
250
+ return
251
+
252
+ info_height = ndev * SPAN_HEIGHT + (ndev + 1) * BORDER_SIZE
253
+ info_ctx = DrawCtx.from_base_ctx(ctx, info_height, 0)
254
+ add_info(info_ctx, color_text_height, include_w, include_o)
255
+
256
+ if title_text:
257
+ center_title_ctx = DrawCtx.from_base_ctx(info_ctx, canvas_info.info_height, 0)
258
+ add_center_title(center_title_ctx, title_text)
259
+
260
+
261
+ def plot_span(ctx, start, end, h, color, ):
262
+ ctx.rect(h, start, SPAN_HEIGHT, end - start, color)
263
+ if ENABLE_BORDER:
264
+ ctx.rect_frame(h-BORDER_SIZE, start, SPAN_HEIGHT + BORDER_SIZE, end - start)
265
+
266
+
267
+ def add_devices(ctx, devs):
268
+ for i in range(devs):
269
+ h = i * SPAN_HEIGHT + (i + 1) * BORDER_SIZE
270
+ ctx.text(h, 6 * SCALE_FACTOR, "Device {}".format(i), "left")
271
+
272
+
273
+ def add_info(ctx, color_text_height, include_w=True, include_o=True):
274
+ div = 4 + int(include_w) + int(include_o)
275
+ f_start = ctx.width() // div
276
+ b_start = ctx.width() // div * 2
277
+ w_start = ctx.width() // div * 3
278
+ o_start = ctx.width() // div * 4
279
+
280
+ block_w = 25 * SCALE_FACTOR
281
+ plot_span(ctx, f_start, f_start+block_w, color_text_height + BORDER_SIZE, COLOR_MAP["F"])
282
+ plot_span(ctx, b_start, b_start+block_w, color_text_height + BORDER_SIZE, COLOR_MAP["B"])
283
+ if include_w:
284
+ plot_span(ctx, w_start, w_start+block_w, color_text_height + BORDER_SIZE, COLOR_MAP["W"])
285
+ if include_o:
286
+ plot_span(ctx, o_start, o_start+block_w, color_text_height + BORDER_SIZE, COLOR_MAP["Optimizer"])
287
+
288
+ ctx.text(0, 6 * SCALE_FACTOR, "Time", "left")
289
+ draw_arrow(ctx, SPAN_HEIGHT // 2 + BORDER_SIZE + 1, 65 * SCALE_FACTOR, 50 * SCALE_FACTOR)
290
+
291
+ block_w = 30 * SCALE_FACTOR
292
+ ctx.text(color_text_height, f_start + block_w, "F", "left")
293
+ ctx.text(color_text_height, b_start + block_w,
294
+ "B", "left")
295
+ if include_w:
296
+ ctx.text(color_text_height, w_start + block_w, "W", "left")
297
+ if include_o:
298
+ ctx.text(color_text_height, o_start + block_w, "Optimizer Step", "left")
299
+
300
+
301
+ def add_center_title(ctx: DrawCtx, text):
302
+ ctx.text(CENTER_TITLE_HEIGHT / 4, ctx.width() / 2,
303
+ text, "middle", 2)
304
+
305
+
306
+ def draw_arrow(ctx: DrawCtx, start_y, start_x, width, thickness=2):
307
+ b = thickness * (SCALE_FACTOR // 2)
308
+ ctx.line(start_y, start_x, start_y, start_x + width, b)
309
+ ctx.line(start_y, start_x + width, start_y - 3*b, start_x + width - 3*b)
310
+ ctx.line(start_y, start_x + width, start_y + 3*b, start_x + width - 3*b)
311
+
312
+
313
+ def render_manual_graph(data, longest_time, enable_batch_id = False):
314
+ global ENABLE_BORDER
315
+ global ENABLE_BATCH_ID
316
+ ENABLE_BORDER = True
317
+ ENABLE_BATCH_ID = enable_batch_id
318
+ fbw_types = {"F", "B", "W", "Optimizer"}
319
+ start = 0
320
+ end = None
321
+ time_scale= 1024 / longest_time * TIME_PER_UNIT
322
+ events = [[{
323
+ "type": e.type,
324
+ "start_time": int(max(e.start_time - start, 0)) * time_scale,
325
+ "completion_time": int(e.completion_time - start) * time_scale,
326
+ "minibatch": e.minibatch,
327
+ "chunk": e.chunk if hasattr(e, "chunk") else 0,
328
+ } for e in dev_evs
329
+ if e.type in fbw_types and filter_time(e, start, end)
330
+ ] for dev_evs in data]
331
+ # events = load_json_data("std-schedule.json")
332
+ # global TIME_PER_UNIT
333
+ # global ENABLE_BATCH_ID
334
+ # global ENABLE_BORDER
335
+ # global SCALE_FACTOR
336
+ # SCALE_FACTOR = 8
337
+ # ENABLE_BATCH_ID = False
338
+ # ENABLE_BORDER = False
339
+ # TIME_PER_UNIT *= 7
340
+ #events = load_json_data("no-bb-schedule.json")
341
+
342
+ path = os.path.join(tempfile.mkdtemp(), 'a.svg')
343
+ draw_events(events, path, include_w=True, include_o=False, tail=50)
344
+ return path
345
+
346
+
347
+ def render_experiment_graph():
348
+ global ENABLE_BORDER
349
+ global ENABLE_BATCH_ID
350
+ global TIME_PER_UNIT
351
+ ENABLE_BORDER = False
352
+ ENABLE_BATCH_ID = False
353
+ TIME_PER_UNIT = 200 // SCALE_FACTOR
354
+ TIME_PER_UNIT *= 12000
355
+ start_time = 1100000000 + 10000000
356
+ # iter_time = 1600000000
357
+ iter_time = 1290000000
358
+ end_time = start_time + iter_time
359
+ exp_events = load_json_data("20-09-zero/zero-events.json", start_time, end_time)
360
+ # draw_events(events, "pics/experiment.svg")
361
+ sched_events = load_json_data("schedule.json", time_scale=1000)
362
+ draw_experiment_and_schedule(exp_events, sched_events, "pics/exp.svg")
363
+ # draw_events(events, "pics/schedule.svg", include_w=True, include_o=False)
364
+
365
+
366
+ # render_manual_graph()
367
+ # render_experiment_graph()