Spaces:
Runtime error
Runtime error
Emily McMilin
commited on
Commit
·
68fec63
1
Parent(s):
9a5cfb0
first commit, not describing text
Browse files- app.py +488 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
5 |
+
from transformers import pipeline
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from matplotlib.ticker import MaxNLocator
|
10 |
+
|
11 |
+
|
12 |
+
# DATASETS
|
13 |
+
REDDIT = 'reddit_finetuned'
|
14 |
+
WIKIBIO = 'wikibio_finetuned'
|
15 |
+
BASE = 'BERT_base'
|
16 |
+
|
17 |
+
# Play with me, consts
|
18 |
+
SUBREDDIT_CONDITIONING_VARIABLES = ["none", "subreddit"]
|
19 |
+
WIKIBIO_CONDITIONING_VARIABLES = ['none', 'birth_date', 'birth_place'] # EMILY!!
|
20 |
+
|
21 |
+
BERT_LIKE_MODELS = ["bert", "distilbert"]
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
## Internal constants
|
26 |
+
GENDER_OPTIONS = ['female', 'male']
|
27 |
+
DECIMAL_PLACES = 1
|
28 |
+
|
29 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
|
31 |
+
MAX_TOKEN_LENGTH = 32
|
32 |
+
NON_LOSS_TOKEN_ID = -100
|
33 |
+
|
34 |
+
# Picked ints that will pop out visually during debug
|
35 |
+
NON_GENDERED_TOKEN_ID = 30
|
36 |
+
LABEL_DICT = {GENDER_OPTIONS[0]: 9, GENDER_OPTIONS[1]: -9}
|
37 |
+
CLASSES = list(LABEL_DICT.keys())
|
38 |
+
|
39 |
+
MULTITOKEN_WOMAN_WORD = 'policewoman'
|
40 |
+
MULTITOKEN_MAN_WORD = 'spiderman'
|
41 |
+
|
42 |
+
# Wikibio conts
|
43 |
+
|
44 |
+
START_YEAR = 1800
|
45 |
+
STOP_YEAR = 1999
|
46 |
+
SPLIT_KEY = "DATE"
|
47 |
+
|
48 |
+
# Reddit consts
|
49 |
+
|
50 |
+
# List of randomly selected (tending towards those with seemingly more gender-neutral words)
|
51 |
+
# in order of increasing self-identified female participation.
|
52 |
+
# See http://bburky.com/subredditgenderratios/ , Minimum subreddit size: 100000
|
53 |
+
# Update: 400000
|
54 |
+
SUBREDDITS = [
|
55 |
+
"GlobalOffensive",
|
56 |
+
"pcmasterrace",
|
57 |
+
"nfl",
|
58 |
+
"sports",
|
59 |
+
"The_Donald",
|
60 |
+
"leagueoflegends",
|
61 |
+
"Overwatch",
|
62 |
+
"gonewild",
|
63 |
+
"Futurology",
|
64 |
+
"space",
|
65 |
+
"technology",
|
66 |
+
"gaming",
|
67 |
+
"Jokes",
|
68 |
+
"dataisbeautiful",
|
69 |
+
"woahdude",
|
70 |
+
"askscience",
|
71 |
+
"wow",
|
72 |
+
"anime",
|
73 |
+
"BlackPeopleTwitter",
|
74 |
+
"politics",
|
75 |
+
"pokemon",
|
76 |
+
"worldnews",
|
77 |
+
"reddit.com",
|
78 |
+
"interestingasfuck",
|
79 |
+
"videos",
|
80 |
+
"nottheonion",
|
81 |
+
"television",
|
82 |
+
"science",
|
83 |
+
"atheism",
|
84 |
+
"movies",
|
85 |
+
"gifs",
|
86 |
+
"Music",
|
87 |
+
"trees",
|
88 |
+
"EarthPorn",
|
89 |
+
"GetMotivated",
|
90 |
+
"pokemongo",
|
91 |
+
"news",
|
92 |
+
"fffffffuuuuuuuuuuuu",
|
93 |
+
"Fitness",
|
94 |
+
"Showerthoughts",
|
95 |
+
"OldSchoolCool",
|
96 |
+
"explainlikeimfive",
|
97 |
+
"todayilearned",
|
98 |
+
"gameofthrones",
|
99 |
+
"AdviceAnimals",
|
100 |
+
"DIY",
|
101 |
+
"WTF",
|
102 |
+
"IAmA",
|
103 |
+
"cringepics",
|
104 |
+
"tifu",
|
105 |
+
"mildlyinteresting",
|
106 |
+
"funny",
|
107 |
+
"pics",
|
108 |
+
"LifeProTips",
|
109 |
+
"creepy",
|
110 |
+
"personalfinance",
|
111 |
+
"food",
|
112 |
+
"AskReddit",
|
113 |
+
"books",
|
114 |
+
"aww",
|
115 |
+
"sex",
|
116 |
+
"relationships",
|
117 |
+
]
|
118 |
+
|
119 |
+
|
120 |
+
# Fire up the models
|
121 |
+
models_paths = dict()
|
122 |
+
models = dict()
|
123 |
+
|
124 |
+
base_path = "emilylearning/"
|
125 |
+
|
126 |
+
# reddit finetuned models:
|
127 |
+
for var in SUBREDDIT_CONDITIONING_VARIABLES:
|
128 |
+
models_paths[(REDDIT, var)] = base_path + f'cond_ft_{var}_on_reddit__prcnt_100__test_run_False'
|
129 |
+
models[(REDDIT, var)] = AutoModelForTokenClassification.from_pretrained(
|
130 |
+
models_paths[(REDDIT, var)]
|
131 |
+
)
|
132 |
+
|
133 |
+
# wikibio finetuned models:
|
134 |
+
for var in WIKIBIO_CONDITIONING_VARIABLES:
|
135 |
+
models_paths[(WIKIBIO, var)] = base_path + f"cond_ft_{var}_on_wiki_bio__prcnt_100__test_run_False"
|
136 |
+
models[(WIKIBIO, var)] = AutoModelForTokenClassification.from_pretrained(
|
137 |
+
models_paths[(WIKIBIO, var)]
|
138 |
+
)
|
139 |
+
|
140 |
+
# BERT-like models:
|
141 |
+
for bert_like in BERT_LIKE_MODELS:
|
142 |
+
models_paths[(BASE, bert_like)] = f"{bert_like}-base-uncased"
|
143 |
+
models[(BASE, bert_like)] = pipeline(
|
144 |
+
"fill-mask", model=models_paths[(BASE, bert_like)])
|
145 |
+
|
146 |
+
# Tokenizers same for each model, so just grabbing one of them
|
147 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
148 |
+
models_paths[(BASE, BERT_LIKE_MODELS[0])], add_prefix_space=True
|
149 |
+
)
|
150 |
+
MASK_TOKEN_ID = tokenizer.mask_token_id
|
151 |
+
|
152 |
+
|
153 |
+
def get_gendered_token_ids(tokenizer):
|
154 |
+
|
155 |
+
## Set up gendered token constants
|
156 |
+
gendered_lists = [
|
157 |
+
['he', 'she'],
|
158 |
+
['him', 'her'],
|
159 |
+
['his', 'hers'],
|
160 |
+
["himself", "herself"],
|
161 |
+
['male', 'female'],
|
162 |
+
['man', 'woman'],
|
163 |
+
['men', 'women'],
|
164 |
+
["husband", "wife"],
|
165 |
+
['father', 'mother'],
|
166 |
+
['boyfriend', 'girlfriend'],
|
167 |
+
['brother', 'sister'],
|
168 |
+
["actor", "actress"],
|
169 |
+
]
|
170 |
+
# Generating dicts here for potential later token reconstruction of predictions
|
171 |
+
male_gendered_dict = {list[0]: list for list in gendered_lists}
|
172 |
+
female_gendered_dict = {list[1]: list for list in gendered_lists}
|
173 |
+
|
174 |
+
male_gendered_token_ids = tokenizer.convert_tokens_to_ids(
|
175 |
+
list(male_gendered_dict.keys()))
|
176 |
+
female_gendered_token_ids = tokenizer.convert_tokens_to_ids(
|
177 |
+
list(female_gendered_dict.keys())
|
178 |
+
)
|
179 |
+
|
180 |
+
# Below technique is used to grab second token in a multi-token word
|
181 |
+
# There must be a better way...
|
182 |
+
multiword_woman_token_ids = tokenizer.encode(
|
183 |
+
MULTITOKEN_WOMAN_WORD, add_special_tokens=False)
|
184 |
+
assert len(multiword_woman_token_ids) == 2
|
185 |
+
subword_woman_token_id = multiword_woman_token_ids[1]
|
186 |
+
|
187 |
+
multiword_man_token_ids = tokenizer.encode(
|
188 |
+
MULTITOKEN_MAN_WORD, add_special_tokens=False)
|
189 |
+
assert len(multiword_man_token_ids) == 2
|
190 |
+
subword_man_token_id = multiword_man_token_ids[1]
|
191 |
+
|
192 |
+
male_gendered_token_ids.append(subword_man_token_id)
|
193 |
+
female_gendered_token_ids.append(subword_woman_token_id)
|
194 |
+
|
195 |
+
assert tokenizer.unk_token_id not in male_gendered_token_ids
|
196 |
+
assert tokenizer.unk_token_id not in female_gendered_token_ids
|
197 |
+
|
198 |
+
return male_gendered_token_ids, female_gendered_token_ids
|
199 |
+
|
200 |
+
|
201 |
+
def tokenize_and_append_metadata(text, tokenizer, female_gendered_token_ids, male_gendered_token_ids):
|
202 |
+
"""Tokenize text and mask/flag 'gendered_tokens_ids' in token_ids and labels."""
|
203 |
+
|
204 |
+
label_list = list(LABEL_DICT.values())
|
205 |
+
assert label_list[0] == LABEL_DICT["female"], "LABEL_DICT not an ordered dict"
|
206 |
+
label2id = {label: idx for idx, label in enumerate(label_list)}
|
207 |
+
|
208 |
+
tokenized = tokenizer(
|
209 |
+
text,
|
210 |
+
truncation=True,
|
211 |
+
padding='max_length',
|
212 |
+
max_length=MAX_TOKEN_LENGTH,
|
213 |
+
)
|
214 |
+
|
215 |
+
# Finding the gender pronouns in the tokens
|
216 |
+
token_ids = tokenized["input_ids"]
|
217 |
+
female_tags = torch.tensor(
|
218 |
+
[
|
219 |
+
LABEL_DICT["female"]
|
220 |
+
if id in female_gendered_token_ids
|
221 |
+
else NON_GENDERED_TOKEN_ID
|
222 |
+
for id in token_ids
|
223 |
+
]
|
224 |
+
)
|
225 |
+
male_tags = torch.tensor(
|
226 |
+
[
|
227 |
+
LABEL_DICT["male"]
|
228 |
+
if id in male_gendered_token_ids
|
229 |
+
else NON_GENDERED_TOKEN_ID
|
230 |
+
for id in token_ids
|
231 |
+
]
|
232 |
+
)
|
233 |
+
|
234 |
+
# Labeling and masking out occurrences of gendered pronouns
|
235 |
+
labels = torch.tensor([NON_LOSS_TOKEN_ID] * len(token_ids))
|
236 |
+
labels = torch.where(
|
237 |
+
female_tags == LABEL_DICT["female"],
|
238 |
+
label2id[LABEL_DICT["female"]],
|
239 |
+
NON_LOSS_TOKEN_ID,
|
240 |
+
)
|
241 |
+
labels = torch.where(
|
242 |
+
male_tags == LABEL_DICT["male"], label2id[LABEL_DICT["male"]], labels
|
243 |
+
)
|
244 |
+
masked_token_ids = torch.where(
|
245 |
+
female_tags == LABEL_DICT["female"], MASK_TOKEN_ID, torch.tensor(
|
246 |
+
token_ids)
|
247 |
+
)
|
248 |
+
masked_token_ids = torch.where(
|
249 |
+
male_tags == LABEL_DICT["male"], MASK_TOKEN_ID, masked_token_ids
|
250 |
+
)
|
251 |
+
|
252 |
+
tokenized["input_ids"] = masked_token_ids
|
253 |
+
tokenized["labels"] = labels
|
254 |
+
|
255 |
+
return tokenized
|
256 |
+
|
257 |
+
|
258 |
+
def get_tokenized_text_with_metadata(input_text, indie_vars, dataset, male_gendered_token_ids, female_gendered_token_ids):
|
259 |
+
"""Construct dict of tokenized texts with each year injected into the text."""
|
260 |
+
if dataset == WIKIBIO:
|
261 |
+
text_portions = input_text.split(SPLIT_KEY)
|
262 |
+
# If no SPLIT_KEY found in text, add space for metadata and whitespaces
|
263 |
+
if len(text_portions) == 1:
|
264 |
+
text_portions = ['Born in ', f" {text_portions[0]}"]
|
265 |
+
|
266 |
+
|
267 |
+
tokenized_w_metadata = {'ids': [], 'atten_mask': [], 'toks': [], 'labels': []}
|
268 |
+
for indie_var in indie_vars:
|
269 |
+
if dataset == WIKIBIO:
|
270 |
+
target_text = f"{indie_var}".join(text_portions)
|
271 |
+
else:
|
272 |
+
target_text = f"r/{indie_var}: {input_text}"
|
273 |
+
|
274 |
+
tokenized_sample = tokenize_and_append_metadata(
|
275 |
+
target_text,
|
276 |
+
tokenizer,
|
277 |
+
male_gendered_token_ids,
|
278 |
+
female_gendered_token_ids
|
279 |
+
)
|
280 |
+
|
281 |
+
tokenized_w_metadata['ids'].append(tokenized_sample["input_ids"])
|
282 |
+
tokenized_w_metadata['atten_mask'].append(
|
283 |
+
torch.tensor(tokenized_sample["attention_mask"]))
|
284 |
+
tokenized_w_metadata['toks'].append(
|
285 |
+
tokenizer.convert_ids_to_tokens(tokenized_sample["input_ids"]))
|
286 |
+
tokenized_w_metadata['labels'].append(tokenized_sample["labels"])
|
287 |
+
|
288 |
+
return tokenized_w_metadata
|
289 |
+
|
290 |
+
|
291 |
+
def get_avg_prob_from_finetuned_outputs(outputs, is_masked, num_preds, gender):
|
292 |
+
preds = torch.softmax(outputs[0][0].cpu(), dim=1, dtype=torch.double)
|
293 |
+
pronoun_preds = torch.where(is_masked, preds[:,CLASSES.index(gender)], 0.0)
|
294 |
+
return round(torch.sum(pronoun_preds).item() / num_preds * 100, DECIMAL_PLACES)
|
295 |
+
|
296 |
+
|
297 |
+
def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token_ids, num_preds):
|
298 |
+
pronoun_preds = [sum([
|
299 |
+
pronoun["score"] if pronoun["token"] in gendered_token_ids else 0.0
|
300 |
+
for pronoun in top_preds])
|
301 |
+
for top_preds in mask_filled_text
|
302 |
+
]
|
303 |
+
return round(sum(pronoun_preds) / num_preds * 100, DECIMAL_PLACES)
|
304 |
+
|
305 |
+
|
306 |
+
def get_figure(results, dataset, gender, indie_var_name):
|
307 |
+
fig, ax = plt.subplots()
|
308 |
+
ax.plot(results)
|
309 |
+
|
310 |
+
if dataset == REDDIT:
|
311 |
+
ax.set_xlabel("Subreddit prepended to input text")
|
312 |
+
ax.xaxis.set_major_locator(MaxNLocator(6))
|
313 |
+
else:
|
314 |
+
ax.set_xlabel("Date injected into input text")
|
315 |
+
ax.set_title(f"Softmax probability of pronouns predicted {gender}\n by model type vs {indie_var_name}.")
|
316 |
+
ax.set_ylabel(f"Avg softmax prob for {gender} pronouns")
|
317 |
+
ax.legend(list(results.columns))
|
318 |
+
return fig
|
319 |
+
|
320 |
+
|
321 |
+
def predict_gender_pronouns(
|
322 |
+
dataset,
|
323 |
+
bert_like_models,
|
324 |
+
normalizing,
|
325 |
+
input_text,
|
326 |
+
):
|
327 |
+
"""Run inference on input_text for each model type, returning df and plots of precentage
|
328 |
+
of gender pronouns predicted as female and male in each target text.
|
329 |
+
"""
|
330 |
+
|
331 |
+
male_gendered_token_ids, female_gendered_token_ids = get_gendered_token_ids(tokenizer)
|
332 |
+
if dataset == REDDIT:
|
333 |
+
indie_vars = SUBREDDITS
|
334 |
+
conditioning_variables = SUBREDDIT_CONDITIONING_VARIABLES
|
335 |
+
indie_var_name = 'subreddit'
|
336 |
+
else:
|
337 |
+
indie_vars = np.linspace(START_YEAR, STOP_YEAR, 20).astype(int)
|
338 |
+
conditioning_variables = WIKIBIO_CONDITIONING_VARIABLES
|
339 |
+
indie_var_name = 'date'
|
340 |
+
|
341 |
+
|
342 |
+
tokenized = get_tokenized_text_with_metadata(
|
343 |
+
input_text,
|
344 |
+
indie_vars,
|
345 |
+
dataset,
|
346 |
+
male_gendered_token_ids,
|
347 |
+
female_gendered_token_ids
|
348 |
+
)
|
349 |
+
num_preds = torch.sum(tokenized['ids'][0] == MASK_TOKEN_ID).item()
|
350 |
+
|
351 |
+
female_dfs = []
|
352 |
+
male_dfs = []
|
353 |
+
female_dfs.append(pd.DataFrame({indie_var_name: indie_vars}))
|
354 |
+
male_dfs.append(pd.DataFrame({indie_var_name: indie_vars}))
|
355 |
+
for var in conditioning_variables:
|
356 |
+
prefix = f"{var}_metadata"
|
357 |
+
model = models[(dataset, var)]
|
358 |
+
|
359 |
+
female_pronoun_preds = []
|
360 |
+
male_pronoun_preds = []
|
361 |
+
for indie_var_idx in range(len(tokenized['ids'])):
|
362 |
+
is_masked = tokenized['ids'][indie_var_idx] == MASK_TOKEN_ID
|
363 |
+
|
364 |
+
ids = tokenized["ids"][indie_var_idx]
|
365 |
+
atten_mask = tokenized["atten_mask"][indie_var_idx]
|
366 |
+
labels = tokenized["labels"][indie_var_idx]
|
367 |
+
|
368 |
+
with torch.no_grad():
|
369 |
+
outputs = model(ids.unsqueeze(dim=0),
|
370 |
+
atten_mask.unsqueeze(dim=0))
|
371 |
+
|
372 |
+
female_pronoun_preds.append(
|
373 |
+
get_avg_prob_from_finetuned_outputs(outputs,is_masked, num_preds, "female")
|
374 |
+
)
|
375 |
+
male_pronoun_preds.append(
|
376 |
+
get_avg_prob_from_finetuned_outputs(outputs,is_masked, num_preds, "male")
|
377 |
+
)
|
378 |
+
|
379 |
+
female_dfs.append(pd.DataFrame({prefix : female_pronoun_preds}))
|
380 |
+
male_dfs.append(pd.DataFrame({prefix : male_pronoun_preds}))
|
381 |
+
|
382 |
+
for bert_like in bert_like_models:
|
383 |
+
prefix = f"base_{bert_like}"
|
384 |
+
model = models[(BASE, bert_like)]
|
385 |
+
|
386 |
+
female_pronoun_preds = []
|
387 |
+
male_pronoun_preds = []
|
388 |
+
for indie_var_idx in range(len(tokenized['ids'])):
|
389 |
+
toks = tokenized["toks"][indie_var_idx]
|
390 |
+
target_text_for_bert = ' '.join(
|
391 |
+
toks[1:-1]) # Removing [CLS] and [SEP]
|
392 |
+
|
393 |
+
mask_filled_text = model(target_text_for_bert)
|
394 |
+
# Quick hack as realized return type based on how many MASKs in text.
|
395 |
+
if type(mask_filled_text[0]) is not list:
|
396 |
+
mask_filled_text = [mask_filled_text]
|
397 |
+
|
398 |
+
female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
|
399 |
+
mask_filled_text,
|
400 |
+
female_gendered_token_ids,
|
401 |
+
num_preds
|
402 |
+
))
|
403 |
+
male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
|
404 |
+
mask_filled_text,
|
405 |
+
male_gendered_token_ids,
|
406 |
+
num_preds
|
407 |
+
))
|
408 |
+
|
409 |
+
if normalizing:
|
410 |
+
total_gendered_probs = np.add(female_pronoun_preds, male_pronoun_preds)
|
411 |
+
female_pronoun_preds = np.around(
|
412 |
+
np.divide(female_pronoun_preds, total_gendered_probs)*100,
|
413 |
+
decimals=DECIMAL_PLACES
|
414 |
+
)
|
415 |
+
male_pronoun_preds = np.around(
|
416 |
+
np.divide(male_pronoun_preds, total_gendered_probs)*100,
|
417 |
+
decimals=DECIMAL_PLACES
|
418 |
+
)
|
419 |
+
|
420 |
+
female_dfs.append(pd.DataFrame({prefix : female_pronoun_preds}))
|
421 |
+
male_dfs.append(pd.DataFrame({prefix : male_pronoun_preds}))
|
422 |
+
|
423 |
+
# To display to user as an example
|
424 |
+
toks = tokenized["toks"][0]
|
425 |
+
target_text_w_masks = ' '.join(toks[1:-1])
|
426 |
+
|
427 |
+
# Plots / dataframe for display to users
|
428 |
+
female_results = pd.concat(female_dfs, axis=1).set_index(indie_var_name)
|
429 |
+
male_results = pd.concat(male_dfs, axis=1).set_index(indie_var_name)
|
430 |
+
|
431 |
+
female_fig = get_figure(female_results, dataset, "female", indie_var_name)
|
432 |
+
male_fig = get_figure(male_results, dataset, "male", indie_var_name)
|
433 |
+
female_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index?
|
434 |
+
male_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index?
|
435 |
+
|
436 |
+
return (
|
437 |
+
target_text_w_masks,
|
438 |
+
female_fig,
|
439 |
+
male_fig,
|
440 |
+
female_results,
|
441 |
+
male_results,
|
442 |
+
)
|
443 |
+
|
444 |
+
|
445 |
+
|
446 |
+
gr.Interface(
|
447 |
+
fn=predict_gender_pronouns,
|
448 |
+
inputs=[
|
449 |
+
gr.inputs.Radio(
|
450 |
+
[REDDIT, WIKIBIO],
|
451 |
+
default=WIKIBIO,
|
452 |
+
type="value",
|
453 |
+
label="Pick 'conditionally' fine-tuned model.",
|
454 |
+
optional=False,
|
455 |
+
),
|
456 |
+
gr.inputs.CheckboxGroup(
|
457 |
+
BERT_LIKE_MODELS,
|
458 |
+
default=[BERT_LIKE_MODELS[0]],
|
459 |
+
type="value",
|
460 |
+
label="Pick optional BERT base uncased model.",
|
461 |
+
),
|
462 |
+
gr.inputs.Dropdown(
|
463 |
+
["False", "True"],
|
464 |
+
label="Normalize BERT-like model's predictions to gendered-only?",
|
465 |
+
default = "True",
|
466 |
+
type="index",
|
467 |
+
),
|
468 |
+
gr.inputs.Textbox(
|
469 |
+
lines=5,
|
470 |
+
label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.",
|
471 |
+
default="She always walked past the building built in DATE on her way to her job as an elementary school teacher.",
|
472 |
+
),
|
473 |
+
],
|
474 |
+
outputs=[
|
475 |
+
gr.outputs.Textbox(
|
476 |
+
type="auto", label="Sample target text fed to model"),
|
477 |
+
gr.outputs.Plot(type="auto", label="Plot of softmax probability pronouns predicted female."),
|
478 |
+
gr.outputs.Plot(type="auto", label="Plot of softmax probability pronouns predicted male."),
|
479 |
+
gr.outputs.Dataframe(
|
480 |
+
overflow_row_behaviour="show_ends",
|
481 |
+
label="Table of softmax probability pronouns predicted female",
|
482 |
+
),
|
483 |
+
gr.outputs.Dataframe(
|
484 |
+
overflow_row_behaviour="show_ends",
|
485 |
+
label="Table of softmax probability pronouns predicted male",
|
486 |
+
),
|
487 |
+
],
|
488 |
+
).launch(debug=True)
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
torch
|
3 |
+
pandas
|
4 |
+
numpy
|
5 |
+
matplotlib
|