Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding: utf-8 | |
import gradio as gr | |
import numpy as np | |
from transformers import ( | |
AutoModelForSequenceClassification, | |
AutoTokenizer, | |
TextClassificationPipeline, | |
pipeline, | |
) | |
from sklearn import preprocessing | |
from langdetect import detect | |
from matplotlib import pyplot as plt | |
import imageio | |
import logging | |
import warnings | |
logging.getLogger().setLevel(logging.INFO) | |
DESCRIPTION = """Diese Anwendung teilt Vorstösse an das federführende Departement zu und | |
macht einen Vorschlag für das zuständige Amt. Der Vorschlag der Anwendung ist nicht | |
100% richtig. Der Zuteilungsvorschlag muss von einer Fachperson geprüft und die | |
effektive Zuteilung muss nach eigenem Ermessen erfolgen. \n\n | |
Cette application attribue les interventions au département chef de file et fait une | |
proposition à l'office compétent. La proposition de l'application n'est pas correcte | |
à 100%. La proposition d'attribution doit être vérifiée par un spécialiste et l'attribution | |
effective doit être faite à la discrétion de l'utilisateur.""" | |
TITLE_DE = ( | |
"Automatische Zuteilung von Vorstössen an das federführende Departement bzw. Amt" | |
) | |
TITLE_FR = "Où aller ? Classification des départements & bureaux" | |
UNKNOWN_LANG_TEXT = ( | |
"The language is not recognized, it must be either in German or in French." | |
) | |
PLACEHOLDER_TEXT = "Geben Sie bitte den Titel und den 'Submitted Text' des Vorstoss ein.\nVeuillez entrer le titre et le 'Submitted Text' de la requête." | |
UNSURE_DE_TEXT = "Das ML-Modell ist nicht sicher. Die Zuteilung könnte sein: \n\n" | |
UNSURE_FR_TEXT = "Le modèle ML n'est pas sûr. L'allocation pourrait être: \n\n" | |
ML_MODEL_SURE = 0.6 | |
BARS_DEP_FR = ( | |
"DDPS", | |
"DFI", | |
"AS-MPC", | |
"DFJP", | |
"DEFR", | |
"DETEC", | |
"DFAE", | |
"Parl", | |
"ChF", | |
"DFF", | |
"AF", | |
"TF", | |
) | |
BARS_DEP_DE = ( | |
"VBS", | |
"EDI", | |
"AB-BA", | |
"EJPD", | |
"WBF", | |
"UVEK", | |
"EDA", | |
"Parl", | |
"BK", | |
"EFD", | |
"BV", | |
"BGer", | |
) | |
def load_model(modelFolder): | |
"""Loads model from model_folder & creates a text classification pipeline.""" | |
model = AutoModelForSequenceClassification.from_pretrained(modelFolder) | |
tokenizer = AutoTokenizer.from_pretrained(modelFolder) | |
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer) | |
return pipe | |
def translate_to_de(SubmittedText): | |
"""Translates french user input to German for the model to reach better classification.""" | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de") | |
translatedText = translator(SubmittedText[0:1000]) | |
text = translatedText[0]["translation_text"] | |
return text | |
def create_bar_plot(rates, barnames): | |
y_pos = np.arange(len(barnames)) | |
plt.barh(y_pos, rates) | |
plt.yticks(y_pos, barnames) | |
# Save the bar chart as png and load it (enables better display) | |
plt.savefig("rates.png") | |
im = imageio.v2.imread("rates.png") | |
return im, barnames | |
def show_chosen_category(barnames, rates, language): | |
"""Creates the output text | |
- adds disclaimer if ML model is not sure | |
- when unsure, adds all categories with prob. > 10% to output""" | |
maxRate = np.max(rates) | |
maxIndex = np.argmax(rates) | |
distance = "\t\t\t\t\t" | |
# ML model not sure if highest probability < 60% | |
if maxRate < ML_MODEL_SURE: | |
name = UNSURE_FR_TEXT if language == "fr" else UNSURE_DE_TEXT | |
# Show each department that has a probability > 10% | |
i = 0 | |
while i == 0: | |
if rates[maxIndex] >= 0.1: | |
chosenScore = str(rates[maxIndex])[2:4] | |
chosenCat = barnames[maxIndex] | |
name = name + "\t" + chosenScore + "%" + distance + chosenCat + "\n" | |
rates[maxIndex] = 0 | |
maxIndex = np.argmax(rates) | |
else: | |
i = 1 | |
# ML model pretty sure, show only one department | |
else: | |
name = str(maxRate)[2:4] + "%" + distance + barnames[maxIndex] | |
return name | |
pipeDep = load_model("saved_model_dep") | |
pipeOffice = load_model("saved_model_office") | |
labelencoderOffice = preprocessing.LabelEncoder() | |
labelencoderOffice.classes_ = np.load("classes_office.npy") | |
def textclassification(SubmittedText): | |
language = detect(SubmittedText) | |
logging.info( | |
f"SubmittedText received. Detected language: {language}. SubmittedText: {SubmittedText}" | |
) | |
# Translate the input to german if necessary | |
if language == "fr": | |
SubmittedText = translate_to_de(SubmittedText) | |
elif language != "de": | |
return UNKNOWN_LANG_TEXT, None, None, None | |
# Make the prediction with the 1000 first characters | |
images = [] | |
chosenCategoryTexts = [] | |
labelsDep = BARS_DEP_FR if language == "fr" else BARS_DEP_DE | |
labelsOffice = labelencoderOffice.classes_ | |
for pipe, barnames in zip((pipeDep, pipeOffice), (labelsDep, labelsOffice)): | |
plt.clf() | |
# catch deprecation warning, as new functionality following the deprecated way | |
# sorts results the wrong way and cannot be easily fixed | |
with warnings.catch_warnings(): | |
warnings.filterwarnings("ignore") | |
prediction = pipe(SubmittedText[0:1000], return_all_scores=True) | |
rates = [row["score"] for row in prediction[0]] | |
# Create barplot & output text | |
im, barnames = create_bar_plot(rates, barnames) | |
images.append(im) | |
chosenCategoryText = show_chosen_category(barnames, rates, language) | |
chosenCategoryTexts.append(chosenCategoryText) | |
# return chosenCategoryText & image for both predictions | |
logging.info( | |
f"Prediction Department: {chosenCategoryTexts[0]}\n\nPrediction Amt: {chosenCategoryTexts[1]}" | |
) | |
return chosenCategoryTexts[0], images[0], chosenCategoryTexts[1], images[1] | |
# Launch UI | |
with gr.Blocks( | |
# Set theme matching BK CH | |
gr.themes.Monochrome( | |
primary_hue="red", | |
secondary_hue="red", | |
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"], | |
) | |
) as demo: | |
gr.Markdown(f"# {TITLE_DE}\n # {TITLE_FR}\n\n {DESCRIPTION}") | |
# Organize layout in three columns for input, prediction I and prediction II | |
with gr.Row(): | |
with gr.Column(scale=2): | |
name = gr.Textbox( | |
label="Vorstosstext:", lines=28, placeholder=PLACEHOLDER_TEXT | |
) | |
predict_btn = gr.Button("Submit | Soumettre") | |
with gr.Column(scale=2): | |
output_text_dep = gr.Textbox(label="Vorschlag Departement:") | |
output_image_dep = gr.Image(label="Departement") | |
with gr.Column(scale=2): | |
output_text_office = gr.Textbox(label="Vorschlag Amt:") | |
output_image_office = gr.Image(label="Amt") | |
predict_btn.click( | |
fn=textclassification, | |
inputs=name, | |
outputs=[ | |
output_text_dep, | |
output_image_dep, | |
output_text_office, | |
output_image_office, | |
], | |
api_name="predict", | |
) | |
demo.launch() | |