|
import gradio as gr |
|
import auto_schedule |
|
import v_schedule |
|
|
|
def greet(name, is_morning, temperature): |
|
salutation = "Good morning" if is_morning else "Good evening" |
|
greeting = f"{salutation} {name}. It is {temperature} degrees today" |
|
celsius = (temperature - 32) * 5 / 9 |
|
return greeting, round(celsius, 2) |
|
|
|
def percentage(x): |
|
return f"{x*100:.2f}%" |
|
|
|
def get_schedule_time_and_image(result): |
|
result = [ |
|
list(filter(lambda x: x.type in {'F', 'B', 'W'}, r)) for r in result |
|
] |
|
time = max( |
|
[ |
|
max([x.completion_time for x in stage]) - min([x.start_time for x in stage]) for stage in result |
|
] |
|
) |
|
return time, None |
|
|
|
def calculate(p, m, f, b, w, c, mem): |
|
baseline_time=(f+b+w)*m + (f+b+w+c)*(p-1) |
|
baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1) |
|
baseline_acceleration=percentage(0) |
|
baseline_image=None |
|
|
|
|
|
zb_result = auto_schedule.auto_schedule(p, m, auto_schedule.GraphConfig( |
|
cost_f=f, |
|
cost_b=b, |
|
cost_w=w, |
|
cost_comm=c, |
|
max_mem=mem * 2, |
|
print_scaling=1000 |
|
)) |
|
zb_time,zb_image=get_schedule_time_and_image(zb_result) |
|
|
|
zb_bubble=percentage(zb_time/(f+b+w)/m - 1) |
|
zb_acceleration=percentage(baseline_time/zb_time - 1) |
|
|
|
zbv_graph = v_schedule.PipelineGraph( |
|
n_stage=p, |
|
n_micro=m, |
|
f_cost=f/2, |
|
b_cost=b/2, |
|
w_cost=w/2, |
|
c_cost=c, |
|
f_mem=2, |
|
b_mem=-1, |
|
w_mem=-1, |
|
max_mem=mem * 4, |
|
) |
|
zbv_result = zbv_graph.get_v_schedule() |
|
|
|
zbv_time,zbv_image = get_schedule_time_and_image(zbv_result) |
|
zbv_bubble=percentage(zbv_time/(f+b+w)/m - 1) |
|
zbv_acceleration=percentage(baseline_time/zbv_time - 1) |
|
zbv_image=None |
|
|
|
return [baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image] |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("Zero bubble pipeline parallel bubble calculator") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
with gr.Group(): |
|
gr.Markdown("Basic Parameters") |
|
with gr.Row(): |
|
p=gr.Number(label="Number of stages (p)", value=4, interactive=True, precision=0) |
|
m=gr.Number(label="Number of microbatches (m)", value=12, interactive=True, precision=0) |
|
with gr.Column(scale=2): |
|
with gr.Group(): |
|
gr.Markdown("Costs. All costs are used as integers. For ZBV schedules, this is the time of two virtual stages on a stage combined.") |
|
with gr.Row(): |
|
f=gr.Number(label="Time of F", value=8, interactive=True, precision=0) |
|
b=gr.Number(label="Time of B", value=8, interactive=True, precision=0) |
|
w=gr.Number(label="Time of W", value=8, interactive=True, precision=0) |
|
c=gr.Number(label="Time of one P2P communication", value=1, interactive=True, precision=0) |
|
with gr.Group(): |
|
gr.Markdown("Activation memory limit.") |
|
def update_mem(p, s, mem): |
|
print("update") |
|
if s=="custom": |
|
return mem |
|
return p*int(s[:-1]) |
|
memsel=gr.Radio(choices=["1p", "2p", "3p", "custom"], value="1p") |
|
mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For ZBV schedules, this is relative to two virtual stages on a stage combined.", value=p.value, interactive=True, precision=0) |
|
memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem) |
|
p.change(update_mem, inputs=[p, memsel, mem], outputs=mem) |
|
|
|
button=gr.Button("Calculate") |
|
|
|
with gr.Group(): |
|
gr.Markdown("1F1B") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
baseline_time=gr.Textbox("", label="Longest Stage Time") |
|
baseline_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).") |
|
baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B") |
|
with gr.Column(scale=4): |
|
baseline_image=gr.Image(None, interactive=False, label="Schedule Image") |
|
|
|
with gr.Group(): |
|
gr.Markdown("Zero Bubble Schedule") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
zb_time=gr.Textbox("", label="Longest Stage Time") |
|
zb_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).") |
|
zb_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B") |
|
with gr.Column(scale=4): |
|
zb_image=gr.Image(None, interactive=False, label="Schedule Image") |
|
with gr.Group(): |
|
gr.Markdown("Zero Bubble V Schedule") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
zbv_time=gr.Textbox("", label="Longest Stage Time") |
|
zbv_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).") |
|
zbv_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B") |
|
with gr.Column(scale=4): |
|
zbv_image=gr.Image(None, interactive=False, label="Schedule Image") |
|
button.click(calculate, inputs=[p, m, f, b, w, c, mem], outputs=[baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image]) |
|
demo.launch() |
|
|