|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import re |
|
import uuid |
|
import time |
|
import shutil |
|
import zipfile |
|
import threading |
|
import subprocess |
|
import select |
|
from datetime import datetime |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
import dash |
|
from dash import dcc, html |
|
from dash.dependencies import Input, Output, State, ALL |
|
import dash_bootstrap_components as dbc |
|
from dash.exceptions import PreventUpdate |
|
|
|
from flask import Flask, render_template, request, send_file, jsonify |
|
|
|
import yaml |
|
import ruamel.yaml |
|
import pandas as pd |
|
|
|
|
|
import logging |
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
server = Flask(__name__) |
|
server.secret_key = os.urandom(24) |
|
|
|
|
|
@server.route('/') |
|
def welcome_page(): |
|
""" |
|
Handles the welcome page route. |
|
|
|
This function extracts the username from the request host, |
|
determines if the duplicate mode should be enabled, and renders |
|
the welcome page template with the duplicate mode state. |
|
|
|
Returns: |
|
str: The rendered 'index.html' template with the duplicate_mode parameter. |
|
""" |
|
host = request.host |
|
print("host:", host) |
|
usr_match = re.match(r'^(.*?)\-stm32', host) |
|
print("usr_match:", usr_match) |
|
|
|
if usr_match: |
|
hf_user = usr_match.group(1) |
|
else: |
|
hf_user = "modelzoo_user" |
|
|
|
if hf_user == "stmicroelectronics": |
|
duplicate_mode = True |
|
else: |
|
duplicate_mode = False |
|
|
|
print("hf_user:", hf_user) |
|
print("duplicate_mode:", duplicate_mode) |
|
|
|
return render_template('index.html', duplicate_mode=duplicate_mode) |
|
|
|
|
|
external_stylesheets = [dbc.themes.LITERA] |
|
app = dash.Dash(__name__, server=server,external_stylesheets=external_stylesheets, url_base_pathname='/dash_app/', suppress_callback_exceptions=True) |
|
|
|
|
|
local_yamls = { |
|
'image_classification': 'stm32ai-modelzoo-services/image_classification/src/user_config.yaml', |
|
'human_activity_recognition': 'stm32ai-modelzoo-services/human_activity_recognition/src/user_config.yaml', |
|
'hand_posture': 'stm32ai-modelzoo-services/hand_posture/src/user_config.yaml', |
|
'object_detection': 'stm32ai-modelzoo-services/object_detection/src/user_config.yaml', |
|
'audio_event_detection': 'stm32ai-modelzoo-services/audio_event_detection/src/user_config.yaml', |
|
'pose_estimation': 'stm32ai-modelzoo-services/pose_estimation/src/user_config.yaml', |
|
'semantic_segmentation': 'stm32ai-modelzoo-services/semantic_segmentation/src/user_config.yaml' |
|
} |
|
|
|
|
|
def banner(): |
|
return html.Div( |
|
id="banner", |
|
className="top-bar", |
|
style={ |
|
"display": "flex", |
|
"align-items": "center", |
|
"justify-content": "space-between", |
|
"position": "fixed", |
|
"top": "0", |
|
"left": "0", |
|
"width": "100%", |
|
"z-index": "1000", |
|
"background-color": "#3234b", |
|
"padding": "10px 20px", |
|
"box-shadow": "0px 2px 4px rgba(0, 0, 0, 0.1)" |
|
}, |
|
children=[ |
|
html.A( |
|
id="learn-more-button", |
|
children=[ |
|
html.Img( |
|
src=app.get_asset_url("github-mark-white.png"), |
|
style={"width": "20px", "height": "20px", "margin-right": "10px"} |
|
), |
|
"stm32ai-modelzoo", |
|
], |
|
href="https://github.com/STMicroelectronics/stm32ai-modelzoo-services", |
|
target="_blank", |
|
style={ |
|
"display": "flex", |
|
"align-items": "center", |
|
"color": "#ffffff", |
|
"text-decoration": "none", |
|
"font-size": "15px", |
|
"font-family": "Arial, sans-serif" |
|
} |
|
), |
|
html.Div( |
|
html.Img( |
|
id="logo", |
|
src=app.get_asset_url("ST_logo_2024_white.png"), |
|
style={"width": "50px", "height": "auto"} |
|
), |
|
style={"text-align": "center"} |
|
), |
|
html.Div( |
|
[ |
|
html.A( |
|
[ |
|
html.H5( |
|
"ST Edge AI Developer Cloud", |
|
style={ |
|
"margin": "0", |
|
"text-align": "right", |
|
"color": "#ffffff", |
|
"font-size": "15px", |
|
"font-weight": "bold", |
|
"font-family": "Arial, sans-serif" |
|
} |
|
) |
|
], |
|
href="https://stm32ai-cs.st.com/home", |
|
target="_blank", |
|
style={ |
|
"display": "flex", |
|
"align-items": "center", |
|
"text-decoration": "none" |
|
} |
|
) |
|
], |
|
style={"padding-right": "10px"} |
|
) |
|
] |
|
) |
|
|
|
|
|
def read_configs(selected_model): |
|
""" |
|
Loads a YAML file based on the selected model by the user. |
|
|
|
Args: |
|
selected_model (str): The key to select the appropriate YAML file path. |
|
|
|
Returns: |
|
dict: The loaded YAML data. |
|
""" |
|
if not selected_model: |
|
raise ValueError("No model selected. Please select a valid model.") |
|
if selected_model not in local_yamls: |
|
raise ValueError(f"Model '{selected_model}' not found in local_yamls") |
|
|
|
yaml_path = local_yamls[selected_model] |
|
try: |
|
with open(yaml_path, 'r') as file: |
|
return yaml.safe_load(file) |
|
except Exception as e: |
|
raise ValueError(f"Error reading YAML file at {yaml_path}: {e}") |
|
|
|
|
|
def build_yaml_form(yaml_content, parent_key=''): |
|
""" |
|
Recursively builds a form based on the provided YAML content. |
|
|
|
Parameters: |
|
- yaml_content (dict): The YAML content to build the form from. |
|
- parent_key (str): The parent key to maintain the hierarchy of nested keys. Default is an empty string. |
|
|
|
Returns: |
|
- list: A list of Dash Bootstrap Components (dbc) AccordionItems representing the form fields. |
|
""" |
|
accordion_items = [] |
|
for key, value in yaml_content.items(): |
|
full_key = f"{parent_key}.{key}" if parent_key else key |
|
|
|
if isinstance(value, dict): |
|
nested_accordion = build_yaml_form(value, full_key) |
|
accordion_items.append( |
|
dbc.AccordionItem( |
|
nested_accordion, |
|
title=key.capitalize() |
|
) |
|
) |
|
else: |
|
|
|
field = [html.Label(key, style={"font-weight": "bold", "margin-bottom": "5px"})] |
|
|
|
if isinstance(value, bool): |
|
field.append( |
|
dcc.Checklist( |
|
id={'type': 'yaml-setting', 'index': full_key}, |
|
options=[{'label': '', 'value': True}], |
|
value=[True] if value else [], |
|
style={"padding": "10px", "border": "1px solid #ddd", "margin-bottom": "10px"} |
|
) |
|
) |
|
elif isinstance(value, list): |
|
field.append( |
|
dcc.Dropdown( |
|
id={'type': 'yaml-setting', 'index': full_key}, |
|
options=[{'label': str(v), 'value': v} for v in value], |
|
value=value, |
|
multi=True, |
|
style={"padding": "10px", "border": "1px solid #ddd", "margin-bottom": "10px"} |
|
) |
|
) |
|
else: |
|
field.append( |
|
dcc.Input( |
|
id={'type': 'yaml-setting', 'index': full_key}, |
|
value=value, |
|
type='text', |
|
style={"padding": "10px", "border": "1px solid #ddd", "margin-bottom": "10px"} |
|
) |
|
) |
|
|
|
accordion_items.append( |
|
dbc.AccordionItem( |
|
field, |
|
title=key.capitalize() |
|
) |
|
) |
|
|
|
return accordion_items |
|
|
|
|
|
def create_yaml(yaml_content): |
|
""" |
|
Creates a YAML form using Dash Bootstrap Components (dbc) and Dash HTML Components (html). |
|
|
|
Parameters: |
|
yaml_content (dict): The content of the YAML file to be used for building the form. |
|
|
|
Returns: |
|
dbc.Form: A Dash form component containing an accordion with the YAML content and a submit button. |
|
""" |
|
accordion_items = build_yaml_form(yaml_content) |
|
accordion = dbc.Accordion( |
|
accordion_items, |
|
start_collapsed=True |
|
) |
|
|
|
return dbc.Form([ |
|
accordion, |
|
html.Div( |
|
dbc.Button( |
|
'Submit', |
|
id='apply-button', |
|
style={ |
|
'background-color': '#FFD200', |
|
'color': '#03234b', |
|
'font-size': '14px', |
|
'padding': '10px 10px 10px 10px', |
|
'border-radius': '5px', |
|
'margin-top': '15px', |
|
'border': '2px solid #FFD200', |
|
'box-shadow': '0px 4px 6px rgba(0, 0, 0, 0.1)', |
|
} |
|
), |
|
style={ |
|
'display': 'flex', |
|
'justify-content': 'center', |
|
'margin-top': '15px', |
|
} |
|
), |
|
html.Div( |
|
id='submission-outcome', |
|
style={ |
|
'marginTop': '10px', |
|
'textAlign': 'center', |
|
'fontStyle': 'italic', |
|
'color': '#03234b', |
|
'font-size': '14px' |
|
} |
|
) |
|
]) |
|
|
|
|
|
def process_form_configs(form_configs): |
|
""" |
|
Extracts and processes form data to update YAML content. |
|
|
|
This function processes the form data, converting values to appropriate types |
|
and updating the YAML content accordingly. |
|
|
|
Args: |
|
form_configs (dict): The form data to be processed. |
|
|
|
Returns: |
|
dict: The updated YAML content with processed form data. |
|
""" |
|
updated_yaml = {} |
|
for key, value in form_configs.items(): |
|
if value is not None: |
|
if isinstance(value, list) and len(value) == 1: |
|
value = value[0] |
|
|
|
if isinstance(value, str): |
|
try: |
|
if '.' in value: |
|
value = float(value) |
|
else: |
|
value = int(value) |
|
except ValueError: |
|
pass |
|
|
|
updated_yaml[key] = value |
|
|
|
return updated_yaml |
|
|
|
|
|
def create_archive(archive_path, directory_to_compress): |
|
""" |
|
Creates a ZIP archive of a specified directory. |
|
|
|
Parameters: |
|
archive_path (str): The path where the ZIP archive will be created. |
|
directory_to_compress (str): The directory whose contents will be compressed into the ZIP archive. |
|
|
|
Returns: |
|
None |
|
""" |
|
def add_file_to_zip(zipf, file_path, arcname): |
|
""" |
|
Adds a file to the ZIP archive. |
|
|
|
Parameters: |
|
zipf (zipfile.ZipFile): The ZIP file object. |
|
file_path (str): The path of the file to add to the ZIP archive. |
|
arcname (str): The archive name for the file within the ZIP archive. |
|
|
|
Returns: |
|
None |
|
""" |
|
zipf.write(file_path, arcname=arcname) |
|
|
|
with zipfile.ZipFile(archive_path, 'w', compression=zipfile.ZIP_DEFLATED) as zipf: |
|
with ThreadPoolExecutor() as executor: |
|
for root_dir, sub_dirs, files in os.walk(directory_to_compress): |
|
for file_name in files: |
|
file_path = os.path.join(root_dir, file_name) |
|
if os.path.abspath(file_path) != os.path.abspath(archive_path): |
|
arcname = os.path.relpath(file_path, directory_to_compress) |
|
executor.submit(add_file_to_zip, zipf, file_path, arcname) |
|
|
|
|
|
def create_dashboard_layout(): |
|
""" |
|
Creates the layout for the application: STM32ModelZoo dashboard. |
|
|
|
This function defines the structure and components of the dashboard, |
|
including the banner, model selection dropdown, YAML update options, |
|
credentials input, output display, training metrics graphs, and download button. |
|
|
|
Returns: |
|
dbc.Container: A Dash Bootstrap Component container with the dashboard layout. |
|
""" |
|
return html.Div([ |
|
banner(), |
|
dbc.Container([ |
|
dcc.Location(id='url', refresh=False), |
|
dbc.Row(dbc.Col(html.H3("STM32 Modelzoo", style={'color': '#03234b', 'text-align': 'center',"margin-top": "80px", "font-family": "Arial, sans-serif"}), className="mb-4")), |
|
dbc.Row([ |
|
dbc.Col( |
|
html.H5("Use case selection", style={'color': '#03234b', 'margin-bottom': '10px'}), |
|
width=12 |
|
) |
|
], id="use-case-section", style={"display": "none"}), |
|
dbc.Row(dbc.Col(dcc.Dropdown( |
|
id='selected-model', |
|
options=[ |
|
{'label': 'Image Classification (IC)', 'value': 'image_classification'}, |
|
{'label': 'Human Activity Recognition (HAR)', 'value': 'human_activity_recognition'}, |
|
{'label': 'Hand Posture', 'value': 'hand_posture'}, |
|
{'label': 'Audio Event Detection(AED)', 'value': 'audio_event_detection'}, |
|
{'label': 'Object Detection', 'value': 'object_detection'}, |
|
{'label': 'Pose estimation', 'value': 'pose_estimation'}, |
|
{'label': 'Semantic Segmentation', 'value': 'semantic_segmentation'}, |
|
], |
|
placeholder="Please select your use case", |
|
className="mb-4" |
|
))), |
|
|
|
dbc.Row( |
|
dbc.Col( |
|
html.Div( |
|
id='toggle-yaml', |
|
children=[ |
|
html.P([ |
|
"Please update the YAML file: Dataset path (example: ../datasets/your_use_case/name_of_dataset) or datasets/your_prepared_dataset. For more details, refer to the ", |
|
html.A("README", href="https://huggingface.co/spaces/STMicroelectronics/stm32-modelzoo-app/blob/main/datasets/README.md", target="_blank", style={'color': '#007bff', 'text-decoration': 'underline'}), |
|
"." |
|
], style={'font-family': 'Arial, sans-serif', 'color': '#03234b', 'fontSize': '15px'}), |
|
dcc.RadioItems( |
|
id='modify-yaml-choice', |
|
labelStyle={'display': 'inline-block', 'margin-right': '10px'}, |
|
className="mb-4", |
|
), |
|
dcc.Upload( |
|
id='load-yaml-file', |
|
children=html.Button('Upload YAML File'), |
|
style={'display': 'none'} |
|
), |
|
html.Div(id='load-state', style={'margin-top': '10px'}), |
|
html.Div(id='yaml-layout', style={'display': 'none'}) |
|
], |
|
style={'font-family': 'Arial, sans-serif', 'display': 'none'} |
|
) |
|
) |
|
), |
|
dbc.Row([ |
|
dbc.Col([ |
|
html.P("Enter your ST Edge AI Developer Cloud credentials:", style={'color': '03234b', 'fontSize': '15px', 'fontWeight': 'bold'}, className="credentials-text"), |
|
dcc.Input(id='devcloud-username-input', type='text', placeholder='Enter username', className="input-field mb-2"), |
|
dcc.Input(id='devcloud-password-input', type='password', placeholder='Enter password', className="input-field mb-4") |
|
], width=6), |
|
dbc.Col([ |
|
dbc.Button('Launch training', id='process-button', color="#3234b", className="start-button mb-4", style={'display': 'none', 'box-shadow': '0px 4px 6px rgba(0, 0, 0, 0.1)'}) |
|
], className="credentials-col") |
|
], id='credentials-section', style={ |
|
'display': 'none', |
|
'justify-content': 'center', |
|
'align-items': 'center', |
|
'height': '100vh', |
|
}, className="credentials-section mb-4"), |
|
|
|
dbc.Row([ |
|
dbc.Col( |
|
html.H5("Results visualization", style={'color': '#03234b', 'margin-bottom': '10px'}), |
|
width=12 |
|
) |
|
], id="results-section", style={"display": "none"}), |
|
dbc.Row([ |
|
dbc.Col(dbc.Card([ |
|
dbc.CardHeader("Command output"), |
|
dbc.CardBody( |
|
html.Div(id='log-reader', style={'whiteSpace': 'pre-wrap', 'padding-top': '15px', 'height': '100%', 'overflow': 'auto'}), |
|
style={'height': '300px'} |
|
) |
|
])) |
|
],style={'margin-bottom': '30px'}), |
|
dbc.Row([ |
|
dbc.Col(dbc.Card([ |
|
dbc.CardHeader("Metrics", style={'background-color': '#03234b', 'color': 'white'}), |
|
dbc.CardBody( |
|
dcc.Graph(id='acc-visualization', style={'height': '100%', 'width': '100%'}), |
|
style={'height': '400px', 'display': 'flex', 'justify-content': 'center', 'align-items': 'center'} |
|
) |
|
]), width=6, style={'padding': '10px'}), |
|
dbc.Col(dbc.Card([ |
|
dbc.CardHeader("Metrics", style={'background-color': '#03234b', 'color': 'white'}), |
|
dbc.CardBody( |
|
dcc.Graph(id='loss-visualization', style={'height': '100%', 'width': '100%'}), |
|
style={'height': '400px', 'display': 'flex', 'justify-content': 'center', 'align-items': 'center'} |
|
) |
|
]), width=6, style={'padding': '10px'}) |
|
], style={'margin-bottom': '30px'}), |
|
|
|
dcc.Interval(id='interval-widget', interval=1000, n_intervals=0), |
|
dcc.Download(id="download-resource"), |
|
dbc.Row( |
|
dbc.Col( |
|
dbc.Button('Download outputs', id='download-action', className="mb-4", style={ |
|
'background-color': '#ffd200', |
|
'color': '#ffffff', |
|
'font-size': '14px', |
|
'padding': '10px 10px', |
|
'border-radius': '5px', |
|
'box-shadow': '0px 4px 6px rgba(0, 0, 0, 0.1)', |
|
'margin-top': '20px' |
|
}), |
|
style={ |
|
'display': 'flex', |
|
'justify-content': 'center', |
|
'alignItems': 'center', |
|
} |
|
) |
|
) |
|
], fluid=True) |
|
]) |
|
|
|
app.layout = create_dashboard_layout |
|
|
|
logs = [] |
|
lock = threading.Lock() |
|
new_training = False |
|
|
|
def fill_logs(message): |
|
""" |
|
Appends a message to the logs list in a thread-safe manner. |
|
|
|
Parameters: |
|
message (str): The message to be appended to the logs. |
|
|
|
Returns: |
|
None |
|
""" |
|
with lock: |
|
logs.append(message) |
|
|
|
def run_script(script, devcloud_username, devcloud_password): |
|
""" |
|
Executes a given script with the provided ST Developer Cloud credentials and logs the output. |
|
|
|
Parameters: |
|
- script (str): The path to the script to be executed. |
|
- devcloud_username (str): Username for ST Developer Cloud. |
|
- devcloud_password (str): Password for ST Developer Cloud. |
|
|
|
Returns: |
|
- None |
|
""" |
|
global logs |
|
|
|
with lock: |
|
logs = [] |
|
|
|
os.environ['stmai_username'] = devcloud_username |
|
os.environ['stmai_password'] = devcloud_password |
|
os.environ['STATS_TYPE'] = 'HuggingFace_devcloud' |
|
|
|
execution = subprocess.Popen(['python3', script], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
|
while True: |
|
file_descriptors = [execution.stdout.fileno(), execution.stderr.fileno()] |
|
selected_descriptors = select.select(file_descriptors, [], []) |
|
|
|
for descriptor in selected_descriptors[0]: |
|
if descriptor == execution.stdout.fileno(): |
|
out = execution.stdout.readline() |
|
if out: |
|
fill_logs(out) |
|
if out == '' and execution.poll() is not None: |
|
return |
|
if descriptor == execution.stderr.fileno(): |
|
error = execution.stderr.readline() |
|
if error: |
|
fill_logs(error) |
|
|
|
def execute_async(script, devcloud_username, devcloud_password): |
|
""" |
|
Executes a Python script asynchronously in a separate thread. |
|
|
|
Parameters: |
|
script (str): The path to the Python script to be executed. |
|
devcloud_username (str): The username for the DevCloud environment. |
|
devcloud_password (str): The password for the DevCloud environment. |
|
|
|
Returns: |
|
None |
|
""" |
|
thread = threading.Thread(target=run_script, args=(script, devcloud_username, devcloud_password)) |
|
thread.start() |
|
|
|
|
|
@app.callback( |
|
Output("config-section", "style"), |
|
Input('selected-model', 'value') |
|
) |
|
def toggle_config_section(selected_model): |
|
""" |
|
Toggles the visibility of the configuration section based on the selected model. |
|
|
|
Parameters: |
|
selected_model (str): The value of the selected model from the dropdown. |
|
|
|
Returns: |
|
dict: A dictionary containing the CSS style for the configuration section. |
|
""" |
|
if selected_model: |
|
return {"display": "block"} |
|
else: |
|
return {"display": "none"} |
|
|
|
|
|
@app.callback( |
|
Output('toggle-yaml', 'style'), |
|
Input('selected-model', 'value') |
|
) |
|
def dipslay_yaml_container(selected_model): |
|
""" |
|
Toggles the display of the YAML update container based on the selected model. |
|
|
|
This function updates the CSS style of the YAML update container to either |
|
show or hide it based on whether a model is selected from the dropdown. |
|
|
|
Args: |
|
selected_model (str): The selected model from the dropdown. |
|
|
|
Returns: |
|
dict: A dictionary containing the CSS style to either display or hide the container. |
|
""" |
|
if selected_model: |
|
return {'display': 'block'} |
|
return {'display': 'none'} |
|
|
|
@app.callback( |
|
[Output('yaml-layout', 'style'), |
|
Output('yaml-layout', 'children')], |
|
[Input('modify-yaml-choice', 'value'), |
|
Input('selected-model', 'value')] |
|
) |
|
|
|
def display_yaml_form(selection_update, selected_model): |
|
""" |
|
Toggles the display of the YAML form and updates its content based on user input. |
|
|
|
This function updates the CSS style and content of the YAML form based on whether |
|
the user chooses to update the YAML file and a model is selected from the dropdown. |
|
|
|
Args: |
|
selection_update (str): The user's choice to update the YAML file ('yes' or 'no'). |
|
selected_model (str): The selected model from the dropdown. |
|
|
|
Returns: |
|
tuple: A tuple containing the CSS style to either display or hide the form, |
|
and the form content generated from the YAML data. |
|
""" |
|
|
|
if not selected_model: |
|
return {'display': 'none'}, "Please select a model to display its configuration." |
|
|
|
try: |
|
yaml_conf = read_configs(selected_model) |
|
form_conf = create_yaml(yaml_conf) |
|
return {'display': 'block'}, form_conf |
|
except ValueError as e: |
|
return {'display': 'none'}, f"Error: {str(e)}" |
|
except Exception as e: |
|
return {'display': 'none'}, f"Unexpected Error: {str(e)}" |
|
|
|
|
|
|
|
@app.callback( |
|
Output('credentials-section', 'style'), |
|
[Input('modify-yaml-choice', 'value'), |
|
Input('selected-model', 'value'), |
|
Input('apply-button', 'n_clicks')] |
|
|
|
) |
|
def display_credentials(selection_update, selected_model, n_clicks): |
|
""" |
|
Toggles the display of the credentials input fields based on user input. |
|
|
|
This function updates the CSS style of the credentials input fields to either |
|
show or hide them based on the user's choice to update the YAML file and the |
|
selection of a model from the dropdown. |
|
|
|
Args: |
|
selection_update (str): The user's choice to update the YAML file ('yes' or 'no'). |
|
selected_model (str): The selected model from the dropdown. |
|
|
|
Returns: |
|
dict: A dictionary containing the CSS style to either display or hide the credentials input fields. |
|
""" |
|
if n_clicks is None or n_clicks == 0: |
|
return {'display': 'none'} |
|
return {'display': 'block'} |
|
|
|
|
|
@app.callback( |
|
Output('process-button', 'style'), |
|
[Input('apply-button', 'n_clicks')] |
|
|
|
) |
|
def display_launch_training(n_clicks): |
|
""" |
|
Displays the process button based on the number of clicks on the apply button. |
|
|
|
Parameters: |
|
n_clicks (int): The number of times the apply button has been clicked. |
|
|
|
Returns: |
|
dict: A dictionary containing the CSS style for the process button. |
|
""" |
|
if n_clicks and n_clicks > 0: |
|
return {'display': 'inline-block'} |
|
return {'display': 'none'} |
|
|
|
@app.callback( |
|
Output("results-section", "style"), |
|
Input('process-button', 'n_clicks') |
|
) |
|
def display_results_section(n_clicks): |
|
""" |
|
Displays the results section based on the number of clicks on the process button. |
|
|
|
Parameters: |
|
n_clicks (int): The number of times the process button has been clicked. |
|
|
|
Returns: |
|
dict: A dictionary containing the CSS style for the results section. |
|
""" |
|
if n_clicks and n_clicks > 0: |
|
return {"display": "block"} |
|
else: |
|
return {"display": "none"} |
|
|
|
|
|
|
|
@app.callback( |
|
[Output('log-reader', 'children'), |
|
Output('acc-visualization', 'figure'), |
|
Output('acc-visualization', 'style'), |
|
Output('loss-visualization', 'figure'), |
|
Output('loss-visualization', 'style')], |
|
[Input('interval-widget', 'n_intervals'), |
|
Input('process-button', 'n_clicks')], |
|
[State('selected-model', 'value'), |
|
State('devcloud-username-input', 'value'), |
|
State('devcloud-password-input', 'value')] |
|
) |
|
def refresh_metrics(n_intervals, nb_clicks, selected_model, devcloud_username, devcloud_password): |
|
""" |
|
Updates the log display and training metrics based on user actions and intervals. |
|
|
|
This function handles the following: |
|
- Executes the training script when the run button is clicked and updates the logs. |
|
- Periodically checks for new training metrics and updates the accuracy and loss graphs. |
|
- Manages the display of the log and metrics components based on the training status. |
|
|
|
Args: |
|
n_intervals (int): The number of intervals that have passed for the interval component. |
|
nb_clicks (int): The number of times the run button has been clicked. |
|
selected_model (str): The selected model from the dropdown. |
|
devcloud_username (str): The username for authentication. |
|
devcloud_password (str): The password for authentication. |
|
|
|
Returns: |
|
tuple: A tuple containing: |
|
- str: The updated log messages. |
|
- dict: The figure data for the accuracy graph. |
|
- dict: The CSS style to display or hide the accuracy graph. |
|
- dict: The figure data for the loss graph. |
|
- dict: The CSS style to display or hide the loss graph. |
|
|
|
Raises: |
|
PreventUpdate: If the callback context is not triggered by a relevant input. |
|
""" |
|
|
|
global logs, new_training |
|
|
|
callback_context = dash.callback_context |
|
if not callback_context.triggered: |
|
raise PreventUpdate |
|
|
|
button = callback_context.triggered[0]['prop_id'].split('.')[0] |
|
|
|
if button == 'process-button' and nb_clicks: |
|
if devcloud_username and devcloud_password: |
|
st_script = f"stm32ai-modelzoo-services/{selected_model}/src/stm32ai_main.py" |
|
execute_async(st_script, devcloud_username, devcloud_password) |
|
new_training = True |
|
logs.append("Starting application ...") |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
|
else: |
|
logs.append("Please enter both ST Developer Cloud username and password:") |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
|
|
|
elif button == 'interval-widget': |
|
if not new_training: |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
|
|
|
outputs_folder = "experiments_outputs" |
|
|
|
if not os.path.exists(outputs_folder): |
|
os.makedirs(outputs_folder) |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
|
|
|
dated_directories = [d for d in os.listdir(outputs_folder) if os.path.isdir(os.path.join(outputs_folder, d)) and d.startswith('20')] |
|
if dated_directories: |
|
recent_directory = max(dated_directories, key=lambda d: datetime.strptime(d, '%Y_%m_%d_%H_%M_%S')) |
|
train_metrics_file = os.path.join(outputs_folder, recent_directory, 'logs', 'metrics', 'train_metrics.csv') |
|
print(f"Metrics file : {train_metrics_file}") |
|
if os.path.exists(train_metrics_file) and new_training: |
|
metrics_dataframe = pd.read_csv(train_metrics_file) |
|
if not metrics_dataframe.empty: |
|
figures = [] |
|
metrics_pairs = [ |
|
('accuracy', 'val_accuracy'), |
|
('loss', 'val_loss'), |
|
('oks', 'val_oks'), |
|
('val_map',) |
|
] |
|
|
|
for pair in metrics_pairs: |
|
if len(pair) == 2: |
|
train_metric, val_metric = pair |
|
if train_metric in metrics_dataframe.columns and val_metric in metrics_dataframe.columns: |
|
fig = { |
|
'data': [ |
|
{ |
|
'x': metrics_dataframe['epoch'], |
|
'y': metrics_dataframe[train_metric], |
|
'type': 'line', |
|
'name': train_metric.capitalize(), |
|
'line': {'color': '#FFD200', 'width': 2, 'dash': 'solid'}, |
|
'hoverinfo': 'x+y+name', |
|
'hoverlabel': {'bgcolor': '#EEEFF1', 'font': {'color': '#525A63'}} |
|
}, |
|
{ |
|
'x': metrics_dataframe['epoch'], |
|
'y': metrics_dataframe[val_metric], |
|
'type': 'line', |
|
'name': val_metric.capitalize(), |
|
'line': {'color': '#3CB4E6', 'width': 2, 'dash': 'solid'}, |
|
'hoverinfo': 'x+y+name', |
|
'hoverlabel': {'bgcolor': '#EEEFF1', 'font': {'color': '#525A63'}} |
|
} |
|
], |
|
'layout': { |
|
'xaxis': { |
|
'title': 'Epochs', |
|
'showgrid': True, |
|
'gridcolor': '#EEEFF1', |
|
'tickangle': 45 |
|
}, |
|
'yaxis': { |
|
'title': train_metric.capitalize(), |
|
'showgrid': True, |
|
'gridcolor': '#EEEFF1' |
|
}, |
|
'showlegend': True, |
|
'legend': { |
|
'x': 1, |
|
'y': 1, |
|
'traceorder': 'normal', |
|
'font': {'size': 10}, |
|
'bgcolor': '#EEEFF1', |
|
'bordercolor': '#A6ADB5', |
|
'borderwidth': 1 |
|
}, |
|
'hovermode': 'closest', |
|
'plot_bgcolor': '#ffffff' |
|
} |
|
} |
|
figures.append(fig) |
|
elif len(pair) == 1: |
|
val_metric = pair[0] |
|
if val_metric in metrics_dataframe.columns: |
|
fig = { |
|
'data': [ |
|
{ |
|
'x': metrics_dataframe['epoch'], |
|
'y': metrics_dataframe[val_metric], |
|
'type': 'line', |
|
'name': val_metric.capitalize(), |
|
'line': {'color': '#3CB4E6', 'width': 2, 'dash': 'solid'}, |
|
'hoverinfo': 'x+y+name', |
|
'hoverlabel': {'bgcolor': '#EEEFF1', 'font': {'color': '#525A63'}} |
|
} |
|
], |
|
'layout': { |
|
'xaxis': { |
|
'title': 'Epochs', |
|
'showgrid': True, |
|
'gridcolor': '#EEEFF1', |
|
'tickangle': 45 |
|
}, |
|
'yaxis': { |
|
'title': val_metric.capitalize(), |
|
'showgrid': True, |
|
'gridcolor': '#EEEFF1' |
|
}, |
|
'showlegend': True, |
|
'legend': { |
|
'x': 1, |
|
'y': 1, |
|
'traceorder': 'normal', |
|
'font': {'size': 10}, |
|
'bgcolor': '#EEEFF1', |
|
'bordercolor': '#A6ADB5', |
|
'borderwidth': 1 |
|
}, |
|
'hovermode': 'closest', |
|
'plot_bgcolor': '#ffffff' |
|
} |
|
} |
|
figures.append(fig) |
|
|
|
if figures: |
|
return "\n".join(logs), figures[0], {'display': 'block'}, figures[1] if len(figures) > 1 else {}, {'display': 'block'} |
|
else: |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
|
else: |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
|
else: |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
|
else: |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'} |
|
|
|
raise PreventUpdate |
|
@app.callback( |
|
Output('submission-outcome', 'children'), |
|
[Input('apply-button', 'n_clicks'), |
|
Input('process-button', 'n_clicks')], |
|
[State({'type': 'yaml-setting', 'index': ALL}, 'id'), |
|
State({'type': 'yaml-setting', 'index': ALL}, 'value'), |
|
State('selected-model', 'value'), |
|
State('devcloud-username-input', 'value'), |
|
State('devcloud-password-input', 'value')] |
|
) |
|
def process_button_actions(submit_clicks, exec_nb_clicks, form_input_ids, form_input_values, selected_model, devcloud_username, devcloud_password): |
|
""" |
|
Handles the actions triggered by the submit and run buttons. |
|
|
|
This function processes the form data when the submit button is clicked, |
|
updates the corresponding YAML file, and executes the training script when |
|
the run button is clicked. |
|
|
|
Args: |
|
submit_clicks (int): The number of times the submit button has been clicked. |
|
exec_nb_clicks (int): The number of times the execution/run button has been clicked. |
|
form_input_ids (list): A list of dictionaries containing the IDs of the form inputs. |
|
form_input_values (list): A list of values from the form inputs. |
|
selected_model (str): The selected model from the dropdown. |
|
devcloud_username (str): The username for DevCloud authentication. |
|
devcloud_password (str): The password for DevCloud authentication. |
|
|
|
Returns: |
|
str: A message indicating the result of the action, such as successful YAML update or script execution status. |
|
|
|
Raises: |
|
PreventUpdate: If the callback context is not triggered by a relevant input or if no action is taken. |
|
""" |
|
new_fields = [] |
|
|
|
callback_context = dash.callback_context |
|
if not callback_context.triggered: |
|
raise PreventUpdate |
|
|
|
triggered_button = callback_context.triggered[0]['prop_id'].split('.')[0] |
|
|
|
if triggered_button == 'apply-button': |
|
if submit_clicks: |
|
try: |
|
form_fields_data = {} |
|
for i in range(len(form_input_ids)): |
|
input_id = form_input_ids[i]['index'] |
|
input_value = form_input_values[i] |
|
form_fields_data[input_id] = input_value |
|
|
|
yaml_file_path = local_yamls.get(selected_model) |
|
if yaml_file_path : |
|
yaml_parser = ruamel.yaml.YAML() |
|
with open(yaml_file_path , 'r') as file: |
|
current_yaml_data = yaml_parser.load(file) |
|
|
|
updated_yaml_data = process_form_configs(form_fields_data) |
|
for key, value in updated_yaml_data.items(): |
|
keys = key.split('.') |
|
nested_dict = current_yaml_data |
|
for k in keys[:-1]: |
|
nested_dict = nested_dict.setdefault(k, {}) |
|
if nested_dict[keys[-1]] != value: |
|
nested_dict[keys[-1]] = value |
|
new_fields.append(key) |
|
|
|
with open(yaml_file_path , 'w') as file: |
|
yaml_parser.dump(current_yaml_data, file) |
|
|
|
return f"User config yaml file has been updated successfully ! Updated fields are: {', '.join(new_fields)}" |
|
else: |
|
return f"ERROR: No user config yaml found for '{selected_model}'." |
|
except Exception as e: |
|
return f"ERROR: UPDATING USER CONFIG YAML file: {e}" |
|
else: |
|
raise PreventUpdate |
|
elif triggered_button == 'process-button': |
|
if exec_nb_clicks: |
|
st_script = f"stm32ai-modelzoo-services/{selected_model}/src/stm32ai_main.py" |
|
execute_async(st_script, devcloud_username, devcloud_password) |
|
return "Application is running ..." |
|
else: |
|
raise PreventUpdate |
|
|
|
|
|
|
|
@app.callback( |
|
Output('download-action', 'style'), |
|
[Input('interval-widget', 'n_intervals')], |
|
[State('selected-model', 'value')] |
|
) |
|
def toggle_download_button(n_intervals, selected_model): |
|
""" |
|
Toggles the display of the download button based on the existence of output directories. |
|
|
|
This function checks if the output directories for the selected model exist and |
|
toggles the display of the download button accordingly. |
|
|
|
Args: |
|
n_intervals (int): The number of intervals that have passed for the interval component. |
|
model_choice (str): The selected model from the dropdown. |
|
|
|
Returns: |
|
dict: A dictionary containing the CSS style to either display or hide the download button. |
|
""" |
|
out_directory = os.path.join(os.getcwd(), "experiments_outputs") |
|
|
|
if not os.path.exists(out_directory ): |
|
return {'display': 'none'} |
|
|
|
output_subdirectories = [d for d in os.listdir(out_directory ) if os.path.isdir(os.path.join(out_directory , d)) and d.startswith('20')] |
|
|
|
if output_subdirectories: |
|
return {'display': 'block'} |
|
return {'display': 'none'} |
|
|
|
|
|
@app.callback( |
|
Output('download-resource', 'data'), |
|
[Input('download-action', 'n_clicks')], |
|
[State('selected-model', 'value')] |
|
) |
|
def generate_download_link(n_clicks, selected_model): |
|
""" |
|
Generates a download link based on the selected model and operation mode. |
|
|
|
This function reads the YAML configuration for the selected model, determines the operation mode, |
|
and generates a download link for the appropriate file (ZIP or ELF/BIN) based on the operation mode. |
|
|
|
Args: |
|
click_count (int): The number of times the download button has been clicked. |
|
selected_model (str): The selected model from the dropdown. |
|
|
|
Returns: |
|
dcc.send_file: A Dash component to send the file for download. |
|
|
|
Raises: |
|
PreventUpdate: If no relevant action is taken or the required files do not exist. |
|
""" |
|
|
|
if n_clicks is None: |
|
raise PreventUpdate |
|
|
|
|
|
output_directory = os.path.join(os.getcwd(), "./experiments_outputs") |
|
|
|
if not os.path.exists(output_directory ): |
|
raise PreventUpdate |
|
|
|
|
|
timestamped_directories = [d for d in os.listdir(output_directory ) if os.path.isdir(os.path.join(output_directory , d)) and d.startswith('20')] |
|
|
|
timestamped_directories = [ |
|
d for d in os.listdir(output_directory) |
|
if os.path.isdir(os.path.join(output_directory, d)) and d.startswith("20") |
|
] |
|
|
|
if timestamped_directories: |
|
recent_directory = max( |
|
timestamped_directories, |
|
key=lambda d: datetime.strptime(d, "%Y_%m_%d_%H_%M_%S") |
|
) |
|
recent_directory_path = os.path.join(output_directory, recent_directory) |
|
zip_file_path = os.path.join(recent_directory_path, f"{recent_directory}.zip") |
|
|
|
|
|
if not os.path.exists(zip_file_path): |
|
create_archive(zip_file_path, recent_directory_path) |
|
|
|
|
|
if os.path.exists(zip_file_path): |
|
return dcc.send_file(zip_file_path) |
|
|
|
raise PreventUpdate |
|
|
|
@server.route('/download/<path:subpath>') |
|
def download_file(subpath): |
|
""" |
|
Route to download a file from the server. |
|
|
|
Parameters: |
|
- subpath (str): The subpath of the file to be downloaded, relative to the './experiments_outputs' directory. |
|
|
|
Returns: |
|
- Response: A Flask response object to send the file as an attachment if it exists. |
|
- tuple: A tuple containing an error message and a 404 status code if the file is not found. |
|
""" |
|
file_path = os.path.join(os.getcwd(), './experiments_outputs', subpath) |
|
if os.path.exists(file_path): |
|
return send_file(file_path, as_attachment=True) |
|
else: |
|
return "File not found", 404 |
|
|
|
|
|
if __name__ == '__main__': |
|
app.run_server(host='0.0.0.0',port=7860, dev_tools_ui=True, dev_tools_hot_reload=True, threaded=True) |