yhavinga commited on
Commit
46ffa30
Β·
1 Parent(s): 5314ab7
Files changed (7) hide show
  1. .gitignore +4 -0
  2. README.md +32 -6
  3. app.py +211 -0
  4. babel.png +0 -0
  5. generator.py +124 -0
  6. requirements.txt +13 -0
  7. style.css +42 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ venv
2
+ .idea
3
+ __pycache__
4
+ *~
README.md CHANGED
@@ -1,13 +1,39 @@
1
  ---
2
- title: Babel
3
- emoji: πŸ“Š
4
- colorFrom: indigo
5
- colorTo: gray
6
  sdk: streamlit
7
- sdk_version: 1.10.0
8
  app_file: app.py
9
  pinned: false
 
10
  license: postgresql
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Babel - translate between Dutch and English
3
+ emoji: πŸ§™
4
+ colorFrom: gray
5
+ colorTo: indigo
6
  sdk: streamlit
 
7
  app_file: app.py
8
  pinned: false
9
+ sdk_version: 1.0.0
10
  license: postgresql
11
  ---
12
 
13
+ # Configuration
14
+
15
+ `title`: _string_
16
+ Display title for the Space
17
+
18
+ `emoji`: _string_
19
+ Space emoji (emoji-only character allowed)
20
+
21
+ `colorFrom`: _string_
22
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
23
+
24
+ `colorTo`: _string_
25
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
26
+
27
+ `sdk`: _string_
28
+ Can be either `gradio`, `streamlit`, or `static`
29
+
30
+ `sdk_version` : _string_
31
+ Only applicable for `streamlit` SDK.
32
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
33
+
34
+ `app_file`: _string_
35
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
36
+ Path is relative to the root of the repository.
37
+
38
+ `pinned`: _boolean_
39
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ from random import randint
5
+
6
+ import psutil
7
+ import streamlit as st
8
+ import torch
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoModelForSeq2SeqLM,
12
+ AutoTokenizer,
13
+ pipeline,
14
+ set_seed,
15
+ )
16
+
17
+ from generator import GeneratorFactory
18
+
19
+ device = torch.cuda.device_count() - 1
20
+
21
+ TRANSLATION_NL_TO_EN = "translation_en_to_nl"
22
+
23
+ GENERATOR_LIST = [
24
+ {
25
+ "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
26
+ "desc": "longT5 large nl8 256cc/512beta/512l en->nl",
27
+ "task": TRANSLATION_NL_TO_EN,
28
+ },
29
+ {
30
+ "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
31
+ "desc": "longT5 large nl8 512beta/512l en->nl",
32
+ "task": TRANSLATION_NL_TO_EN,
33
+ },
34
+ {
35
+ "model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
36
+ "desc": "T5 small nl24 ccmatrix en->nl",
37
+ "task": TRANSLATION_NL_TO_EN,
38
+ },
39
+ ]
40
+
41
+
42
+ def main():
43
+ st.set_page_config( # Alternate names: setup_page, page, layout
44
+ page_title="Babel", # String or None. Strings get appended with "β€’ Streamlit".
45
+ layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
46
+ initial_sidebar_state="expanded", # Can be "auto", "expanded", "collapsed"
47
+ page_icon="πŸ“š", # String, anything supported by st.image, or None.
48
+ )
49
+
50
+ if "generators" not in st.session_state:
51
+ st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST)
52
+
53
+ generators = st.session_state["generators"]
54
+
55
+ with open("style.css") as f:
56
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
57
+
58
+ st.sidebar.image("babel.png", width=200)
59
+ st.sidebar.markdown(
60
+ """# Babel
61
+ Vertaal van en naar Engels"""
62
+ )
63
+ model_desc = st.sidebar.selectbox("Model", generators.gpt_descs(), index=1)
64
+ st.sidebar.title("Parameters:")
65
+ if "prompt_box" not in st.session_state:
66
+ # Text is from https://www.gutenberg.org/files/35091/35091-h/35091-h.html
67
+ st.session_state[
68
+ "prompt_box"
69
+ ] = """It was a wet, gusty night and I had a lonely walk home. By taking the river road, though I hated it, I saved two miles, so I sloshed ahead trying not to think at all. Through the barbed wire fence I could see the racing river. Its black swollen body writhed along with extraordinary swiftness, breathlessly silent, only occasionally making a swishing ripple. I did not enjoy looking at it. I was somehow afraid.
70
+
71
+ And there, at the end of the river road where I swerved off, a figure stood waiting for me, motionless and enigmatic. I had to meet it or turn back.
72
+
73
+ It was a quite young girl, unknown to me, with a hood over her head, and with large unhappy eyes.
74
+
75
+ β€œMy father is very ill,” she said without a word of introduction. β€œThe nurse is frightened. Could you come in and help?”"""
76
+ st.session_state["text"] = st.text_area(
77
+ "Enter text", st.session_state.prompt_box, height=300
78
+ )
79
+ max_length = st.sidebar.number_input(
80
+ "Lengte van de tekst",
81
+ value=200,
82
+ max_value=4096,
83
+ )
84
+ no_repeat_ngram_size = st.sidebar.number_input(
85
+ "No-repeat NGram size", min_value=1, max_value=5, value=3
86
+ )
87
+ repetition_penalty = st.sidebar.number_input(
88
+ "Repetition penalty", min_value=0.0, max_value=5.0, value=1.2, step=0.1
89
+ )
90
+ num_return_sequences = st.sidebar.number_input(
91
+ "Num return sequences", min_value=1, max_value=5, value=1
92
+ )
93
+ seed_placeholder = st.sidebar.empty()
94
+ if "seed" not in st.session_state:
95
+ print(f"Session state does not contain seed")
96
+ st.session_state["seed"] = 4162549114
97
+ print(f"Seed is set to: {st.session_state['seed']}")
98
+
99
+ seed = seed_placeholder.number_input(
100
+ "Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"]
101
+ )
102
+
103
+ def set_random_seed():
104
+ st.session_state["seed"] = randint(0, 2**32 - 1)
105
+ seed = seed_placeholder.number_input(
106
+ "Seed", min_value=0, max_value=2**32 - 1, value=st.session_state["seed"]
107
+ )
108
+ print(f"New random seed set to: {seed}")
109
+
110
+ if st.button("Set new random seed"):
111
+ set_random_seed()
112
+
113
+ if sampling_mode := st.sidebar.selectbox(
114
+ "select a Mode", index=0, options=["Top-k Sampling", "Beam Search"]
115
+ ):
116
+ if sampling_mode == "Beam Search":
117
+ num_beams = st.sidebar.number_input(
118
+ "Num beams", min_value=1, max_value=10, value=4
119
+ )
120
+ length_penalty = st.sidebar.number_input(
121
+ "Length penalty", min_value=0.0, max_value=2.0, value=1.0, step=0.1
122
+ )
123
+ params = {
124
+ "max_length": max_length,
125
+ "no_repeat_ngram_size": no_repeat_ngram_size,
126
+ "repetition_penalty": repetition_penalty,
127
+ "num_return_sequences": num_return_sequences,
128
+ "num_beams": num_beams,
129
+ "early_stopping": True,
130
+ "length_penalty": length_penalty,
131
+ }
132
+ else:
133
+ top_k = st.sidebar.number_input(
134
+ "Top K", min_value=0, max_value=100, value=50
135
+ )
136
+ top_p = st.sidebar.number_input(
137
+ "Top P", min_value=0.0, max_value=1.0, value=0.95, step=0.05
138
+ )
139
+ temperature = st.sidebar.number_input(
140
+ "Temperature", min_value=0.05, max_value=1.0, value=1.0, step=0.05
141
+ )
142
+ params = {
143
+ "max_length": max_length,
144
+ "no_repeat_ngram_size": no_repeat_ngram_size,
145
+ "repetition_penalty": repetition_penalty,
146
+ "num_return_sequences": num_return_sequences,
147
+ "do_sample": True,
148
+ "top_k": top_k,
149
+ "top_p": top_p,
150
+ "temperature": temperature,
151
+ }
152
+
153
+ st.sidebar.markdown(
154
+ """For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
155
+ and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
156
+ """
157
+ )
158
+
159
+ def estimate_time():
160
+ """Estimate the time it takes to generate the text."""
161
+ estimate = max_length / 18
162
+ if device == -1:
163
+ ## cpu
164
+ estimate = estimate * (1 + 0.7 * (num_return_sequences - 1))
165
+ if sampling_mode == "Beam Search":
166
+ estimate = estimate * (1.1 + 0.3 * (num_beams - 1))
167
+ else:
168
+ ## gpu
169
+ estimate = estimate * (1 + 0.1 * (num_return_sequences - 1))
170
+ estimate = 0.5 + estimate / 5
171
+ if sampling_mode == "Beam Search":
172
+ estimate = estimate * (1.0 + 0.1 * (num_beams - 1))
173
+ return int(estimate)
174
+
175
+ if st.button("Run"):
176
+ estimate = estimate_time()
177
+
178
+ with st.spinner(
179
+ text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
180
+ ):
181
+ memory = psutil.virtual_memory()
182
+
183
+ for generator in generators:
184
+ st.subheader(f"Result from {generator}")
185
+ set_seed(seed)
186
+ time_start = time.time()
187
+ result = generator.generate(text=st.session_state.text, **params)
188
+ time_end = time.time()
189
+ time_diff = time_end - time_start
190
+
191
+ for text in result:
192
+ st.write(text.replace("\n", " \n"))
193
+ st.write(f"--- generated in {time_diff:.2f} seconds ---")
194
+
195
+ info = f"""
196
+ ---
197
+ *Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
198
+ *Text generated using seed {seed}*
199
+ """
200
+ st.write(info)
201
+
202
+ params["seed"] = seed
203
+ params["prompt"] = st.session_state.text
204
+ params["model"] = generator.model_name
205
+ params_text = json.dumps(params)
206
+ print(params_text)
207
+ st.json(params_text)
208
+
209
+
210
+ if __name__ == "__main__":
211
+ main()
babel.png ADDED
generator.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import torch
4
+ from transformers import (
5
+ AutoModelForCausalLM,
6
+ AutoModelForSeq2SeqLM,
7
+ AutoTokenizer,
8
+ )
9
+
10
+ device = torch.cuda.device_count() - 1
11
+
12
+ TRANSLATION_NL_TO_EN = "translation_en_to_nl"
13
+
14
+
15
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True)
16
+ def load_model(model_name, task):
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+ try:
19
+ if not os.path.exists(".streamlit/secrets.toml"):
20
+ raise FileNotFoundError
21
+ access_token = st.secrets.get("netherator")
22
+ except FileNotFoundError:
23
+ access_token = os.environ.get("HF_ACCESS_TOKEN", None)
24
+ tokenizer = AutoTokenizer.from_pretrained(
25
+ model_name, from_flax=True, use_auth_token=access_token
26
+ )
27
+ if tokenizer.pad_token is None:
28
+ print("Adding pad_token to the tokenizer")
29
+ tokenizer.pad_token = tokenizer.eos_token
30
+ auto_model_class = (
31
+ AutoModelForSeq2SeqLM if "translation" in task else AutoModelForCausalLM
32
+ )
33
+ model = auto_model_class.from_pretrained(
34
+ model_name, from_flax=True, use_auth_token=access_token
35
+ )
36
+ if device != -1:
37
+ model.to(f"cuda:{device}")
38
+ return tokenizer, model
39
+
40
+
41
+ class Generator:
42
+ def __init__(self, model_name, task, desc):
43
+ self.model_name = model_name
44
+ self.task = task
45
+ self.desc = desc
46
+ self.tokenizer = None
47
+ self.model = None
48
+ self.prefix = ""
49
+ self.load()
50
+
51
+ def load(self):
52
+ if not self.model:
53
+ print(f"Loading model {self.model_name}")
54
+ self.tokenizer, self.model = load_model(self.model_name, self.task)
55
+
56
+ try:
57
+ if self.task in self.model.config.task_specific_params:
58
+ task_specific_params = self.model.config.task_specific_params[
59
+ self.task
60
+ ]
61
+ if "prefix" in task_specific_params:
62
+ self.prefix = task_specific_params["prefix"]
63
+ except TypeError:
64
+ pass
65
+
66
+ def generate(self, text: str, **generate_kwargs) -> str:
67
+ #
68
+ # import pydevd_pycharm
69
+ # pydevd_pycharm.settrace('10.1.0.144', port=12345, stdoutToServer=True, stderrToServer=True)
70
+ #
71
+ batch_encoded = self.tokenizer(
72
+ self.prefix + text,
73
+ max_length=generate_kwargs["max_length"],
74
+ padding=False,
75
+ truncation=False,
76
+ return_tensors="pt",
77
+ )
78
+ if device != -1:
79
+ batch_encoded.to(f"cuda:{device}")
80
+ logits = self.model.generate(
81
+ batch_encoded["input_ids"],
82
+ attention_mask=batch_encoded["attention_mask"],
83
+ **generate_kwargs,
84
+ )
85
+ decoded_preds = self.tokenizer.batch_decode(
86
+ logits.cpu().numpy(), skip_special_tokens=False
87
+ )
88
+ decoded_preds = [
89
+ pred.replace("<pad> ", "").replace("<pad>", "").replace("</s>", "")
90
+ for pred in decoded_preds
91
+ ]
92
+ return decoded_preds
93
+
94
+ # return self.pipeline(text, **generate_kwargs)
95
+
96
+ def __str__(self):
97
+ return self.desc
98
+
99
+
100
+ class GeneratorFactory:
101
+ def __init__(self, generator_list):
102
+ self.generators = []
103
+ for g in generator_list:
104
+ with st.spinner(text=f"Loading the model {g['desc']} ..."):
105
+ self.add_generator(**g)
106
+
107
+ def add_generator(self, model_name, task, desc):
108
+ # If the generator is not yet present, add it
109
+ if not self.get_generator(model_name=model_name, task=task, desc=desc):
110
+ g = Generator(model_name, task, desc)
111
+ g.load()
112
+ self.generators.append(g)
113
+
114
+ def get_generator(self, **kwargs):
115
+ for g in self.generators:
116
+ if all([g.__dict__.get(k) == v for k, v in kwargs.items()]):
117
+ return g
118
+ return None
119
+
120
+ def __iter__(self):
121
+ return iter(self.generators)
122
+
123
+ def gpt_descs(self):
124
+ return [g.desc for g in self.generators if g.task == TRANSLATION_NL_TO_EN]
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-f https://download.pytorch.org/whl/torch_stable.html
2
+ -f https://download.pytorch.org/whl/cu116
3
+ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
4
+ protobuf<3.20
5
+ streamlit>=1.4.0,<=1.10.0
6
+ torch
7
+ transformers>=4.13.0
8
+ mtranslate
9
+ psutil
10
+ jax[cuda]==0.3.16
11
+ chex>=0.1.4
12
+ ##jaxlib==0.1.67
13
+ flax>=0.5.3
style.css ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ background-color: #eee;
3
+ }
4
+ /*.fullScreenFrame > div {*/
5
+ /* display: flex;*/
6
+ /* justify-content: center;*/
7
+ /*}*/
8
+ /*.stButton>button {*/
9
+ /* color: #4F8BF9;*/
10
+ /* border-radius: 50%;*/
11
+ /* height: 3em;*/
12
+ /* width: 3em;*/
13
+ /*}*/
14
+
15
+ .stTextInput>div>div>input {
16
+ color: #4F8BF9;
17
+ }
18
+ .stTextArea>div>div>input {
19
+ color: #4F8BF9;
20
+ min-height: 300px;
21
+ }
22
+
23
+
24
+ /*.st-cj {*/
25
+ /* min-height: 500px;*/
26
+ /* spellcheck="false";*/
27
+ /* color: #4F8BF9;*/
28
+ /*}*/
29
+ /*.st-ch {*/
30
+ /* min-height: 500px;*/
31
+ /* spellcheck="false";*/
32
+ /* color: #4F8BF9;*/
33
+ /*}*/
34
+ /*.st-bb {*/
35
+ /* min-height: 500px;*/
36
+ /* spellcheck="false";*/
37
+ /* color: #4F8BF9;*/
38
+ /*}*/
39
+
40
+ /*body {*/
41
+ /* background-color: #f1fbff*/
42
+ /*}*/