Rajarshi Roy commited on
Commit
08949f8
·
verified ·
1 Parent(s): b944a57

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +381 -0
app.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def convert_google_sheet_url(url):
3
+ # Regular expression to match and capture the necessary part of the URL
4
+ pattern = r'https://docs\.google\.com/spreadsheets/d/([a-zA-Z0-9-_]+)(/edit#gid=(\d+)|/edit.*)?'
5
+
6
+ # Replace function to construct the new URL for CSV export
7
+ # If gid is present in the URL, it includes it in the export URL, otherwise, it's omitted
8
+ replacement = lambda m: f'https://docs.google.com/spreadsheets/d/{m.group(1)}/export?' + (f'gid={m.group(3)}&' if m.group(3) else '') + 'format=csv'
9
+
10
+ # Replace using regex
11
+ new_url = re.sub(pattern, replacement, url)
12
+
13
+ return new_url
14
+
15
+ # Replace with your modified URL
16
+ # url = "https://docs.google.com/spreadsheets/d/1dlTjKJrGVwRDU8m-hT53IdSluRAsWXftnx5uRqnq4yE/edit?gid=0#gid=0"
17
+ url = "https://docs.google.com/spreadsheets/d/1MY0-DOitMZGnib73BAaSKg0TI7i5V1CXP8dF6jAgKWc/edit?gid=293606167#gid=293606167"
18
+
19
+ new_url = convert_google_sheet_url(url)
20
+
21
+
22
+ df = pd.read_csv(new_url)
23
+
24
+ # Set 'Categories' column as index
25
+ df1 = df.copy()
26
+ df1.set_index('Categories', inplace=True)
27
+
28
+ transposed_df = df.transpose()
29
+ transposed_df.columns = transposed_df.iloc[0]
30
+ df = transposed_df.drop(["Categories"])
31
+
32
+
33
+ df = df.fillna("[]")
34
+ df1 = df1.fillna("[]")
35
+
36
+
37
+ # Convert the string representation of lists into actual lists for all relevant columns
38
+ for col in df.columns: # Skip the first column which is 'Categories'
39
+ df[col] = df[col].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
40
+
41
+
42
+ # Convert the string representation of lists into actual lists for all relevant columns
43
+ for col in df1.columns: # Skip the first column which is 'Categories'
44
+ df1[col] = df1[col].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
45
+
46
+
47
+ cols = df.columns
48
+
49
+ # Get the specific column while filtering out empty cells
50
+ column_data = df[cols[0]]
51
+
52
+ # Filter out the empty lists ([])
53
+ filtered_column_data = column_data[column_data.apply(lambda x: x != [])]
54
+
55
+
56
+
57
+ def get_score(avg_kl_div,kl_div,missing,extra,common):
58
+ Wc=1
59
+ Wm=1.5
60
+ We=1.5
61
+ WeE=(We*extra)**2
62
+ WeM=(Wm*missing)**2
63
+ WeC=(We*common)**2
64
+ if kl_div==-1:
65
+ kl_div=avg_kl_div
66
+ kl_div_factor=kl_div/avg_kl_div
67
+ ans=kl_div_factor*(((WeE+WeM)/WeC)-2)# (e**2 -c**2)/c**2 +(m**2-c**2)/c**2 => (0-1)*[((e**2+m**2)/c**2 -2)] => ((rank*y/a)m(m+1)/2))
68
+ return ans
69
+ def get_individual_score(avg_kl_div,kl_div,e_or_m,common):
70
+ if kl_div==-1:
71
+ kl_div=avg_kl_div
72
+ kl_div_factor=kl_div/avg_kl_div
73
+ weight=1.5
74
+ ans=avg_kl_div + ((1+(e_or_m/common))*(((e_or_m)*(e_or_m+1)))/2)**0.5 # X +- [(1+b/a)*n**2*y]
75
+ # ans = kl_div_factor*((((weight*e_or_m)**2)/(common**2))-1)
76
+ return ans
77
+
78
+
79
+ def get_entity_scores(ans4):
80
+ # Calculate average KL divergence
81
+ tt = 0
82
+ avg_kl_div = 0
83
+ for t in ans4:
84
+ if t[0] != -1:
85
+ avg_kl_div += t[0]
86
+ tt += 1
87
+
88
+ # Avoid division by zero
89
+ if tt > 0:
90
+ avg_kl_div /= tt
91
+ else:
92
+ avg_kl_div = 0
93
+
94
+ extra_entity_score = []
95
+ missing_entity_score = []
96
+
97
+ for t in ans4:
98
+ extra_entity_score.append(get_individual_score(avg_kl_div, t[0], t[2], t[3]))
99
+ missing_entity_score.append(get_individual_score(avg_kl_div, t[0], t[1], t[3]))
100
+
101
+ extra_entity_score.sort()
102
+ missing_entity_score.sort()
103
+
104
+ return (
105
+ missing_entity_score[:int(0.950 * len(missing_entity_score))],
106
+ extra_entity_score[:int(0.95 * len(extra_entity_score))]
107
+ )
108
+
109
+
110
+ compare = df.columns[0]
111
+ column_data = df[compare]
112
+
113
+ # Filter out the empty lists ([])
114
+ filtered_column_data = column_data[column_data.apply(lambda x: x != [])]
115
+
116
+ # Display the filtered column data
117
+ variables = filtered_column_data.to_list()
118
+ models = filtered_column_data.index.to_list()
119
+
120
+ color_schemes = [
121
+ '#d60000', # Red
122
+ '#2f5282', # Navy Blue
123
+ '#f15cd8', # Pink
124
+ '#66abb7', # Light Teal
125
+ '#ce7391', # Rose
126
+ '#6bdb7a', # Light Green
127
+ '#ea8569', # Coral
128
+ '#b36cc9', # Lavender
129
+ '#ffd700', # Gold
130
+ '#ff7f0e', # Orange
131
+ '#1f77b4', # Blue
132
+ '#2ca02c', # Green
133
+ ]
134
+
135
+
136
+ colors = color_schemes[:len(models)]
137
+
138
+ values_dict = {model: var for var, model in zip(variables, models)}
139
+ color_dict = {model: color for model, color in zip(models, colors)}
140
+
141
+
142
+ # plot_grouped_3d_kde(values_dict, models, color_dict, compare)
143
+
144
+
145
+ import numpy as np
146
+ import plotly.graph_objects as go
147
+ from scipy.stats import gaussian_kde
148
+ import plotly.express as px
149
+
150
+
151
+
152
+ def adjust_kde_range(data, increment=25, threshold=0.00005):
153
+ kde = gaussian_kde(data)
154
+ min_x, max_x = min(data) - increment, max(data) + increment
155
+
156
+ # Keep expanding the range until both tails get close to zero
157
+ while True:
158
+ x_values = np.linspace(min_x, max_x, 1000)
159
+ y_values = kde(x_values)
160
+
161
+ # # Check the values at the tails
162
+ # print(y_values[0], y_values[-1])
163
+ # print(x_values[0], x_values[-1], "\n")
164
+
165
+ if y_values[0] < threshold and y_values[-1] < threshold:
166
+ break # Stop if both tails are below the threshold
167
+
168
+ # Extend the range
169
+ min_x -= increment
170
+ max_x += increment
171
+
172
+ return x_values, y_values
173
+
174
+
175
+ def compute_kde_ranges(missing_scores, extra_scores):
176
+ data1 = np.array(missing_scores)
177
+ data2 = -np.array(extra_scores) # Negate extra scores for alignment
178
+
179
+ # Compute KDE for missing scores with extended range
180
+ x_missing, y_missing = adjust_kde_range(data1)
181
+
182
+ # Compute KDE for extra scores with extended range
183
+ x_extra, y_extra = adjust_kde_range(data2)
184
+
185
+ # Calculate axis limits
186
+ Val_x_extra = [max(x_extra)]
187
+ Val_x_miss = [x_missing[np.argmax(y_missing)]]
188
+
189
+ peak_extra = max(y_extra)
190
+ peak_miss = max(y_missing)
191
+
192
+ # Calculate the x and y axis ranges
193
+ min_x = min(min(x_missing), min(x_extra))
194
+ max_x = max(max(x_missing), max(x_extra))
195
+ x_range = [min_x, max_x]
196
+
197
+ y_range = [-peak_extra, peak_miss * 1.25]
198
+
199
+ return x_missing, y_missing, x_extra, y_extra, x_range, y_range
200
+
201
+
202
+ def calculate_ticks(x_min, x_max, num_ticks=20):
203
+ # Calculate the total range
204
+ total_range = x_max - x_min
205
+
206
+ # Determine the interval between ticks
207
+ interval = total_range / (num_ticks - 1) # We need num_ticks - 1 intervals
208
+
209
+ # Generate tick values
210
+ ticks = np.arange(x_min, x_max + interval, interval)
211
+
212
+ return ticks
213
+
214
+
215
+
216
+
217
+ def plot_filled_surface(x, z, y_level, color):
218
+ """
219
+ Create a 3D mesh to fill the surface between the KDE curve and the 0-axis.
220
+ """
221
+ x_full = np.concatenate([x, x[::-1]]) # X-axis values, with reverse for baseline
222
+ z_full = np.concatenate([z, np.zeros_like(z)]) # Z-axis (KDE and baseline at 0)
223
+ y_full = np.full_like(x_full, y_level) # Flat Y plane (constant for each model)
224
+
225
+ num_pts = len(x)
226
+ i = np.arange(num_pts - 1)
227
+ j = i + 1
228
+ k = i + num_pts
229
+
230
+ i = np.concatenate([i, i + num_pts])
231
+ j = np.concatenate([j, j + num_pts])
232
+ k = np.concatenate([k, i[:len(i)//2]])
233
+
234
+ return go.Mesh3d(
235
+ x=x_full, y=y_full, z=z_full,
236
+ i=i, j=j, k=k,
237
+ opacity=0.5,
238
+ color=color,
239
+ showscale=False,
240
+ legendgroup='filling'
241
+ )
242
+
243
+
244
+
245
+ def plot_kde_3d(values_dict, models, color_dict, compare):
246
+
247
+ # values_dict, models, color_dict, compare = (values_dict, models, color_dict, 'Comparison Title')
248
+ fig = go.Figure()
249
+
250
+ model_y_positions = {model: i for i, model in enumerate(models)}
251
+
252
+ x_ranges = []
253
+ y_ranges = []
254
+
255
+ for model in models:
256
+ missing_scores, extra_scores = get_entity_scores(values_dict[model])
257
+
258
+ # Compute KDE and ranges for missing and extra scores
259
+ x_m, y_m, x_e, y_e, x_range, y_range = compute_kde_ranges(missing_scores, extra_scores)
260
+
261
+ # Append ranges for global limits
262
+ x_ranges.append(x_range)
263
+ y_ranges.append(y_range)
264
+
265
+ # Get color for this model
266
+ color = color_dict.get(model, 'rgba(0, 0, 0, 0.5)') # Default color if not found
267
+
268
+ # Create filled surfaces between KDE curves and zero line
269
+ fig.add_trace(plot_filled_surface(x_m, y_m, model_y_positions[model], color))
270
+ fig.add_trace(plot_filled_surface(x_e, -y_e, model_y_positions[model], color))
271
+
272
+ # Plot the KDE lines (for visualization of the curves)
273
+ fig.add_trace(go.Scatter3d(
274
+ x=x_m,
275
+ y=[model_y_positions[model]] * len(x_m),
276
+ z=y_m,
277
+ mode='lines',
278
+ line=dict(color='blue'),
279
+ showlegend=False
280
+ ))
281
+
282
+ fig.add_trace(go.Scatter3d(
283
+ x=x_e,
284
+ y=[model_y_positions[model]] * len(x_e),
285
+ z=-y_e,
286
+ mode='lines',
287
+ line=dict(color='red'),
288
+ showlegend=False # Hide legend for extra scores to combine with missing scores
289
+ ))
290
+
291
+ # Compute global x and y limits
292
+ x_min = min(r[0] for r in x_ranges)
293
+ x_max = max(r[1] for r in x_ranges)
294
+ y_min = min(r[0] for r in y_ranges)
295
+ y_max = max(r[1] for r in y_ranges)
296
+
297
+ # Define x, y, z axis tick intervals
298
+ x_ticks = calculate_ticks(np.floor(x_min), np.ceil(x_max))
299
+ y_ticks = list(model_y_positions.values())
300
+ z_ticks = calculate_ticks(y_min, y_max)
301
+
302
+ # Add a line through the 0-axis of density for each model
303
+ for model in models:
304
+ color = color_dict.get(model, 'rgba(0, 0, 0, 0.5)')
305
+ fig.add_trace(go.Scatter3d(
306
+ x=[x_min, x_max],
307
+ y=[model_y_positions[model], model_y_positions[model]],
308
+ z=[0, 0],
309
+ mode='lines',
310
+ # line=dict(color=color, width=2, dash='dash'),
311
+ line=dict(color=color),
312
+ name=model,
313
+
314
+ # showlegend=False
315
+ ))
316
+
317
+ # Update layout for 3D plot
318
+ fig.update_layout(
319
+ title=f'3D KDE Plots for {compare}',
320
+ scene=dict(
321
+ xaxis_title='Score',
322
+ yaxis_title='Model',
323
+ zaxis_title='Density',
324
+ xaxis=dict(
325
+ range=[x_min, x_max],
326
+ tickvals=x_ticks,
327
+ ticktext=[f'{tick:.2f}' for tick in x_ticks]
328
+ ),
329
+ yaxis=dict(
330
+ tickvals=y_ticks,
331
+ ticktext=[list(model_y_positions.keys())[list(model_y_positions.values()).index(tick)] for tick in y_ticks]
332
+ ),
333
+ zaxis=dict(
334
+ range=[y_min, y_max],
335
+ tickvals=z_ticks,
336
+ ticktext=[f'{tick:.4f}' for tick in z_ticks]
337
+ ),
338
+ camera=dict(
339
+ eye=dict(x=1.25, y=1.25, z=1.25)
340
+ )
341
+ ),
342
+ autosize=True,
343
+ width=1200*.75,
344
+ height=800*.75
345
+ )
346
+
347
+ # Save the plot as an HTML file
348
+ # plot = px.scatter(x=range(10), y=range(10))
349
+ filename = f"{compare}.html"
350
+ fig.write_html(filename)
351
+
352
+ # fig.show()
353
+
354
+ return fig
355
+
356
+
357
+
358
+ # Path to your saved HTML file
359
+ html_file_path = '3d_plot.html'
360
+ title = 'My 3D Plot'
361
+
362
+ def display_plot():
363
+ fig = plot_kde_3d(values_dict, models, color_dict, compare)
364
+ return fig
365
+
366
+
367
+ # Define the Gradio interface
368
+ interface = gr.Interface(
369
+ fn=display_plot,
370
+ inputs=[],
371
+ outputs=gr.Plot(),
372
+ title='Plotly 3D Plot in Gradio',
373
+ description='This app displays a 3D Plotly plot directly in the Gradio interface.',
374
+ live=False
375
+ )
376
+
377
+ # Launch the Gradio app
378
+ if __name__ == "__main__":
379
+ interface.launch()
380
+
381
+