Emily McMilin commited on
Commit
68fec63
·
1 Parent(s): 9a5cfb0

first commit, not describing text

Browse files
Files changed (2) hide show
  1. app.py +488 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
5
+ from transformers import pipeline
6
+ import pandas as pd
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib.ticker import MaxNLocator
10
+
11
+
12
+ # DATASETS
13
+ REDDIT = 'reddit_finetuned'
14
+ WIKIBIO = 'wikibio_finetuned'
15
+ BASE = 'BERT_base'
16
+
17
+ # Play with me, consts
18
+ SUBREDDIT_CONDITIONING_VARIABLES = ["none", "subreddit"]
19
+ WIKIBIO_CONDITIONING_VARIABLES = ['none', 'birth_date', 'birth_place'] # EMILY!!
20
+
21
+ BERT_LIKE_MODELS = ["bert", "distilbert"]
22
+
23
+
24
+
25
+ ## Internal constants
26
+ GENDER_OPTIONS = ['female', 'male']
27
+ DECIMAL_PLACES = 1
28
+
29
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+ MAX_TOKEN_LENGTH = 32
32
+ NON_LOSS_TOKEN_ID = -100
33
+
34
+ # Picked ints that will pop out visually during debug
35
+ NON_GENDERED_TOKEN_ID = 30
36
+ LABEL_DICT = {GENDER_OPTIONS[0]: 9, GENDER_OPTIONS[1]: -9}
37
+ CLASSES = list(LABEL_DICT.keys())
38
+
39
+ MULTITOKEN_WOMAN_WORD = 'policewoman'
40
+ MULTITOKEN_MAN_WORD = 'spiderman'
41
+
42
+ # Wikibio conts
43
+
44
+ START_YEAR = 1800
45
+ STOP_YEAR = 1999
46
+ SPLIT_KEY = "DATE"
47
+
48
+ # Reddit consts
49
+
50
+ # List of randomly selected (tending towards those with seemingly more gender-neutral words)
51
+ # in order of increasing self-identified female participation.
52
+ # See http://bburky.com/subredditgenderratios/ , Minimum subreddit size: 100000
53
+ # Update: 400000
54
+ SUBREDDITS = [
55
+ "GlobalOffensive",
56
+ "pcmasterrace",
57
+ "nfl",
58
+ "sports",
59
+ "The_Donald",
60
+ "leagueoflegends",
61
+ "Overwatch",
62
+ "gonewild",
63
+ "Futurology",
64
+ "space",
65
+ "technology",
66
+ "gaming",
67
+ "Jokes",
68
+ "dataisbeautiful",
69
+ "woahdude",
70
+ "askscience",
71
+ "wow",
72
+ "anime",
73
+ "BlackPeopleTwitter",
74
+ "politics",
75
+ "pokemon",
76
+ "worldnews",
77
+ "reddit.com",
78
+ "interestingasfuck",
79
+ "videos",
80
+ "nottheonion",
81
+ "television",
82
+ "science",
83
+ "atheism",
84
+ "movies",
85
+ "gifs",
86
+ "Music",
87
+ "trees",
88
+ "EarthPorn",
89
+ "GetMotivated",
90
+ "pokemongo",
91
+ "news",
92
+ "fffffffuuuuuuuuuuuu",
93
+ "Fitness",
94
+ "Showerthoughts",
95
+ "OldSchoolCool",
96
+ "explainlikeimfive",
97
+ "todayilearned",
98
+ "gameofthrones",
99
+ "AdviceAnimals",
100
+ "DIY",
101
+ "WTF",
102
+ "IAmA",
103
+ "cringepics",
104
+ "tifu",
105
+ "mildlyinteresting",
106
+ "funny",
107
+ "pics",
108
+ "LifeProTips",
109
+ "creepy",
110
+ "personalfinance",
111
+ "food",
112
+ "AskReddit",
113
+ "books",
114
+ "aww",
115
+ "sex",
116
+ "relationships",
117
+ ]
118
+
119
+
120
+ # Fire up the models
121
+ models_paths = dict()
122
+ models = dict()
123
+
124
+ base_path = "emilylearning/"
125
+
126
+ # reddit finetuned models:
127
+ for var in SUBREDDIT_CONDITIONING_VARIABLES:
128
+ models_paths[(REDDIT, var)] = base_path + f'cond_ft_{var}_on_reddit__prcnt_100__test_run_False'
129
+ models[(REDDIT, var)] = AutoModelForTokenClassification.from_pretrained(
130
+ models_paths[(REDDIT, var)]
131
+ )
132
+
133
+ # wikibio finetuned models:
134
+ for var in WIKIBIO_CONDITIONING_VARIABLES:
135
+ models_paths[(WIKIBIO, var)] = base_path + f"cond_ft_{var}_on_wiki_bio__prcnt_100__test_run_False"
136
+ models[(WIKIBIO, var)] = AutoModelForTokenClassification.from_pretrained(
137
+ models_paths[(WIKIBIO, var)]
138
+ )
139
+
140
+ # BERT-like models:
141
+ for bert_like in BERT_LIKE_MODELS:
142
+ models_paths[(BASE, bert_like)] = f"{bert_like}-base-uncased"
143
+ models[(BASE, bert_like)] = pipeline(
144
+ "fill-mask", model=models_paths[(BASE, bert_like)])
145
+
146
+ # Tokenizers same for each model, so just grabbing one of them
147
+ tokenizer = AutoTokenizer.from_pretrained(
148
+ models_paths[(BASE, BERT_LIKE_MODELS[0])], add_prefix_space=True
149
+ )
150
+ MASK_TOKEN_ID = tokenizer.mask_token_id
151
+
152
+
153
+ def get_gendered_token_ids(tokenizer):
154
+
155
+ ## Set up gendered token constants
156
+ gendered_lists = [
157
+ ['he', 'she'],
158
+ ['him', 'her'],
159
+ ['his', 'hers'],
160
+ ["himself", "herself"],
161
+ ['male', 'female'],
162
+ ['man', 'woman'],
163
+ ['men', 'women'],
164
+ ["husband", "wife"],
165
+ ['father', 'mother'],
166
+ ['boyfriend', 'girlfriend'],
167
+ ['brother', 'sister'],
168
+ ["actor", "actress"],
169
+ ]
170
+ # Generating dicts here for potential later token reconstruction of predictions
171
+ male_gendered_dict = {list[0]: list for list in gendered_lists}
172
+ female_gendered_dict = {list[1]: list for list in gendered_lists}
173
+
174
+ male_gendered_token_ids = tokenizer.convert_tokens_to_ids(
175
+ list(male_gendered_dict.keys()))
176
+ female_gendered_token_ids = tokenizer.convert_tokens_to_ids(
177
+ list(female_gendered_dict.keys())
178
+ )
179
+
180
+ # Below technique is used to grab second token in a multi-token word
181
+ # There must be a better way...
182
+ multiword_woman_token_ids = tokenizer.encode(
183
+ MULTITOKEN_WOMAN_WORD, add_special_tokens=False)
184
+ assert len(multiword_woman_token_ids) == 2
185
+ subword_woman_token_id = multiword_woman_token_ids[1]
186
+
187
+ multiword_man_token_ids = tokenizer.encode(
188
+ MULTITOKEN_MAN_WORD, add_special_tokens=False)
189
+ assert len(multiword_man_token_ids) == 2
190
+ subword_man_token_id = multiword_man_token_ids[1]
191
+
192
+ male_gendered_token_ids.append(subword_man_token_id)
193
+ female_gendered_token_ids.append(subword_woman_token_id)
194
+
195
+ assert tokenizer.unk_token_id not in male_gendered_token_ids
196
+ assert tokenizer.unk_token_id not in female_gendered_token_ids
197
+
198
+ return male_gendered_token_ids, female_gendered_token_ids
199
+
200
+
201
+ def tokenize_and_append_metadata(text, tokenizer, female_gendered_token_ids, male_gendered_token_ids):
202
+ """Tokenize text and mask/flag 'gendered_tokens_ids' in token_ids and labels."""
203
+
204
+ label_list = list(LABEL_DICT.values())
205
+ assert label_list[0] == LABEL_DICT["female"], "LABEL_DICT not an ordered dict"
206
+ label2id = {label: idx for idx, label in enumerate(label_list)}
207
+
208
+ tokenized = tokenizer(
209
+ text,
210
+ truncation=True,
211
+ padding='max_length',
212
+ max_length=MAX_TOKEN_LENGTH,
213
+ )
214
+
215
+ # Finding the gender pronouns in the tokens
216
+ token_ids = tokenized["input_ids"]
217
+ female_tags = torch.tensor(
218
+ [
219
+ LABEL_DICT["female"]
220
+ if id in female_gendered_token_ids
221
+ else NON_GENDERED_TOKEN_ID
222
+ for id in token_ids
223
+ ]
224
+ )
225
+ male_tags = torch.tensor(
226
+ [
227
+ LABEL_DICT["male"]
228
+ if id in male_gendered_token_ids
229
+ else NON_GENDERED_TOKEN_ID
230
+ for id in token_ids
231
+ ]
232
+ )
233
+
234
+ # Labeling and masking out occurrences of gendered pronouns
235
+ labels = torch.tensor([NON_LOSS_TOKEN_ID] * len(token_ids))
236
+ labels = torch.where(
237
+ female_tags == LABEL_DICT["female"],
238
+ label2id[LABEL_DICT["female"]],
239
+ NON_LOSS_TOKEN_ID,
240
+ )
241
+ labels = torch.where(
242
+ male_tags == LABEL_DICT["male"], label2id[LABEL_DICT["male"]], labels
243
+ )
244
+ masked_token_ids = torch.where(
245
+ female_tags == LABEL_DICT["female"], MASK_TOKEN_ID, torch.tensor(
246
+ token_ids)
247
+ )
248
+ masked_token_ids = torch.where(
249
+ male_tags == LABEL_DICT["male"], MASK_TOKEN_ID, masked_token_ids
250
+ )
251
+
252
+ tokenized["input_ids"] = masked_token_ids
253
+ tokenized["labels"] = labels
254
+
255
+ return tokenized
256
+
257
+
258
+ def get_tokenized_text_with_metadata(input_text, indie_vars, dataset, male_gendered_token_ids, female_gendered_token_ids):
259
+ """Construct dict of tokenized texts with each year injected into the text."""
260
+ if dataset == WIKIBIO:
261
+ text_portions = input_text.split(SPLIT_KEY)
262
+ # If no SPLIT_KEY found in text, add space for metadata and whitespaces
263
+ if len(text_portions) == 1:
264
+ text_portions = ['Born in ', f" {text_portions[0]}"]
265
+
266
+
267
+ tokenized_w_metadata = {'ids': [], 'atten_mask': [], 'toks': [], 'labels': []}
268
+ for indie_var in indie_vars:
269
+ if dataset == WIKIBIO:
270
+ target_text = f"{indie_var}".join(text_portions)
271
+ else:
272
+ target_text = f"r/{indie_var}: {input_text}"
273
+
274
+ tokenized_sample = tokenize_and_append_metadata(
275
+ target_text,
276
+ tokenizer,
277
+ male_gendered_token_ids,
278
+ female_gendered_token_ids
279
+ )
280
+
281
+ tokenized_w_metadata['ids'].append(tokenized_sample["input_ids"])
282
+ tokenized_w_metadata['atten_mask'].append(
283
+ torch.tensor(tokenized_sample["attention_mask"]))
284
+ tokenized_w_metadata['toks'].append(
285
+ tokenizer.convert_ids_to_tokens(tokenized_sample["input_ids"]))
286
+ tokenized_w_metadata['labels'].append(tokenized_sample["labels"])
287
+
288
+ return tokenized_w_metadata
289
+
290
+
291
+ def get_avg_prob_from_finetuned_outputs(outputs, is_masked, num_preds, gender):
292
+ preds = torch.softmax(outputs[0][0].cpu(), dim=1, dtype=torch.double)
293
+ pronoun_preds = torch.where(is_masked, preds[:,CLASSES.index(gender)], 0.0)
294
+ return round(torch.sum(pronoun_preds).item() / num_preds * 100, DECIMAL_PLACES)
295
+
296
+
297
+ def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token_ids, num_preds):
298
+ pronoun_preds = [sum([
299
+ pronoun["score"] if pronoun["token"] in gendered_token_ids else 0.0
300
+ for pronoun in top_preds])
301
+ for top_preds in mask_filled_text
302
+ ]
303
+ return round(sum(pronoun_preds) / num_preds * 100, DECIMAL_PLACES)
304
+
305
+
306
+ def get_figure(results, dataset, gender, indie_var_name):
307
+ fig, ax = plt.subplots()
308
+ ax.plot(results)
309
+
310
+ if dataset == REDDIT:
311
+ ax.set_xlabel("Subreddit prepended to input text")
312
+ ax.xaxis.set_major_locator(MaxNLocator(6))
313
+ else:
314
+ ax.set_xlabel("Date injected into input text")
315
+ ax.set_title(f"Softmax probability of pronouns predicted {gender}\n by model type vs {indie_var_name}.")
316
+ ax.set_ylabel(f"Avg softmax prob for {gender} pronouns")
317
+ ax.legend(list(results.columns))
318
+ return fig
319
+
320
+
321
+ def predict_gender_pronouns(
322
+ dataset,
323
+ bert_like_models,
324
+ normalizing,
325
+ input_text,
326
+ ):
327
+ """Run inference on input_text for each model type, returning df and plots of precentage
328
+ of gender pronouns predicted as female and male in each target text.
329
+ """
330
+
331
+ male_gendered_token_ids, female_gendered_token_ids = get_gendered_token_ids(tokenizer)
332
+ if dataset == REDDIT:
333
+ indie_vars = SUBREDDITS
334
+ conditioning_variables = SUBREDDIT_CONDITIONING_VARIABLES
335
+ indie_var_name = 'subreddit'
336
+ else:
337
+ indie_vars = np.linspace(START_YEAR, STOP_YEAR, 20).astype(int)
338
+ conditioning_variables = WIKIBIO_CONDITIONING_VARIABLES
339
+ indie_var_name = 'date'
340
+
341
+
342
+ tokenized = get_tokenized_text_with_metadata(
343
+ input_text,
344
+ indie_vars,
345
+ dataset,
346
+ male_gendered_token_ids,
347
+ female_gendered_token_ids
348
+ )
349
+ num_preds = torch.sum(tokenized['ids'][0] == MASK_TOKEN_ID).item()
350
+
351
+ female_dfs = []
352
+ male_dfs = []
353
+ female_dfs.append(pd.DataFrame({indie_var_name: indie_vars}))
354
+ male_dfs.append(pd.DataFrame({indie_var_name: indie_vars}))
355
+ for var in conditioning_variables:
356
+ prefix = f"{var}_metadata"
357
+ model = models[(dataset, var)]
358
+
359
+ female_pronoun_preds = []
360
+ male_pronoun_preds = []
361
+ for indie_var_idx in range(len(tokenized['ids'])):
362
+ is_masked = tokenized['ids'][indie_var_idx] == MASK_TOKEN_ID
363
+
364
+ ids = tokenized["ids"][indie_var_idx]
365
+ atten_mask = tokenized["atten_mask"][indie_var_idx]
366
+ labels = tokenized["labels"][indie_var_idx]
367
+
368
+ with torch.no_grad():
369
+ outputs = model(ids.unsqueeze(dim=0),
370
+ atten_mask.unsqueeze(dim=0))
371
+
372
+ female_pronoun_preds.append(
373
+ get_avg_prob_from_finetuned_outputs(outputs,is_masked, num_preds, "female")
374
+ )
375
+ male_pronoun_preds.append(
376
+ get_avg_prob_from_finetuned_outputs(outputs,is_masked, num_preds, "male")
377
+ )
378
+
379
+ female_dfs.append(pd.DataFrame({prefix : female_pronoun_preds}))
380
+ male_dfs.append(pd.DataFrame({prefix : male_pronoun_preds}))
381
+
382
+ for bert_like in bert_like_models:
383
+ prefix = f"base_{bert_like}"
384
+ model = models[(BASE, bert_like)]
385
+
386
+ female_pronoun_preds = []
387
+ male_pronoun_preds = []
388
+ for indie_var_idx in range(len(tokenized['ids'])):
389
+ toks = tokenized["toks"][indie_var_idx]
390
+ target_text_for_bert = ' '.join(
391
+ toks[1:-1]) # Removing [CLS] and [SEP]
392
+
393
+ mask_filled_text = model(target_text_for_bert)
394
+ # Quick hack as realized return type based on how many MASKs in text.
395
+ if type(mask_filled_text[0]) is not list:
396
+ mask_filled_text = [mask_filled_text]
397
+
398
+ female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
399
+ mask_filled_text,
400
+ female_gendered_token_ids,
401
+ num_preds
402
+ ))
403
+ male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
404
+ mask_filled_text,
405
+ male_gendered_token_ids,
406
+ num_preds
407
+ ))
408
+
409
+ if normalizing:
410
+ total_gendered_probs = np.add(female_pronoun_preds, male_pronoun_preds)
411
+ female_pronoun_preds = np.around(
412
+ np.divide(female_pronoun_preds, total_gendered_probs)*100,
413
+ decimals=DECIMAL_PLACES
414
+ )
415
+ male_pronoun_preds = np.around(
416
+ np.divide(male_pronoun_preds, total_gendered_probs)*100,
417
+ decimals=DECIMAL_PLACES
418
+ )
419
+
420
+ female_dfs.append(pd.DataFrame({prefix : female_pronoun_preds}))
421
+ male_dfs.append(pd.DataFrame({prefix : male_pronoun_preds}))
422
+
423
+ # To display to user as an example
424
+ toks = tokenized["toks"][0]
425
+ target_text_w_masks = ' '.join(toks[1:-1])
426
+
427
+ # Plots / dataframe for display to users
428
+ female_results = pd.concat(female_dfs, axis=1).set_index(indie_var_name)
429
+ male_results = pd.concat(male_dfs, axis=1).set_index(indie_var_name)
430
+
431
+ female_fig = get_figure(female_results, dataset, "female", indie_var_name)
432
+ male_fig = get_figure(male_results, dataset, "male", indie_var_name)
433
+ female_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index?
434
+ male_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index?
435
+
436
+ return (
437
+ target_text_w_masks,
438
+ female_fig,
439
+ male_fig,
440
+ female_results,
441
+ male_results,
442
+ )
443
+
444
+
445
+
446
+ gr.Interface(
447
+ fn=predict_gender_pronouns,
448
+ inputs=[
449
+ gr.inputs.Radio(
450
+ [REDDIT, WIKIBIO],
451
+ default=WIKIBIO,
452
+ type="value",
453
+ label="Pick 'conditionally' fine-tuned model.",
454
+ optional=False,
455
+ ),
456
+ gr.inputs.CheckboxGroup(
457
+ BERT_LIKE_MODELS,
458
+ default=[BERT_LIKE_MODELS[0]],
459
+ type="value",
460
+ label="Pick optional BERT base uncased model.",
461
+ ),
462
+ gr.inputs.Dropdown(
463
+ ["False", "True"],
464
+ label="Normalize BERT-like model's predictions to gendered-only?",
465
+ default = "True",
466
+ type="index",
467
+ ),
468
+ gr.inputs.Textbox(
469
+ lines=5,
470
+ label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.",
471
+ default="She always walked past the building built in DATE on her way to her job as an elementary school teacher.",
472
+ ),
473
+ ],
474
+ outputs=[
475
+ gr.outputs.Textbox(
476
+ type="auto", label="Sample target text fed to model"),
477
+ gr.outputs.Plot(type="auto", label="Plot of softmax probability pronouns predicted female."),
478
+ gr.outputs.Plot(type="auto", label="Plot of softmax probability pronouns predicted male."),
479
+ gr.outputs.Dataframe(
480
+ overflow_row_behaviour="show_ends",
481
+ label="Table of softmax probability pronouns predicted female",
482
+ ),
483
+ gr.outputs.Dataframe(
484
+ overflow_row_behaviour="show_ends",
485
+ label="Table of softmax probability pronouns predicted male",
486
+ ),
487
+ ],
488
+ ).launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ pandas
4
+ numpy
5
+ matplotlib