andrewrreed HF staff commited on
Commit
6bd3956
·
1 Parent(s): 32bb93d

add annotation for # models and days till crossover

Browse files
Files changed (2) hide show
  1. app.py +25 -1
  2. utils.py +31 -0
app.py CHANGED
@@ -15,6 +15,8 @@ from utils import (
15
  get_constants,
16
  update_release_date_mapping,
17
  format_data,
 
 
18
  )
19
 
20
  ###################
@@ -145,6 +147,7 @@ def filter_df(min_score, max_models_per_month, set_selector, org_selector):
145
  .apply(lambda x: x.nlargest(max_models_per_month, "rating"))
146
  .reset_index(drop=True)
147
  )
 
148
  return filtered_df
149
 
150
 
@@ -175,6 +178,27 @@ def build_plot(toggle_annotations, filtered_df):
175
 
176
  fig.update_traces(marker=dict(size=10, opacity=0.6))
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  if toggle_annotations:
179
  # get the points to annotate (only the highest rated model per month per license)
180
  idx_to_annotate = filtered_df.groupby(["Month-Year", "License"])[
@@ -285,7 +309,7 @@ with gr.Blocks(
285
  filtered_df = gr.State()
286
  with gr.Group():
287
  with gr.Tab("Plot"):
288
- plot = gr.Plot()
289
  with gr.Tab("Raw Data"):
290
 
291
  display_df = gr.DataFrame()
 
15
  get_constants,
16
  update_release_date_mapping,
17
  format_data,
18
+ get_trendlines,
19
+ find_crossover_point,
20
  )
21
 
22
  ###################
 
147
  .apply(lambda x: x.nlargest(max_models_per_month, "rating"))
148
  .reset_index(drop=True)
149
  )
150
+
151
  return filtered_df
152
 
153
 
 
178
 
179
  fig.update_traces(marker=dict(size=10, opacity=0.6))
180
 
181
+ # calculate days until crossover
182
+ trend1, trend2 = get_trendlines(fig)
183
+ crossover = find_crossover_point(
184
+ b1=trend1[0], m1=trend1[1], b2=trend2[0], m2=trend2[1]
185
+ )
186
+ days_til_crossover = (
187
+ pd.to_datetime(crossover, unit="s") - pd.Timestamp.today()
188
+ ).days
189
+
190
+ # add annotation with number of models and days til crossover
191
+ fig.add_annotation(
192
+ xref="paper",
193
+ yref="paper", # use paper coordinates
194
+ x=-0.05,
195
+ y=1.13,
196
+ text=f"Number of models: {len(filtered_df)}<br>Days til crossover: {days_til_crossover}",
197
+ showarrow=False,
198
+ font=dict(size=14, color="white"),
199
+ bgcolor="rgba(0,0,0,0.5)",
200
+ )
201
+
202
  if toggle_annotations:
203
  # get the points to annotate (only the highest rated model per month per license)
204
  idx_to_annotate = filtered_df.groupby(["Month-Year", "License"])[
 
309
  filtered_df = gr.State()
310
  with gr.Group():
311
  with gr.Tab("Plot"):
312
+ plot = gr.Plot(show_label=False)
313
  with gr.Tab("Raw Data"):
314
 
315
  display_df = gr.DataFrame()
utils.py CHANGED
@@ -4,6 +4,7 @@ from datetime import datetime
4
  from typing import Literal, List
5
 
6
  import pandas as pd
 
7
  from huggingface_hub import HfFileSystem, hf_hub_download
8
 
9
  # from: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/monitor/monitor.py#L389
@@ -174,3 +175,33 @@ def format_data(df):
174
  df["Month-Year"] = df["Release Date"].dt.to_period("M")
175
  df["rating"] = df["rating"].round()
176
  return df.reset_index(drop=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from typing import Literal, List
5
 
6
  import pandas as pd
7
+ import plotly.express as px
8
  from huggingface_hub import HfFileSystem, hf_hub_download
9
 
10
  # from: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/monitor/monitor.py#L389
 
175
  df["Month-Year"] = df["Release Date"].dt.to_period("M")
176
  df["rating"] = df["rating"].round()
177
  return df.reset_index(drop=True)
178
+
179
+
180
+ def get_trendlines(fig):
181
+
182
+ trend_lines = px.get_trendline_results(fig)
183
+
184
+ return [
185
+ trend_lines.iloc[i]["px_fit_results"].params.tolist()
186
+ for i in range(len(trend_lines))
187
+ ]
188
+
189
+
190
+ def find_crossover_point(b1, m1, b2, m2):
191
+ """
192
+ Determine the X value at which two trendlines will cross over.
193
+
194
+ Parameters:
195
+ m1 (float): Slope of the first trendline.
196
+ b1 (float): Intercept of the first trendline.
197
+ m2 (float): Slope of the second trendline.
198
+ b2 (float): Intercept of the second trendline.
199
+
200
+ Returns:
201
+ float: The X value where the two trendlines cross.
202
+ """
203
+ if m1 == m2:
204
+ raise ValueError("The trendlines are parallel and do not cross.")
205
+
206
+ x_crossover = (b2 - b1) / (m1 - m2)
207
+ return x_crossover