File size: 4,513 Bytes
7278f27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a420fc
7278f27
 
 
 
 
 
7a420fc
7278f27
 
 
 
 
7a420fc
7278f27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from __future__ import annotations
import psutil
import pandas as pd
import streamlit as st
import plotly.express as px
from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
from zeroshot_classification.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier

print(f"Total mem: {psutil.virtual_memory().total}")

def init_state(key: str):
    if key not in st.session_state:
        st.session_state[key] = None


for k in [
    "current_model",
    "current_model_option",
    "current_method_option",
    "current_prediction",
    "current_chart",
]:
    init_state(k)


def load_model(model_option: str, method_option: str, random_state: int = 0):
    with st.spinner("Loading selected model..."):
        if method_option == "Natural Language Inference":
            st.session_state.current_model = NLIZeroshotClassifier(
                model_name=model_option, random_state=random_state
            )
        else:
            st.session_state.current_model = NSPZeroshotClassifier(
                model_name=model_option, random_state=random_state
            )
        st.success("Model loaded!")


def visualize_output(labels: list[str], probabilities: list[float]):
    data = pd.DataFrame({"labels": labels, "probability": probabilities}).sort_values(
        by="probability", ascending=False
    )
    chart = px.bar(
        data,
        x="probability",
        y="labels",
        color="labels",
        orientation="h",
        height=290,
        width=500,
    ).update_layout(
        {
            "xaxis": {"title": "probability", "visible": True, "showticklabels": True},
            "yaxis": {"title": None, "visible": True, "showticklabels": True},
            "margin": dict(
                l=10,  # left
                r=10,  # right
                t=50,  # top
                b=10,  # bottom
            ),
            "showlegend": False,
        }
    )
    return chart


st.title("Zero-shot Turkish Text Classification")
method_option = st.radio(
    "Select a zero-shot classification method.",
    [
        METHOD_OPTIONS["nli"],
        METHOD_OPTIONS["nsp"],
    ],
)
if method_option == METHOD_OPTIONS["nli"]:
    model_option = st.selectbox(
        "Select a natural language inference model.", NLI_MODEL_OPTIONS, index=3
    )
if method_option == METHOD_OPTIONS["nsp"]:
    model_option = st.selectbox(
        "Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS, index=0
    )

if model_option != st.session_state.current_model_option:
    st.session_state.current_model_option = model_option
    st.session_state.current_method_option = method_option
    load_model(
        st.session_state.current_model_option, st.session_state.current_method_option
    )


st.header("Configure prompts and labels")
col1, col2 = st.columns(2)
col1.subheader("Candidate labels")
labels = col1.text_area(
    label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.",
    value="tatli,burger,kebab,diğer,tuzlu",
    key="current_labels",
)

col1.header("Make predictions")
text = col1.text_area(
    "Enter a sentence or a paragraph to classify.",
    value="baklava",
    key="current_text",
)
col2.subheader("Prompt template")
prompt_template = col2.text_area(
    label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.",
    value="{}",
    key="current_template",
)
col2.header("")


make_pred = col1.button("Predict")
if make_pred:
    st.session_state.current_prediction = (
        st.session_state.current_model.predict_on_texts(
            [st.session_state.current_text],
            candidate_labels=st.session_state.current_labels.split(","),
            prompt_template=st.session_state.current_template,
        )
    )
    if "scores" in st.session_state.current_prediction[0]:
        st.session_state.current_chart = visualize_output(
            st.session_state.current_prediction[0]["labels"],
            st.session_state.current_prediction[0]["scores"],
        )
    elif "probabilities" in st.session_state.current_prediction[0]:
        st.session_state.current_chart = visualize_output(
            st.session_state.current_prediction[0]["labels"],
            st.session_state.current_prediction[0]["probabilities"],
        )
    col2.plotly_chart(st.session_state.current_chart, use_container_width=True)