memojja emrecan commited on
Commit
7278f27
0 Parent(s):

Duplicate from emrecan/zero-shot-turkish

Browse files

Co-authored-by: Emrecan Çelik <[email protected]>

Files changed (6) hide show
  1. .gitattributes +27 -0
  2. .gitignore +2 -0
  3. README.md +39 -0
  4. app.py +135 -0
  5. models.py +26 -0
  6. requirements.txt +215 -0
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv
2
+ __pycache__
README.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Zero-shot Turkish
3
+ emoji: 🐨
4
+ colorFrom: yellow
5
+ colorTo: red
6
+ sdk: streamlit
7
+ sdk_version: 1.2.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: emrecan/zero-shot-turkish
11
+ ---
12
+
13
+ # Configuration
14
+
15
+ `title`: _string_
16
+ Display title for the Space
17
+
18
+ `emoji`: _string_
19
+ Space emoji (emoji-only character allowed)
20
+
21
+ `colorFrom`: _string_
22
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
23
+
24
+ `colorTo`: _string_
25
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
26
+
27
+ `sdk`: _string_
28
+ Can be either `gradio` or `streamlit`
29
+
30
+ `sdk_version` : _string_
31
+ Only applicable for `streamlit` SDK.
32
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
33
+
34
+ `app_file`: _string_
35
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
36
+ Path is relative to the root of the repository.
37
+
38
+ `pinned`: _boolean_
39
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import psutil
3
+ import pandas as pd
4
+ import streamlit as st
5
+ import plotly.express as px
6
+ from models import NLI_MODEL_OPTIONS, NSP_MODEL_OPTIONS, METHOD_OPTIONS
7
+ from zeroshot_classification.classifiers import NSPZeroshotClassifier, NLIZeroshotClassifier
8
+
9
+ print(f"Total mem: {psutil.virtual_memory().total}")
10
+
11
+ def init_state(key: str):
12
+ if key not in st.session_state:
13
+ st.session_state[key] = None
14
+
15
+
16
+ for k in [
17
+ "current_model",
18
+ "current_model_option",
19
+ "current_method_option",
20
+ "current_prediction",
21
+ "current_chart",
22
+ ]:
23
+ init_state(k)
24
+
25
+
26
+ def load_model(model_option: str, method_option: str, random_state: int = 0):
27
+ with st.spinner("Loading selected model..."):
28
+ if method_option == "Natural Language Inference":
29
+ st.session_state.current_model = NLIZeroshotClassifier(
30
+ model_name=model_option, random_state=random_state
31
+ )
32
+ else:
33
+ st.session_state.current_model = NSPZeroshotClassifier(
34
+ model_name=model_option, random_state=random_state
35
+ )
36
+ st.success("Model loaded!")
37
+
38
+
39
+ def visualize_output(labels: list[str], probabilities: list[float]):
40
+ data = pd.DataFrame({"labels": labels, "probability": probabilities}).sort_values(
41
+ by="probability", ascending=False
42
+ )
43
+ chart = px.bar(
44
+ data,
45
+ x="probability",
46
+ y="labels",
47
+ color="labels",
48
+ orientation="h",
49
+ height=290,
50
+ width=500,
51
+ ).update_layout(
52
+ {
53
+ "xaxis": {"title": "probability", "visible": True, "showticklabels": True},
54
+ "yaxis": {"title": None, "visible": True, "showticklabels": True},
55
+ "margin": dict(
56
+ l=10, # left
57
+ r=10, # right
58
+ t=50, # top
59
+ b=10, # bottom
60
+ ),
61
+ "showlegend": False,
62
+ }
63
+ )
64
+ return chart
65
+
66
+
67
+ st.title("Zero-shot Turkish Text Classification")
68
+ method_option = st.radio(
69
+ "Select a zero-shot classification method.",
70
+ [
71
+ METHOD_OPTIONS["nli"],
72
+ METHOD_OPTIONS["nsp"],
73
+ ],
74
+ )
75
+ if method_option == METHOD_OPTIONS["nli"]:
76
+ model_option = st.selectbox(
77
+ "Select a natural language inference model.", NLI_MODEL_OPTIONS, index=3
78
+ )
79
+ if method_option == METHOD_OPTIONS["nsp"]:
80
+ model_option = st.selectbox(
81
+ "Select a BERT model for next sentence prediction.", NSP_MODEL_OPTIONS, index=0
82
+ )
83
+
84
+ if model_option != st.session_state.current_model_option:
85
+ st.session_state.current_model_option = model_option
86
+ st.session_state.current_method_option = method_option
87
+ load_model(
88
+ st.session_state.current_model_option, st.session_state.current_method_option
89
+ )
90
+
91
+
92
+ st.header("Configure prompts and labels")
93
+ col1, col2 = st.columns(2)
94
+ col1.subheader("Candidate labels")
95
+ labels = col1.text_area(
96
+ label="These are the labels that the model will try to predict for the given text input. Your input labels should be comma separated and meaningful.",
97
+ value="spor,dünya,siyaset,ekonomi,sanat",
98
+ key="current_labels",
99
+ )
100
+
101
+ col1.header("Make predictions")
102
+ text = col1.text_area(
103
+ "Enter a sentence or a paragraph to classify.",
104
+ value="Ian Anderson, Jethro Tull konserinde yan flüt çalarak zeybek oynadı.",
105
+ key="current_text",
106
+ )
107
+ col2.subheader("Prompt template")
108
+ prompt_template = col2.text_area(
109
+ label="Prompt template is used to transform NLI and NSP tasks into a general-use zero-shot classifier. Models replace {} with the labels that you have given.",
110
+ value="Bu metin {} kategorisine aittir",
111
+ key="current_template",
112
+ )
113
+ col2.header("")
114
+
115
+
116
+ make_pred = col1.button("Predict")
117
+ if make_pred:
118
+ st.session_state.current_prediction = (
119
+ st.session_state.current_model.predict_on_texts(
120
+ [st.session_state.current_text],
121
+ candidate_labels=st.session_state.current_labels.split(","),
122
+ prompt_template=st.session_state.current_template,
123
+ )
124
+ )
125
+ if "scores" in st.session_state.current_prediction[0]:
126
+ st.session_state.current_chart = visualize_output(
127
+ st.session_state.current_prediction[0]["labels"],
128
+ st.session_state.current_prediction[0]["scores"],
129
+ )
130
+ elif "probabilities" in st.session_state.current_prediction[0]:
131
+ st.session_state.current_chart = visualize_output(
132
+ st.session_state.current_prediction[0]["labels"],
133
+ st.session_state.current_prediction[0]["probabilities"],
134
+ )
135
+ col2.plotly_chart(st.session_state.current_chart, use_container_width=True)
models.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ METHOD_OPTIONS = {
2
+ "nli": "Natural Language Inference",
3
+ "nsp": "Next Sentence Prediction",
4
+ }
5
+
6
+ NLI_MODEL_OPTIONS = [
7
+ "emrecan/distilbert-base-turkish-cased-allnli_tr",
8
+ "emrecan/distilbert-base-turkish-cased-multinli_tr",
9
+ "emrecan/distilbert-base-turkish-cased-snli_tr",
10
+ "emrecan/bert-base-turkish-cased-allnli_tr",
11
+ "emrecan/bert-base-turkish-cased-multinli_tr",
12
+ "emrecan/bert-base-turkish-cased-snli_tr",
13
+ "emrecan/convbert-base-turkish-mc4-cased-allnli_tr",
14
+ "emrecan/convbert-base-turkish-mc4-cased-multinli_tr",
15
+ "emrecan/convbert-base-turkish-mc4-cased-snli_tr",
16
+ "emrecan/bert-base-multilingual-cased-allnli_tr",
17
+ "emrecan/bert-base-multilingual-cased-multinli_tr",
18
+ "emrecan/bert-base-multilingual-cased-snli_tr",
19
+ ]
20
+
21
+ NSP_MODEL_OPTIONS = [
22
+ "dbmdz/bert-base-turkish-cased",
23
+ "dbmdz/bert-base-turkish-uncased",
24
+ "dbmdz/bert-base-turkish-128k-cased",
25
+ "dbmdz/bert-base-turkish-128k-uncased",
26
+ ]
requirements.txt ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/emres/turkish-deasciifier.git
2
+ git+https://github.com/emrecncelik/zeroshot-turkish.git
3
+ absl-py==1.0.0
4
+ aiohttp==3.8.0
5
+ aiosignal==1.2.0
6
+ altair==4.1.0
7
+ APScheduler==3.8.1
8
+ argon2-cffi==21.3.0
9
+ argon2-cffi-bindings==21.2.0
10
+ astor==0.8.1
11
+ astunparse==1.6.3
12
+ async-timeout==4.0.1
13
+ attrs==21.2.0
14
+ backcall==0.2.0
15
+ backports.zoneinfo==0.2.1
16
+ base58==2.1.1
17
+ beautifulsoup4==4.11.1
18
+ black==21.10b0
19
+ bleach==5.0.0
20
+ blinker==1.4
21
+ blis==0.7.5
22
+ Brotli==1.0.9
23
+ cachetools==4.2.4
24
+ catalogue==2.0.6
25
+ certifi==2021.10.8
26
+ cffi==1.15.0
27
+ charset-normalizer==2.0.7
28
+ click
29
+ codecarbon==1.2.0
30
+ commonmark==0.9.1
31
+ configparser==5.1.0
32
+ cycler==0.11.0
33
+ cymem==2.0.6
34
+ cytoolz==0.11.2
35
+ dash==2.0.0
36
+ dash-bootstrap-components==1.0.0
37
+ dash-core-components==2.0.0
38
+ dash-html-components==2.0.0
39
+ dash-table==5.0.0
40
+ datasets==2.3.2
41
+ debugpy==1.5.1
42
+ decorator==5.1.0
43
+ defusedxml==0.7.1
44
+ dill==0.3.4
45
+ docker-pycreds==0.4.0
46
+ entrypoints==0.3
47
+ et-xmlfile==1.1.0
48
+ fastjsonschema==2.15.3
49
+ fasttext==0.9.2
50
+ filelock==3.3.2
51
+ fire==0.4.0
52
+ Flask==2.0.2
53
+ Flask-Compress==1.10.1
54
+ flatbuffers==2.0
55
+ fonttools==4.28.5
56
+ frozenlist==1.2.0
57
+ fsspec==2021.11.0
58
+ gast==0.4.0
59
+ gitdb==4.0.9
60
+ GitPython==3.1.24
61
+ google-auth==2.3.3
62
+ google-auth-oauthlib==0.4.6
63
+ google-pasta==0.2.0
64
+ grpcio==1.41.1
65
+ h5py==3.5.0
66
+ huggingface-hub==0.1.2
67
+ idna==3.3
68
+ importlib-metadata==4.12.0
69
+ importlib-resources==5.7.1
70
+ ipykernel==6.6.0
71
+ ipython==7.30.1
72
+ ipython-genutils==0.2.0
73
+ ipywidgets==7.6.5
74
+ itsdangerous==2.0.1
75
+ jedi==0.18.1
76
+ jellyfish==0.8.9
77
+ Jinja2==3.0.3
78
+ joblib==1.1.0
79
+ jsonschema==4.5.1
80
+ jupyter-client==7.1.0
81
+ jupyter-core==4.9.1
82
+ jupyterlab-pygments==0.2.2
83
+ jupyterlab-widgets==1.0.2
84
+ keras==2.7.0
85
+ Keras-Preprocessing==1.1.2
86
+ kiwisolver==1.3.2
87
+ langcodes==3.3.0
88
+ libclang==12.0.0
89
+ loguru==0.6.0
90
+ lxml==4.6.5
91
+ Markdown==3.3.4
92
+ MarkupSafe==2.0.1
93
+ matplotlib==3.5.1
94
+ matplotlib-inline==0.1.3
95
+ mistune==0.8.4
96
+ multidict==5.2.0
97
+ multiprocess==0.70.12.2
98
+ murmurhash==1.0.6
99
+ mypy-extensions==0.4.3
100
+ nbclient==0.6.3
101
+ nbconvert==6.5.0
102
+ nbformat==5.4.0
103
+ nest-asyncio==1.5.4
104
+ networkx==2.6.3
105
+ nltk==3.6.7
106
+ notebook==6.4.11
107
+ numpy==1.21.4
108
+ oauthlib==3.1.1
109
+ openpyxl==3.0.9
110
+ opt-einsum==3.3.0
111
+ packaging==21.2
112
+ pandas==1.4.2
113
+ pandocfilters==1.5.0
114
+ parso==0.8.3
115
+ pathspec==0.9.0
116
+ pathtools==0.1.2
117
+ pathy==0.6.1
118
+ pexpect==4.8.0
119
+ pickleshare==0.7.5
120
+ Pillow==8.4.0
121
+ platformdirs==2.4.0
122
+ plotly==5.4.0
123
+ preshed==3.0.6
124
+ prometheus-client==0.14.1
125
+ promise==2.3
126
+ prompt-toolkit==3.0.24
127
+ protobuf==3.19.1
128
+ psutil==5.8.0
129
+ ptyprocess==0.7.0
130
+ py-cpuinfo==8.0.0
131
+ pyarrow==6.0.0
132
+ pyasn1==0.4.8
133
+ pyasn1-modules==0.2.8
134
+ pybind11==2.9.2
135
+ pycparser==2.21
136
+ pydantic==1.8.2
137
+ pydeck==0.7.1
138
+ Pygments==2.10.0
139
+ Pympler==0.9
140
+ pynvml==11.0.0
141
+ pyparsing==2.4.7
142
+ pyphen==0.11.0
143
+ pyrsistent==0.18.1
144
+ python-dateutil==2.8.2
145
+ pytz==2021.3
146
+ pytz-deprecation-shim==0.1.0.post0
147
+ PyYAML==6.0
148
+ pyzmq==22.3.0
149
+ regex==2021.11.10
150
+ requests==2.26.0
151
+ requests-oauthlib==1.3.0
152
+ responses==0.18.0
153
+ rich==12.4.4
154
+ rsa==4.7.2
155
+ sacremoses==0.0.46
156
+ scikit-learn==1.0.1
157
+ scipy==1.7.2
158
+ semver==2.13.0
159
+ Send2Trash==1.8.0
160
+ sentencepiece==0.1.96
161
+ sentry-sdk==1.4.3
162
+ setuptools-scm==6.3.2
163
+ shortuuid==1.0.8
164
+ six==1.16.0
165
+ sklearn==0.0
166
+ smart-open==5.2.1
167
+ smmap==5.0.0
168
+ soupsieve==2.3.2.post1
169
+ spacy==3.2.1
170
+ spacy-legacy==3.0.8
171
+ spacy-loggers==1.0.1
172
+ srsly==2.4.2
173
+ streamlit==1.2.0
174
+ subprocess32==3.5.4
175
+ tenacity==8.0.1
176
+ tensorboard==2.7.0
177
+ tensorboard-data-server==0.6.1
178
+ tensorboard-plugin-wit==1.8.0
179
+ tensorflow==2.7.0
180
+ tensorflow-estimator==2.7.0
181
+ tensorflow-io-gcs-filesystem==0.22.0
182
+ termcolor==1.1.0
183
+ terminado==0.15.0
184
+ testpath==0.5.0
185
+ textacy==0.12.0
186
+ thinc==8.0.13
187
+ threadpoolctl==3.0.0
188
+ tinycss2==1.1.1
189
+ tokenizers==0.12.1
190
+ toml==0.10.2
191
+ tomli==1.2.2
192
+ toolz==0.11.2
193
+ torch==1.11.0
194
+ tornado==6.1
195
+ tqdm==4.62.3
196
+ traitlets==5.1.1
197
+ transformers==4.20.0
198
+ typer==0.4.0
199
+ typing-extensions
200
+ tzdata==2021.5
201
+ tzlocal==4.1
202
+ urllib3==1.26.7
203
+ validators==0.18.2
204
+ wandb==0.12.6
205
+ wasabi==0.9.0
206
+ watchdog==2.1.6
207
+ wcwidth==0.2.5
208
+ webencodings==0.5.1
209
+ Werkzeug==2.0.2
210
+ widgetsnbextension==3.5.2
211
+ wrapt==1.13.3
212
+ xxhash==2.0.2
213
+ yarl==1.7.2
214
+ yaspin==2.1.0
215
+ zipp==3.8.0