lvwerra HF staff commited on
Commit
17ab7e0
·
verified ·
1 Parent(s): baa6366

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ import gradio as gr
5
+ import urllib.parse
6
+ import plotly.graph_objects as go
7
+
8
+
9
+
10
+ def read_google_sheet(sheet_id, sheet_name):
11
+ # URL encode the sheet name
12
+ encoded_sheet_name = urllib.parse.quote(sheet_name)
13
+
14
+ # Construct the base URL
15
+ base_url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/gviz/tq?tqx=out:csv&sheet={encoded_sheet_name}"
16
+
17
+ try:
18
+ # Read the sheet into a pandas DataFrame
19
+ df = pd.read_csv(base_url)
20
+ return df
21
+ except Exception as e:
22
+ print(f"An error occurred: {e}")
23
+ return None
24
+
25
+ # Function to generate tick values and labels
26
+ def log2_ticks(values):
27
+ min_val, max_val = np.floor(values.min()), np.ceil(values.max())
28
+ print(max_val, min_val)
29
+ tick_vals = np.arange(min_val, max_val+1)
30
+ tick_text = [f"{2**val:.0f}" for val in tick_vals]
31
+ return tick_vals, tick_text
32
+
33
+ # Load data
34
+ sheet_id = "1g07tdGf9ocOZ8XZgLGepI5Q4u6ZH961J0T9O9P64rYw"
35
+ sheet_names = [f"{i} node" if i == 1 else f"{i} nodes" for i in [1, 8]]
36
+
37
+ df = pd.concat([read_google_sheet(sheet_id, sheet_name) for sheet_name in sheet_names])
38
+ df = df.rename(columns={"micro_batch_size":"mbs", "batch_accumulation_per_replica": "gradacc"})
39
+ df["tok/s/gpu"] = df["tok/s/gpu"].replace(-1, 0)
40
+ df["throughput"] = df["tok/s/gpu"]*df["nnodes"]*8
41
+
42
+
43
+
44
+ def get_figure(nodes, hide_nans):
45
+
46
+ # Create a temporary DataFrame with only the rows where nnodes is 8
47
+ df_tmp = df[df["nnodes"]==nodes].reset_index(drop=True)
48
+
49
+ if hide_nans:
50
+ df_tmp = df_tmp.dropna()
51
+
52
+ # Apply log2 scale to all columns except throughput
53
+ log_columns = ['dp', 'tp', 'pp', 'mbs', 'gradacc']
54
+ for col in log_columns:
55
+ df_tmp[f'log_{col}'] = np.log2(df_tmp[col])
56
+
57
+
58
+
59
+ # Generate dimensions list
60
+ dimensions = []
61
+ for col in log_columns:
62
+ ticks, labels = log2_ticks(df_tmp[f'log_{col}'])
63
+ dimensions.append(
64
+ dict(range = [df_tmp[f'log_{col}'].min(), df_tmp[f'log_{col}'].max()],
65
+ label = col,
66
+ values = df_tmp[f'log_{col}'],
67
+ tickvals = ticks,
68
+ ticktext = labels)
69
+ )
70
+
71
+ # Add throughput dimension (not log-scaled)
72
+ dimensions.append(
73
+ dict(range = [df_tmp['throughput'].min(), df_tmp['throughput'].max()],
74
+ label = 'throughput',
75
+ values = df_tmp['throughput'])
76
+ )
77
+
78
+ fig = go.Figure(data=
79
+ go.Parcoords(
80
+ line = dict(color = df_tmp['throughput'],
81
+ colorscale = 'GnBu',
82
+ showscale = True,
83
+ cmin = df_tmp['throughput'].min(),
84
+ cmax = df_tmp['throughput'].max()),
85
+ dimensions = dimensions
86
+ )
87
+ )
88
+
89
+ # Update the layout if needed
90
+ fig.update_layout(
91
+ title = "3D parallel setup throughput ",
92
+ plot_bgcolor = 'white',
93
+ paper_bgcolor = 'white'
94
+ )
95
+
96
+
97
+ return fig
98
+
99
+
100
+ with gr.Blocks() as demo:
101
+ title = gr.Markdown("# 3D parallel benchmark")
102
+ with gr.Row():
103
+ nnodes = gr.Dropdown(choices=[1, 8], label="Number of nodes", value=8)
104
+ hide_nan = gr.Dropdown(choices=[False, True], label="Hide NaNs", value=False)
105
+
106
+ plot = gr.Plot()
107
+ demo.load(get_figure, [nnodes, hide_nan], [plot])
108
+ nnodes.change(get_figure, [nnodes, hide_nan], [plot])
109
+ hide_nan.change(get_figure, [nnodes, hide_nan], [plot])
110
+
111
+ demo.launch(show_api=False)