Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding: utf-8 | |
import gradio as gr | |
import numpy as np | |
from transformers import ( | |
AutoModelForSequenceClassification, | |
AutoTokenizer, | |
TextClassificationPipeline, | |
pipeline, | |
) | |
from sklearn import preprocessing | |
from langdetect import detect | |
from matplotlib import pyplot as plt | |
import imageio | |
# move constants into extra file | |
DESCRIPTION = """Diese Anwendung klassifiziert Vorstöße in Departements und schlägt auch ein | |
mögliches Office vor. Bitte bewerten Sie für sich, ob Sie dem Office-Vorschlag | |
nachkommen wollen, oder Ihren Vorstoß in einem anderen Office sehen, und leiten Sie | |
nach eigenem Ermessen weiter. \n\n | |
Cette application classe les requêtes dans les départements et propose également un | |
office possible. Veuillez évaluer pour vous-même si vous souhaitez suivre la | |
proposition d'office ou si vous souhaitez voir votre démarche dans un autre office | |
et transmettez à votre discrétion.""" | |
TITLE_DE = "Automatisierte Einteilung von Vorstößen in Departements & Offices" | |
TITLE_FR = "Où aller ? Classification des départements & bureaux" | |
UNKNOWN_LANG_TEXT = ( | |
"The language is not recognized, it must be either in German or in French." | |
) | |
PLACEHOLDER_TEXT = "Geben Sie bitte den Titel und den 'Submitted Text' des Vorstoss ein.\nVeuillez entrer le titre et le 'Submitted Text' de la requête." | |
UNSURE_DE_TEXT = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n" | |
UNSURE_FR_TEXT = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n" | |
ML_MODEL_SURE = 0.6 | |
BARS_DEP_FR = ( | |
"DDPS", | |
"DFI", | |
"AS-MPC", | |
"DFJP", | |
"DEFR", | |
"DETEC", | |
"DFAE", | |
"Parl", | |
"ChF", | |
"DFF", | |
"AF", | |
"TF", | |
) | |
BARS_DEP_DE = ( | |
"VBS", | |
"EDI", | |
"AB-BA", | |
"EJPD", | |
"WBF", | |
"UVEK", | |
"EDA", | |
"Parl", | |
"BK", | |
"EFD", | |
"BV", | |
"BGer", | |
) | |
def load_model(modelFolder): | |
"""Loads model from model_folder & creates a text classification pipeline.""" | |
model = AutoModelForSequenceClassification.from_pretrained(modelFolder) | |
tokenizer = AutoTokenizer.from_pretrained(modelFolder) | |
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer) | |
return pipe | |
def translate_to_de(SubmittedText): | |
"""Translates french user input to German for the model to reach better classification.""" | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de") | |
translatedText = translator(SubmittedText[0:1000]) | |
text = translatedText[0]["translation_text"] | |
return text | |
def create_bar_plot(rates, barnames): | |
y_pos = np.arange(len(barnames)) | |
plt.barh(y_pos, rates) | |
plt.yticks(y_pos, barnames) | |
# Save the bar chart as png and load it (enables better display) | |
plt.savefig("rates.png") | |
im = imageio.v2.imread("rates.png") | |
return im, barnames | |
def show_chosen_category(barnames, rates, language): | |
"""Creates the output text | |
- adds disclaimer if ML model is not sure | |
- when unsure, adds all categories with prob. > 10% to output""" | |
maxRate = np.max(rates) | |
maxIndex = np.argmax(rates) | |
distance = "\t\t\t\t\t" | |
# ML model not sure if highest probability < 60% | |
if maxRate < ML_MODEL_SURE: | |
name = UNSURE_FR_TEXT if language == "fr" else UNSURE_DE_TEXT | |
# Show each department that has a probability > 10% | |
i = 0 | |
while i == 0: | |
if rates[maxIndex] >= 0.1: | |
chosenScore = str(rates[maxIndex])[2:4] | |
chosenCat = barnames[maxIndex] | |
name = name + "\t" + chosenScore + "%" + distance + chosenCat + "\n" | |
rates[maxIndex] = 0 | |
maxIndex = np.argmax(rates) | |
else: | |
i = 1 | |
# ML model pretty sure, show only one department | |
else: | |
name = str(maxRate)[2:4] + "%" + distance + barnames[maxIndex] | |
return name | |
pipeDep = load_model("saved_model_dep") | |
pipeOffice = load_model("saved_model_office") | |
labelencoderOffice = preprocessing.LabelEncoder() | |
labelencoderOffice.classes_ = np.load("classes_office.npy") | |
def textclassification(SubmittedText): | |
language = detect(SubmittedText) | |
# Translate the input to german if necessary | |
if language == "fr": | |
SubmittedText = translate_to_de(SubmittedText) | |
elif language != "de": | |
return UNKNOWN_LANG_TEXT, None | |
# Make the prediction with the 1000 first characters | |
images = [] | |
chosenCategoryTexts = [] | |
labelsDep = BARS_DEP_FR if language == "fr" else BARS_DEP_DE | |
labelsOffice = labelencoderOffice.classes_ | |
for pipe, barnames in zip((pipeDep, pipeOffice), (labelsDep, labelsOffice)): | |
plt.clf() | |
prediction = pipe(SubmittedText[0:1000], return_all_scores=True) | |
rates = [row["score"] for row in prediction[0]] | |
# Create barplot & output text | |
im, barnames = create_bar_plot(rates, barnames) | |
images.append(im) | |
chosenCategoryText = show_chosen_category(barnames, rates, language) | |
chosenCategoryTexts.append(chosenCategoryText) | |
# return chosenCategoryText & image for both predictions | |
return chosenCategoryTexts[0], images[0], chosenCategoryTexts[1], images[1] | |
# TODO set example picture upon loading | |
# TODO vielleicht ein paar Sachen zum Einstellen im Frontend? | |
# Launch UI | |
with gr.Blocks( | |
# Set theme matching BK CH | |
gr.themes.Monochrome( | |
primary_hue="red", | |
secondary_hue="red", | |
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"], | |
) | |
) as demo: | |
gr.Markdown(f"# {TITLE_DE}\n # {TITLE_FR}\n\n {DESCRIPTION}") | |
# Organize layout in three columns for input, prediction I and prediction II | |
with gr.Row(): | |
with gr.Column(scale=2): | |
name = gr.Textbox(label="", lines=28, placeholder=PLACEHOLDER_TEXT) | |
predict_btn = gr.Button("Submit | Soumettre") | |
with gr.Column(scale=2): | |
output_text_dep = gr.Textbox(label="Departement prediction:") | |
output_image_dep = gr.Image(label="Departement") | |
with gr.Column(scale=2): | |
output_text_office = gr.Textbox(label="Office prediction:") | |
output_image_office = gr.Image(label="Office") | |
predict_btn.click( | |
fn=textclassification, | |
inputs=name, | |
outputs=[ | |
output_text_dep, | |
output_image_dep, | |
output_text_office, | |
output_image_office, | |
], | |
api_name="predict", | |
) | |
demo.launch() | |