Spaces:
Sleeping
Sleeping
Rajarshi Roy
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def convert_google_sheet_url(url):
|
3 |
+
# Regular expression to match and capture the necessary part of the URL
|
4 |
+
pattern = r'https://docs\.google\.com/spreadsheets/d/([a-zA-Z0-9-_]+)(/edit#gid=(\d+)|/edit.*)?'
|
5 |
+
|
6 |
+
# Replace function to construct the new URL for CSV export
|
7 |
+
# If gid is present in the URL, it includes it in the export URL, otherwise, it's omitted
|
8 |
+
replacement = lambda m: f'https://docs.google.com/spreadsheets/d/{m.group(1)}/export?' + (f'gid={m.group(3)}&' if m.group(3) else '') + 'format=csv'
|
9 |
+
|
10 |
+
# Replace using regex
|
11 |
+
new_url = re.sub(pattern, replacement, url)
|
12 |
+
|
13 |
+
return new_url
|
14 |
+
|
15 |
+
# Replace with your modified URL
|
16 |
+
# url = "https://docs.google.com/spreadsheets/d/1dlTjKJrGVwRDU8m-hT53IdSluRAsWXftnx5uRqnq4yE/edit?gid=0#gid=0"
|
17 |
+
url = "https://docs.google.com/spreadsheets/d/1MY0-DOitMZGnib73BAaSKg0TI7i5V1CXP8dF6jAgKWc/edit?gid=293606167#gid=293606167"
|
18 |
+
|
19 |
+
new_url = convert_google_sheet_url(url)
|
20 |
+
|
21 |
+
|
22 |
+
df = pd.read_csv(new_url)
|
23 |
+
|
24 |
+
# Set 'Categories' column as index
|
25 |
+
df1 = df.copy()
|
26 |
+
df1.set_index('Categories', inplace=True)
|
27 |
+
|
28 |
+
transposed_df = df.transpose()
|
29 |
+
transposed_df.columns = transposed_df.iloc[0]
|
30 |
+
df = transposed_df.drop(["Categories"])
|
31 |
+
|
32 |
+
|
33 |
+
df = df.fillna("[]")
|
34 |
+
df1 = df1.fillna("[]")
|
35 |
+
|
36 |
+
|
37 |
+
# Convert the string representation of lists into actual lists for all relevant columns
|
38 |
+
for col in df.columns: # Skip the first column which is 'Categories'
|
39 |
+
df[col] = df[col].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
|
40 |
+
|
41 |
+
|
42 |
+
# Convert the string representation of lists into actual lists for all relevant columns
|
43 |
+
for col in df1.columns: # Skip the first column which is 'Categories'
|
44 |
+
df1[col] = df1[col].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
|
45 |
+
|
46 |
+
|
47 |
+
cols = df.columns
|
48 |
+
|
49 |
+
# Get the specific column while filtering out empty cells
|
50 |
+
column_data = df[cols[0]]
|
51 |
+
|
52 |
+
# Filter out the empty lists ([])
|
53 |
+
filtered_column_data = column_data[column_data.apply(lambda x: x != [])]
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
def get_score(avg_kl_div,kl_div,missing,extra,common):
|
58 |
+
Wc=1
|
59 |
+
Wm=1.5
|
60 |
+
We=1.5
|
61 |
+
WeE=(We*extra)**2
|
62 |
+
WeM=(Wm*missing)**2
|
63 |
+
WeC=(We*common)**2
|
64 |
+
if kl_div==-1:
|
65 |
+
kl_div=avg_kl_div
|
66 |
+
kl_div_factor=kl_div/avg_kl_div
|
67 |
+
ans=kl_div_factor*(((WeE+WeM)/WeC)-2)# (e**2 -c**2)/c**2 +(m**2-c**2)/c**2 => (0-1)*[((e**2+m**2)/c**2 -2)] => ((rank*y/a)m(m+1)/2))
|
68 |
+
return ans
|
69 |
+
def get_individual_score(avg_kl_div,kl_div,e_or_m,common):
|
70 |
+
if kl_div==-1:
|
71 |
+
kl_div=avg_kl_div
|
72 |
+
kl_div_factor=kl_div/avg_kl_div
|
73 |
+
weight=1.5
|
74 |
+
ans=avg_kl_div + ((1+(e_or_m/common))*(((e_or_m)*(e_or_m+1)))/2)**0.5 # X +- [(1+b/a)*n**2*y]
|
75 |
+
# ans = kl_div_factor*((((weight*e_or_m)**2)/(common**2))-1)
|
76 |
+
return ans
|
77 |
+
|
78 |
+
|
79 |
+
def get_entity_scores(ans4):
|
80 |
+
# Calculate average KL divergence
|
81 |
+
tt = 0
|
82 |
+
avg_kl_div = 0
|
83 |
+
for t in ans4:
|
84 |
+
if t[0] != -1:
|
85 |
+
avg_kl_div += t[0]
|
86 |
+
tt += 1
|
87 |
+
|
88 |
+
# Avoid division by zero
|
89 |
+
if tt > 0:
|
90 |
+
avg_kl_div /= tt
|
91 |
+
else:
|
92 |
+
avg_kl_div = 0
|
93 |
+
|
94 |
+
extra_entity_score = []
|
95 |
+
missing_entity_score = []
|
96 |
+
|
97 |
+
for t in ans4:
|
98 |
+
extra_entity_score.append(get_individual_score(avg_kl_div, t[0], t[2], t[3]))
|
99 |
+
missing_entity_score.append(get_individual_score(avg_kl_div, t[0], t[1], t[3]))
|
100 |
+
|
101 |
+
extra_entity_score.sort()
|
102 |
+
missing_entity_score.sort()
|
103 |
+
|
104 |
+
return (
|
105 |
+
missing_entity_score[:int(0.950 * len(missing_entity_score))],
|
106 |
+
extra_entity_score[:int(0.95 * len(extra_entity_score))]
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
compare = df.columns[0]
|
111 |
+
column_data = df[compare]
|
112 |
+
|
113 |
+
# Filter out the empty lists ([])
|
114 |
+
filtered_column_data = column_data[column_data.apply(lambda x: x != [])]
|
115 |
+
|
116 |
+
# Display the filtered column data
|
117 |
+
variables = filtered_column_data.to_list()
|
118 |
+
models = filtered_column_data.index.to_list()
|
119 |
+
|
120 |
+
color_schemes = [
|
121 |
+
'#d60000', # Red
|
122 |
+
'#2f5282', # Navy Blue
|
123 |
+
'#f15cd8', # Pink
|
124 |
+
'#66abb7', # Light Teal
|
125 |
+
'#ce7391', # Rose
|
126 |
+
'#6bdb7a', # Light Green
|
127 |
+
'#ea8569', # Coral
|
128 |
+
'#b36cc9', # Lavender
|
129 |
+
'#ffd700', # Gold
|
130 |
+
'#ff7f0e', # Orange
|
131 |
+
'#1f77b4', # Blue
|
132 |
+
'#2ca02c', # Green
|
133 |
+
]
|
134 |
+
|
135 |
+
|
136 |
+
colors = color_schemes[:len(models)]
|
137 |
+
|
138 |
+
values_dict = {model: var for var, model in zip(variables, models)}
|
139 |
+
color_dict = {model: color for model, color in zip(models, colors)}
|
140 |
+
|
141 |
+
|
142 |
+
# plot_grouped_3d_kde(values_dict, models, color_dict, compare)
|
143 |
+
|
144 |
+
|
145 |
+
import numpy as np
|
146 |
+
import plotly.graph_objects as go
|
147 |
+
from scipy.stats import gaussian_kde
|
148 |
+
import plotly.express as px
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
def adjust_kde_range(data, increment=25, threshold=0.00005):
|
153 |
+
kde = gaussian_kde(data)
|
154 |
+
min_x, max_x = min(data) - increment, max(data) + increment
|
155 |
+
|
156 |
+
# Keep expanding the range until both tails get close to zero
|
157 |
+
while True:
|
158 |
+
x_values = np.linspace(min_x, max_x, 1000)
|
159 |
+
y_values = kde(x_values)
|
160 |
+
|
161 |
+
# # Check the values at the tails
|
162 |
+
# print(y_values[0], y_values[-1])
|
163 |
+
# print(x_values[0], x_values[-1], "\n")
|
164 |
+
|
165 |
+
if y_values[0] < threshold and y_values[-1] < threshold:
|
166 |
+
break # Stop if both tails are below the threshold
|
167 |
+
|
168 |
+
# Extend the range
|
169 |
+
min_x -= increment
|
170 |
+
max_x += increment
|
171 |
+
|
172 |
+
return x_values, y_values
|
173 |
+
|
174 |
+
|
175 |
+
def compute_kde_ranges(missing_scores, extra_scores):
|
176 |
+
data1 = np.array(missing_scores)
|
177 |
+
data2 = -np.array(extra_scores) # Negate extra scores for alignment
|
178 |
+
|
179 |
+
# Compute KDE for missing scores with extended range
|
180 |
+
x_missing, y_missing = adjust_kde_range(data1)
|
181 |
+
|
182 |
+
# Compute KDE for extra scores with extended range
|
183 |
+
x_extra, y_extra = adjust_kde_range(data2)
|
184 |
+
|
185 |
+
# Calculate axis limits
|
186 |
+
Val_x_extra = [max(x_extra)]
|
187 |
+
Val_x_miss = [x_missing[np.argmax(y_missing)]]
|
188 |
+
|
189 |
+
peak_extra = max(y_extra)
|
190 |
+
peak_miss = max(y_missing)
|
191 |
+
|
192 |
+
# Calculate the x and y axis ranges
|
193 |
+
min_x = min(min(x_missing), min(x_extra))
|
194 |
+
max_x = max(max(x_missing), max(x_extra))
|
195 |
+
x_range = [min_x, max_x]
|
196 |
+
|
197 |
+
y_range = [-peak_extra, peak_miss * 1.25]
|
198 |
+
|
199 |
+
return x_missing, y_missing, x_extra, y_extra, x_range, y_range
|
200 |
+
|
201 |
+
|
202 |
+
def calculate_ticks(x_min, x_max, num_ticks=20):
|
203 |
+
# Calculate the total range
|
204 |
+
total_range = x_max - x_min
|
205 |
+
|
206 |
+
# Determine the interval between ticks
|
207 |
+
interval = total_range / (num_ticks - 1) # We need num_ticks - 1 intervals
|
208 |
+
|
209 |
+
# Generate tick values
|
210 |
+
ticks = np.arange(x_min, x_max + interval, interval)
|
211 |
+
|
212 |
+
return ticks
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
def plot_filled_surface(x, z, y_level, color):
|
218 |
+
"""
|
219 |
+
Create a 3D mesh to fill the surface between the KDE curve and the 0-axis.
|
220 |
+
"""
|
221 |
+
x_full = np.concatenate([x, x[::-1]]) # X-axis values, with reverse for baseline
|
222 |
+
z_full = np.concatenate([z, np.zeros_like(z)]) # Z-axis (KDE and baseline at 0)
|
223 |
+
y_full = np.full_like(x_full, y_level) # Flat Y plane (constant for each model)
|
224 |
+
|
225 |
+
num_pts = len(x)
|
226 |
+
i = np.arange(num_pts - 1)
|
227 |
+
j = i + 1
|
228 |
+
k = i + num_pts
|
229 |
+
|
230 |
+
i = np.concatenate([i, i + num_pts])
|
231 |
+
j = np.concatenate([j, j + num_pts])
|
232 |
+
k = np.concatenate([k, i[:len(i)//2]])
|
233 |
+
|
234 |
+
return go.Mesh3d(
|
235 |
+
x=x_full, y=y_full, z=z_full,
|
236 |
+
i=i, j=j, k=k,
|
237 |
+
opacity=0.5,
|
238 |
+
color=color,
|
239 |
+
showscale=False,
|
240 |
+
legendgroup='filling'
|
241 |
+
)
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
def plot_kde_3d(values_dict, models, color_dict, compare):
|
246 |
+
|
247 |
+
# values_dict, models, color_dict, compare = (values_dict, models, color_dict, 'Comparison Title')
|
248 |
+
fig = go.Figure()
|
249 |
+
|
250 |
+
model_y_positions = {model: i for i, model in enumerate(models)}
|
251 |
+
|
252 |
+
x_ranges = []
|
253 |
+
y_ranges = []
|
254 |
+
|
255 |
+
for model in models:
|
256 |
+
missing_scores, extra_scores = get_entity_scores(values_dict[model])
|
257 |
+
|
258 |
+
# Compute KDE and ranges for missing and extra scores
|
259 |
+
x_m, y_m, x_e, y_e, x_range, y_range = compute_kde_ranges(missing_scores, extra_scores)
|
260 |
+
|
261 |
+
# Append ranges for global limits
|
262 |
+
x_ranges.append(x_range)
|
263 |
+
y_ranges.append(y_range)
|
264 |
+
|
265 |
+
# Get color for this model
|
266 |
+
color = color_dict.get(model, 'rgba(0, 0, 0, 0.5)') # Default color if not found
|
267 |
+
|
268 |
+
# Create filled surfaces between KDE curves and zero line
|
269 |
+
fig.add_trace(plot_filled_surface(x_m, y_m, model_y_positions[model], color))
|
270 |
+
fig.add_trace(plot_filled_surface(x_e, -y_e, model_y_positions[model], color))
|
271 |
+
|
272 |
+
# Plot the KDE lines (for visualization of the curves)
|
273 |
+
fig.add_trace(go.Scatter3d(
|
274 |
+
x=x_m,
|
275 |
+
y=[model_y_positions[model]] * len(x_m),
|
276 |
+
z=y_m,
|
277 |
+
mode='lines',
|
278 |
+
line=dict(color='blue'),
|
279 |
+
showlegend=False
|
280 |
+
))
|
281 |
+
|
282 |
+
fig.add_trace(go.Scatter3d(
|
283 |
+
x=x_e,
|
284 |
+
y=[model_y_positions[model]] * len(x_e),
|
285 |
+
z=-y_e,
|
286 |
+
mode='lines',
|
287 |
+
line=dict(color='red'),
|
288 |
+
showlegend=False # Hide legend for extra scores to combine with missing scores
|
289 |
+
))
|
290 |
+
|
291 |
+
# Compute global x and y limits
|
292 |
+
x_min = min(r[0] for r in x_ranges)
|
293 |
+
x_max = max(r[1] for r in x_ranges)
|
294 |
+
y_min = min(r[0] for r in y_ranges)
|
295 |
+
y_max = max(r[1] for r in y_ranges)
|
296 |
+
|
297 |
+
# Define x, y, z axis tick intervals
|
298 |
+
x_ticks = calculate_ticks(np.floor(x_min), np.ceil(x_max))
|
299 |
+
y_ticks = list(model_y_positions.values())
|
300 |
+
z_ticks = calculate_ticks(y_min, y_max)
|
301 |
+
|
302 |
+
# Add a line through the 0-axis of density for each model
|
303 |
+
for model in models:
|
304 |
+
color = color_dict.get(model, 'rgba(0, 0, 0, 0.5)')
|
305 |
+
fig.add_trace(go.Scatter3d(
|
306 |
+
x=[x_min, x_max],
|
307 |
+
y=[model_y_positions[model], model_y_positions[model]],
|
308 |
+
z=[0, 0],
|
309 |
+
mode='lines',
|
310 |
+
# line=dict(color=color, width=2, dash='dash'),
|
311 |
+
line=dict(color=color),
|
312 |
+
name=model,
|
313 |
+
|
314 |
+
# showlegend=False
|
315 |
+
))
|
316 |
+
|
317 |
+
# Update layout for 3D plot
|
318 |
+
fig.update_layout(
|
319 |
+
title=f'3D KDE Plots for {compare}',
|
320 |
+
scene=dict(
|
321 |
+
xaxis_title='Score',
|
322 |
+
yaxis_title='Model',
|
323 |
+
zaxis_title='Density',
|
324 |
+
xaxis=dict(
|
325 |
+
range=[x_min, x_max],
|
326 |
+
tickvals=x_ticks,
|
327 |
+
ticktext=[f'{tick:.2f}' for tick in x_ticks]
|
328 |
+
),
|
329 |
+
yaxis=dict(
|
330 |
+
tickvals=y_ticks,
|
331 |
+
ticktext=[list(model_y_positions.keys())[list(model_y_positions.values()).index(tick)] for tick in y_ticks]
|
332 |
+
),
|
333 |
+
zaxis=dict(
|
334 |
+
range=[y_min, y_max],
|
335 |
+
tickvals=z_ticks,
|
336 |
+
ticktext=[f'{tick:.4f}' for tick in z_ticks]
|
337 |
+
),
|
338 |
+
camera=dict(
|
339 |
+
eye=dict(x=1.25, y=1.25, z=1.25)
|
340 |
+
)
|
341 |
+
),
|
342 |
+
autosize=True,
|
343 |
+
width=1200*.75,
|
344 |
+
height=800*.75
|
345 |
+
)
|
346 |
+
|
347 |
+
# Save the plot as an HTML file
|
348 |
+
# plot = px.scatter(x=range(10), y=range(10))
|
349 |
+
filename = f"{compare}.html"
|
350 |
+
fig.write_html(filename)
|
351 |
+
|
352 |
+
# fig.show()
|
353 |
+
|
354 |
+
return fig
|
355 |
+
|
356 |
+
|
357 |
+
|
358 |
+
# Path to your saved HTML file
|
359 |
+
html_file_path = '3d_plot.html'
|
360 |
+
title = 'My 3D Plot'
|
361 |
+
|
362 |
+
def display_plot():
|
363 |
+
fig = plot_kde_3d(values_dict, models, color_dict, compare)
|
364 |
+
return fig
|
365 |
+
|
366 |
+
|
367 |
+
# Define the Gradio interface
|
368 |
+
interface = gr.Interface(
|
369 |
+
fn=display_plot,
|
370 |
+
inputs=[],
|
371 |
+
outputs=gr.Plot(),
|
372 |
+
title='Plotly 3D Plot in Gradio',
|
373 |
+
description='This app displays a 3D Plotly plot directly in the Gradio interface.',
|
374 |
+
live=False
|
375 |
+
)
|
376 |
+
|
377 |
+
# Launch the Gradio app
|
378 |
+
if __name__ == "__main__":
|
379 |
+
interface.launch()
|
380 |
+
|
381 |
+
|