bk-departements / app.py
BK-AI's picture
clear plot area after first plot
5f2ec77
raw
history blame
6.5 kB
#!/usr/bin/env python
# coding: utf-8
import gradio as gr
import numpy as np
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
TextClassificationPipeline,
pipeline,
)
from sklearn import preprocessing
from langdetect import detect
from matplotlib import pyplot as plt
import imageio
# move constants into extra file
DESCRIPTION = """Diese Anwendung klassifiziert Vorstöße in Departements und schlägt auch ein
mögliches Office vor. Bitte bewerten Sie für sich, ob Sie dem Office-Vorschlag
nachkommen wollen, oder Ihren Vorstoß in einem anderen Office sehen, und leiten Sie
nach eigenem Ermessen weiter. \n\n
Cette application classe les requêtes dans les départements et propose également un
office possible. Veuillez évaluer pour vous-même si vous souhaitez suivre la
proposition d'office ou si vous souhaitez voir votre démarche dans un autre office
et transmettez à votre discrétion."""
TITLE_DE = "Automatisierte Einteilung von Vorstößen in Departements & Offices"
TITLE_FR = "Où aller ? Classification des départements & bureaux"
UNKNOWN_LANG_TEXT = (
"The language is not recognized, it must be either in German or in French."
)
PLACEHOLDER_TEXT = "Geben Sie bitte den Titel und den 'Submitted Text' des Vorstoss ein.\nVeuillez entrer le titre et le 'Submitted Text' de la requête."
UNSURE_DE_TEXT = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n"
UNSURE_FR_TEXT = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n"
ML_MODEL_SURE = 0.6
BARS_DEP_FR = (
"DDPS",
"DFI",
"AS-MPC",
"DFJP",
"DEFR",
"DETEC",
"DFAE",
"Parl",
"ChF",
"DFF",
"AF",
"TF",
)
BARS_DEP_DE = (
"VBS",
"EDI",
"AB-BA",
"EJPD",
"WBF",
"UVEK",
"EDA",
"Parl",
"BK",
"EFD",
"BV",
"BGer",
)
def load_model(modelFolder):
"""Loads model from model_folder & creates a text classification pipeline."""
model = AutoModelForSequenceClassification.from_pretrained(modelFolder)
tokenizer = AutoTokenizer.from_pretrained(modelFolder)
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer)
return pipe
def translate_to_de(SubmittedText):
"""Translates french user input to German for the model to reach better classification."""
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de")
translatedText = translator(SubmittedText[0:1000])
text = translatedText[0]["translation_text"]
return text
def create_bar_plot(rates, barnames):
y_pos = np.arange(len(barnames))
plt.barh(y_pos, rates)
plt.yticks(y_pos, barnames)
# Save the bar chart as png and load it (enables better display)
plt.savefig("rates.png")
im = imageio.v2.imread("rates.png")
return im, barnames
def show_chosen_category(barnames, rates, language):
"""Creates the output text
- adds disclaimer if ML model is not sure
- when unsure, adds all categories with prob. > 10% to output"""
maxRate = np.max(rates)
maxIndex = np.argmax(rates)
distance = "\t\t\t\t\t"
# ML model not sure if highest probability < 60%
if maxRate < ML_MODEL_SURE:
name = UNSURE_FR_TEXT if language == "fr" else UNSURE_DE_TEXT
# Show each department that has a probability > 10%
i = 0
while i == 0:
if rates[maxIndex] >= 0.1:
chosenScore = str(rates[maxIndex])[2:4]
chosenCat = barnames[maxIndex]
name = name + "\t" + chosenScore + "%" + distance + chosenCat + "\n"
rates[maxIndex] = 0
maxIndex = np.argmax(rates)
else:
i = 1
# ML model pretty sure, show only one department
else:
name = str(maxRate)[2:4] + "%" + distance + barnames[maxIndex]
return name
pipeDep = load_model("saved_model_dep")
pipeOffice = load_model("saved_model_office")
labelencoderOffice = preprocessing.LabelEncoder()
labelencoderOffice.classes_ = np.load("classes_office.npy")
def textclassification(SubmittedText):
language = detect(SubmittedText)
# Translate the input to german if necessary
if language == "fr":
SubmittedText = translate_to_de(SubmittedText)
elif language != "de":
return UNKNOWN_LANG_TEXT, None
# Make the prediction with the 1000 first characters
images = []
chosenCategoryTexts = []
labelsDep = BARS_DEP_FR if language == "fr" else BARS_DEP_DE
labelsOffice = labelencoderOffice.classes_
for pipe, barnames in zip((pipeDep, pipeOffice), (labelsDep, labelsOffice)):
plt.clf()
prediction = pipe(SubmittedText[0:1000], return_all_scores=True)
rates = [row["score"] for row in prediction[0]]
# Create barplot & output text
im, barnames = create_bar_plot(rates, barnames)
images.append(im)
chosenCategoryText = show_chosen_category(barnames, rates, language)
chosenCategoryTexts.append(chosenCategoryText)
# return chosenCategoryText & image for both predictions
return chosenCategoryTexts[0], images[0], chosenCategoryTexts[1], images[1]
# TODO set example picture upon loading
# TODO vielleicht ein paar Sachen zum Einstellen im Frontend?
# Launch UI
with gr.Blocks(
# Set theme matching BK CH
gr.themes.Monochrome(
primary_hue="red",
secondary_hue="red",
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
)
) as demo:
gr.Markdown(f"# {TITLE_DE}\n # {TITLE_FR}\n\n {DESCRIPTION}")
# Organize layout in three columns for input, prediction I and prediction II
with gr.Row():
with gr.Column(scale=2):
name = gr.Textbox(label="", lines=28, placeholder=PLACEHOLDER_TEXT)
predict_btn = gr.Button("Submit | Soumettre")
with gr.Column(scale=2):
output_text_dep = gr.Textbox(label="Departement prediction:")
output_image_dep = gr.Image(label="Departement")
with gr.Column(scale=2):
output_text_office = gr.Textbox(label="Office prediction:")
output_image_office = gr.Image(label="Office")
predict_btn.click(
fn=textclassification,
inputs=name,
outputs=[
output_text_dep,
output_image_dep,
output_text_office,
output_image_office,
],
api_name="predict",
)
demo.launch()