Benjamin Bossan commited on
Commit
2174ccd
·
1 Parent(s): d3818ef

Add special fields for skops template

Browse files

- It's now possible to enter metrics in text area
- It's now possible to toggle model diagram

Files changed (3) hide show
  1. edit.py +96 -21
  2. start.py +3 -0
  3. tasks.py +21 -0
edit.py CHANGED
@@ -37,6 +37,7 @@ from skops.card._model_card import PlotSection, split_subsection_names
37
 
38
  from utils import iterate_key_section_content, process_card_for_rendering
39
  from tasks import (
 
40
  AddSectionTask,
41
  AddFigureTask,
42
  DeleteSectionTask,
@@ -279,33 +280,107 @@ def add_download_model_card_button():
279
  )
280
 
281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  def edit_input_form():
283
  if "task_state" not in st.session_state:
284
  st.session_state.task_state = TaskState()
285
 
286
  with st.sidebar:
287
- col_0, col_1, col_2, *_ = st.columns([2, 2, 2, 2])
288
- undo_disabled = not bool(st.session_state.task_state.done_list)
289
- redo_disabled = not bool(st.session_state.task_state.undone_list)
290
- with col_0:
291
- name = f"UNDO ({len(st.session_state.task_state.done_list)})"
292
- tip = "Undo the last edit"
293
- st.button(name, on_click=undo_last, disabled=undo_disabled, help=tip)
294
- with col_1:
295
- name = f"REDO ({len(st.session_state.task_state.undone_list)})"
296
- tip = "Redo the last undone edit"
297
- st.button(name, on_click=redo_last, disabled=redo_disabled, help=tip)
298
- with col_2:
299
- tip = "Undo all edits"
300
- st.button("Reset", on_click=reset_model_card, help=tip)
301
-
302
- col_0, col_1, *_ = st.columns([2, 2, 2, 2])
303
- with col_0:
304
- add_download_model_card_button()
305
- with col_1:
306
- tip = "Start over from scratch (lose all progress)"
307
- st.button("Delete", on_click=delete_model_card, help=tip)
308
 
 
309
  if "model_card" in st.session_state:
310
  display_sections(st.session_state.model_card)
311
 
 
37
 
38
  from utils import iterate_key_section_content, process_card_for_rendering
39
  from tasks import (
40
+ AddMetricsTask,
41
  AddSectionTask,
42
  AddFigureTask,
43
  DeleteSectionTask,
 
280
  )
281
 
282
 
283
+ def display_edit_buttons():
284
+ # first row: undo + redo + reset
285
+ col_0, col_1, col_2, *_ = st.columns([2, 2, 2, 2])
286
+ undo_disabled = not bool(st.session_state.task_state.done_list)
287
+ redo_disabled = not bool(st.session_state.task_state.undone_list)
288
+ with col_0:
289
+ name = f"UNDO ({len(st.session_state.task_state.done_list)})"
290
+ tip = "Undo the last edit"
291
+ st.button(name, on_click=undo_last, disabled=undo_disabled, help=tip)
292
+ with col_1:
293
+ name = f"REDO ({len(st.session_state.task_state.undone_list)})"
294
+ tip = "Redo the last undone edit"
295
+ st.button(name, on_click=redo_last, disabled=redo_disabled, help=tip)
296
+ with col_2:
297
+ tip = "Undo all edits"
298
+ st.button("Reset", on_click=reset_model_card, help=tip)
299
+
300
+ # second row: download + delete
301
+ col_0, col_1, *_ = st.columns([2, 2, 2, 2])
302
+ with col_0:
303
+ add_download_model_card_button()
304
+ with col_1:
305
+ tip = "Start over from scratch (lose all progress)"
306
+ st.button("Delete", on_click=delete_model_card, help=tip)
307
+
308
+
309
+ def _update_model_diagram():
310
+ val = st.session_state.get("special_model_diagram", True)
311
+ model_card = st.session_state.model_card
312
+ model_card.model_diagram = val
313
+
314
+ # TODO: this may no longer be necesssary once this issue is solved:
315
+ # https://github.com/skops-dev/skops/issues/292
316
+ if val:
317
+ model_card.add_model_plot()
318
+ else:
319
+ model_card.delete("Model description/Training Procedure/Model Plot")
320
+
321
+
322
+ def _parse_metrics(metrics: str) -> dict[str, str | float]:
323
+ # parse metrics from text area, one per line, into a dict
324
+ metrics_table = {}
325
+ for line in metrics.splitlines():
326
+ line = line.strip()
327
+ val: str | float
328
+ name, _, val = line.partition("=")
329
+ try:
330
+ # try to coerce to float but don't error if it fails
331
+ val = float(val.strip())
332
+ except ValueError:
333
+ pass
334
+ metrics_table[name.strip()] = val
335
+ return metrics_table
336
+
337
+
338
+ def _update_metrics():
339
+ metrics = st.session_state.get("special_metrics_text", {})
340
+ model_card = st.session_state.model_card
341
+ metrics_table = _parse_metrics(metrics)
342
+
343
+ # check if any change
344
+ if metrics_table == model_card._metrics:
345
+ return
346
+
347
+ task = AddMetricsTask(model_card, metrics_table)
348
+ st.session_state.task_state.add(task)
349
+
350
+
351
+ def display_skops_special_fields():
352
+ st.checkbox(
353
+ "Show model diagram",
354
+ value=True,
355
+ on_change=_update_model_diagram,
356
+ key="special_model_diagram",
357
+ )
358
+
359
+ with st.expander("Add metrics"):
360
+ with st.form("special_metrics", clear_on_submit=False):
361
+ st.text_area(
362
+ "Add one metric per line, e.g. 'accuracy = 0.9'",
363
+ key="special_metrics_text",
364
+ )
365
+ st.form_submit_button(
366
+ "Update",
367
+ on_click=_update_metrics,
368
+ )
369
+
370
+
371
  def edit_input_form():
372
  if "task_state" not in st.session_state:
373
  st.session_state.task_state = TaskState()
374
 
375
  with st.sidebar:
376
+ # TOP ROW BUTTONS
377
+ display_edit_buttons()
378
+
379
+ # SHOW SPECIAL FIELDS IF SKOPS TEMPLATE WAS USED
380
+ if st.session_state.get("model_card_type", "") == "skops":
381
+ display_skops_special_fields()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
+ # SHOW EDITABLE SECTIONS
384
  if "model_card" in st.session_state:
385
  display_sections(st.session_state.model_card)
386
 
start.py CHANGED
@@ -107,6 +107,7 @@ def create_skops_model_card() -> None:
107
  metadata = card.metadata_from_config(hf_path)
108
  model_card = card.Card(model=st.session_state.model, metadata=metadata)
109
  st.session_state.model_card = model_card
 
110
 
111
 
112
  def create_empty_model_card() -> None:
@@ -117,6 +118,7 @@ def create_empty_model_card() -> None:
117
  )
118
  model_card.add(**{"Untitled": "[More Information Needed]"})
119
  st.session_state.model_card = model_card
 
120
 
121
 
122
  def create_hf_model_card() -> None:
@@ -128,6 +130,7 @@ def create_hf_model_card() -> None:
128
  path = hf_hub_download(repo_id, "README.md")
129
  model_card = card.parse_modelcard(path)
130
  st.session_state.model_card = model_card
 
131
 
132
 
133
  def start_input_form():
 
107
  metadata = card.metadata_from_config(hf_path)
108
  model_card = card.Card(model=st.session_state.model, metadata=metadata)
109
  st.session_state.model_card = model_card
110
+ st.session_state.model_card_type = "skops"
111
 
112
 
113
  def create_empty_model_card() -> None:
 
118
  )
119
  model_card.add(**{"Untitled": "[More Information Needed]"})
120
  st.session_state.model_card = model_card
121
+ st.session_state.model_card_type = "empty"
122
 
123
 
124
  def create_hf_model_card() -> None:
 
130
  path = hf_hub_download(repo_id, "README.md")
131
  model_card = card.parse_modelcard(path)
132
  st.session_state.model_card = model_card
133
+ st.session_state.model_card_type = "loaded"
134
 
135
 
136
  def start_input_form():
tasks.py CHANGED
@@ -208,3 +208,24 @@ class UpdateFigureTask(Task):
208
 
209
  self.path.unlink(missing_ok=True)
210
  section.content = self.old_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  self.path.unlink(missing_ok=True)
210
  section.content = self.old_data
211
+
212
+
213
+ class AddMetricsTask(Task):
214
+ """Add new metrics"""
215
+
216
+ def __init__(
217
+ self,
218
+ model_card: card.Card,
219
+ metrics: dict[str, str | int | float],
220
+ ) -> None:
221
+ self.model_card = model_card
222
+ self.old_metrics = model_card._metrics.copy()
223
+ self.new_metrics = metrics
224
+
225
+ def do(self) -> None:
226
+ self.model_card._metrics.clear()
227
+ self.model_card.add_metrics(**self.new_metrics)
228
+
229
+ def undo(self) -> None:
230
+ self.model_card._metrics.clear()
231
+ self.model_card.add_metrics(**self.old_metrics)