Benjamin Bossan
Add special fields for skops template
2174ccd
raw
history blame
6.09 kB
"""Functionality around tasks
Tasks are used to implement "undo" and "redo" functionality.
"""
from __future__ import annotations
from pathlib import Path
from uuid import uuid4
from skops import card
from skops.card._model_card import PlotSection, split_subsection_names
from streamlit.runtime.uploaded_file_manager import UploadedFile
class Task:
"""(Abstract) base class for tasks"""
def do(self) -> None:
raise NotImplementedError
def undo(self) -> None:
raise NotImplementedError
class TaskState:
"""Tracking the state of tasks"""
def __init__(self) -> None:
self.done_list: list[Task] = []
self.undone_list: list[Task] = []
def undo(self) -> None:
if not self.done_list:
return
task = self.done_list.pop(-1)
task.undo()
self.undone_list.append(task)
def redo(self) -> None:
if not self.undone_list:
return
task = self.undone_list.pop(-1)
task.do()
self.done_list.append(task)
def add(self, task: Task) -> None:
task.do()
self.done_list.append(task)
self.undone_list.clear()
def reset(self) -> None:
self.done_list.clear()
self.undone_list.clear()
class AddSectionTask(Task):
"""Add a new text section"""
def __init__(
self,
model_card: card.Card,
title: str,
content: str,
) -> None:
self.model_card = model_card
self.title = title
self.key = title + " " + str(uuid4())[:6]
self.content = content
def do(self) -> None:
self.model_card.add(**{self.key: self.content})
section = self.model_card.select(self.key)
section.title = split_subsection_names(self.title)[-1]
def undo(self) -> None:
self.model_card.delete(self.key)
class AddFigureTask(Task):
"""Add a new figure section"""
def __init__(
self,
model_card: card.Card,
title: str,
content: str,
) -> None:
self.model_card = model_card
self.title = title
self.key = title + " " + str(uuid4())[:6]
self.content = content
def do(self) -> None:
self.model_card.add_plot(**{self.key: self.content})
section = self.model_card.select(self.key)
section.title = split_subsection_names(self.title)[-1]
section.is_fig = True # type: ignore
def undo(self) -> None:
self.model_card.delete(self.key)
class DeleteSectionTask(Task):
"""Delete a section
The section is not completely removed from the underlying data structure,
but only turned invisible.
"""
def __init__(
self,
model_card: card.Card,
key: str,
) -> None:
self.model_card = model_card
self.key = key
def do(self) -> None:
self.model_card.select(self.key).visible = False
def undo(self) -> None:
self.model_card.select(self.key).visible = True
class UpdateSectionTask(Task):
"""Change the title or content of a text section"""
def __init__(
self,
model_card: card.Card,
key: str,
old_name: str,
new_name: str,
old_content: str,
new_content: str,
) -> None:
self.model_card = model_card
self.key = key
self.old_name = old_name
self.new_name = new_name
self.old_content = old_content
self.new_content = new_content
def do(self) -> None:
section = self.model_card.select(self.key)
new_title = split_subsection_names(self.new_name)[-1]
section.title = new_title
section.content = self.new_content
def undo(self) -> None:
section = self.model_card.select(self.key)
old_title = split_subsection_names(self.old_name)[-1]
section.title = old_title
section.content = self.old_content
class UpdateFigureTask(Task):
"""Change the title or image of a figure section"""
def __init__(
self,
model_card: card.Card,
key: str,
old_name: str,
new_name: str,
data: UploadedFile | None,
path: Path | None,
) -> None:
self.model_card = model_card
self.key = key
self.old_name = old_name
self.new_name = new_name
self.old_data = self.model_card.select(self.key).content
self.path = path
if not data:
self.new_data = self.old_data
else:
self.new_data = data
def do(self) -> None:
section = self.model_card.select(self.key)
new_title = split_subsection_names(self.new_name)[-1]
section.title = self.title = new_title
if self.new_data == self.old_data: # image is same
return
# write figure
# note: this can still be the same image if the image is a file, there
# is no test to check, e.g., the hash of the image
with open(self.path, "wb") as f:
f.write(self.new_data.getvalue())
section.content = PlotSection(
alt_text=self.new_data.name,
path=self.path,
).format()
def undo(self) -> None:
section = self.model_card.select(self.key)
old_title = split_subsection_names(self.old_name)[-1]
section.title = old_title
if self.new_data == self.old_data: # image is same
return
self.path.unlink(missing_ok=True)
section.content = self.old_data
class AddMetricsTask(Task):
"""Add new metrics"""
def __init__(
self,
model_card: card.Card,
metrics: dict[str, str | int | float],
) -> None:
self.model_card = model_card
self.old_metrics = model_card._metrics.copy()
self.new_metrics = metrics
def do(self) -> None:
self.model_card._metrics.clear()
self.model_card.add_metrics(**self.new_metrics)
def undo(self) -> None:
self.model_card._metrics.clear()
self.model_card.add_metrics(**self.old_metrics)