Emily McMilin commited on
Commit
bae4168
·
1 Parent(s): 38542cb

adding baseline to plots and some clean up

Browse files
Files changed (1) hide show
  1. app.py +47 -30
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  from typing import Optional
2
  import gradio as gr
3
  import torch
@@ -16,41 +19,35 @@ BASE = 'BERT_base'
16
 
17
  # Play with me, consts
18
  SUBREDDIT_CONDITIONING_VARIABLES = ["none", "subreddit"]
19
- WIKIBIO_CONDITIONING_VARIABLES = ['none', 'birth_date', 'birth_place'] # EMILY!!
20
-
21
  BERT_LIKE_MODELS = ["bert", "distilbert"]
 
22
 
 
 
 
 
23
 
24
-
25
- ## Internal constants
26
  GENDER_OPTIONS = ['female', 'male']
27
  DECIMAL_PLACES = 1
28
-
29
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
-
31
- MAX_TOKEN_LENGTH = 32
32
- NON_LOSS_TOKEN_ID = -100
33
-
34
  # Picked ints that will pop out visually during debug
35
  NON_GENDERED_TOKEN_ID = 30
36
  LABEL_DICT = {GENDER_OPTIONS[0]: 9, GENDER_OPTIONS[1]: -9}
37
  CLASSES = list(LABEL_DICT.keys())
38
-
39
- MULTITOKEN_WOMAN_WORD = 'policewoman'
40
- MULTITOKEN_MAN_WORD = 'spiderman'
41
 
42
  # Wikibio conts
43
-
44
  START_YEAR = 1800
45
  STOP_YEAR = 1999
46
  SPLIT_KEY = "DATE"
47
 
48
  # Reddit consts
49
-
50
  # List of randomly selected (tending towards those with seemingly more gender-neutral words)
51
  # in order of increasing self-identified female participation.
52
- # See http://bburky.com/subredditgenderratios/ , Minimum subreddit size: 100000
53
- # Update: 400000
54
  SUBREDDITS = [
55
  "GlobalOffensive",
56
  "pcmasterrace",
@@ -263,12 +260,16 @@ def get_tokenized_text_with_metadata(input_text, indie_vars, dataset, male_gende
263
  if len(text_portions) == 1:
264
  text_portions = ['Born in ', f" {text_portions[0]}"]
265
 
266
-
267
  tokenized_w_metadata = {'ids': [], 'atten_mask': [], 'toks': [], 'labels': []}
268
  for indie_var in indie_vars:
 
269
  if dataset == WIKIBIO:
 
 
270
  target_text = f"{indie_var}".join(text_portions)
271
  else:
 
 
272
  target_text = f"r/{indie_var}: {input_text}"
273
 
274
  tokenized_sample = tokenize_and_append_metadata(
@@ -302,10 +303,22 @@ def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token_ids, num
302
  ]
303
  return round(sum(pronoun_preds) / num_preds * 100, DECIMAL_PLACES)
304
 
 
 
305
 
306
- def get_figure(results, dataset, gender, indie_var_name):
 
 
 
307
  fig, ax = plt.subplots()
308
- ax.plot(results)
 
 
 
 
 
 
 
309
 
310
  if dataset == REDDIT:
311
  ax.set_xlabel("Subreddit prepended to input text")
@@ -322,6 +335,7 @@ def predict_gender_pronouns(
322
  dataset,
323
  bert_like_models,
324
  normalizing,
 
325
  input_text,
326
  ):
327
  """Run inference on input_text for each model type, returning df and plots of precentage
@@ -330,15 +344,14 @@ def predict_gender_pronouns(
330
 
331
  male_gendered_token_ids, female_gendered_token_ids = get_gendered_token_ids(tokenizer)
332
  if dataset == REDDIT:
333
- indie_vars = SUBREDDITS
334
  conditioning_variables = SUBREDDIT_CONDITIONING_VARIABLES
335
  indie_var_name = 'subreddit'
336
  else:
337
- indie_vars = np.linspace(START_YEAR, STOP_YEAR, 20).astype(int)
338
  conditioning_variables = WIKIBIO_CONDITIONING_VARIABLES
339
  indie_var_name = 'date'
340
 
341
-
342
  tokenized = get_tokenized_text_with_metadata(
343
  input_text,
344
  indie_vars,
@@ -424,16 +437,15 @@ def predict_gender_pronouns(
424
  female_dfs.append(pd.DataFrame({prefix : female_pronoun_preds}))
425
  male_dfs.append(pd.DataFrame({prefix : male_pronoun_preds}))
426
 
427
- # To display to user as an example
428
- toks = tokenized["toks"][0]
429
- target_text_w_masks = ' '.join(toks[1:-1])
430
 
431
  # Plots / dataframe for display to users
432
  female_results = pd.concat(female_dfs, axis=1).set_index(indie_var_name)
433
  male_results = pd.concat(male_dfs, axis=1).set_index(indie_var_name)
434
-
435
- female_fig = get_figure(female_results, dataset, "female", indie_var_name)
436
- male_fig = get_figure(male_results, dataset, "male", indie_var_name)
437
  female_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index?
438
  male_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index?
439
 
@@ -446,7 +458,6 @@ def predict_gender_pronouns(
446
  )
447
 
448
 
449
-
450
  gr.Interface(
451
  fn=predict_gender_pronouns,
452
  inputs=[
@@ -469,6 +480,12 @@ gr.Interface(
469
  default = "True",
470
  type="index",
471
  ),
 
 
 
 
 
 
472
  gr.inputs.Textbox(
473
  lines=5,
474
  label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.",
 
1
+
2
+
3
+
4
  from typing import Optional
5
  import gradio as gr
6
  import torch
 
19
 
20
  # Play with me, consts
21
  SUBREDDIT_CONDITIONING_VARIABLES = ["none", "subreddit"]
22
+ WIKIBIO_CONDITIONING_VARIABLES = ['none', 'birth_date']
 
23
  BERT_LIKE_MODELS = ["bert", "distilbert"]
24
+ MAX_TOKEN_LENGTH = 32
25
 
26
+ # Internal markers for rendering
27
+ BASELINE_MARKER = 'baseline'
28
+ REDDIT_BASELINE_TEXT = ' '
29
+ WIKIBIO_BASELINE_TEXT = 'date'
30
 
31
+ ## Internal constants from training
 
32
  GENDER_OPTIONS = ['female', 'male']
33
  DECIMAL_PLACES = 1
34
+ MULTITOKEN_WOMAN_WORD = 'policewoman'
35
+ MULTITOKEN_MAN_WORD = 'spiderman'
 
 
 
 
36
  # Picked ints that will pop out visually during debug
37
  NON_GENDERED_TOKEN_ID = 30
38
  LABEL_DICT = {GENDER_OPTIONS[0]: 9, GENDER_OPTIONS[1]: -9}
39
  CLASSES = list(LABEL_DICT.keys())
40
+ NON_LOSS_TOKEN_ID = -100
 
 
41
 
42
  # Wikibio conts
 
43
  START_YEAR = 1800
44
  STOP_YEAR = 1999
45
  SPLIT_KEY = "DATE"
46
 
47
  # Reddit consts
 
48
  # List of randomly selected (tending towards those with seemingly more gender-neutral words)
49
  # in order of increasing self-identified female participation.
50
+ # See http://bburky.com/subredditgenderratios/ , Minimum subreddit size: 400000
 
51
  SUBREDDITS = [
52
  "GlobalOffensive",
53
  "pcmasterrace",
 
260
  if len(text_portions) == 1:
261
  text_portions = ['Born in ', f" {text_portions[0]}"]
262
 
 
263
  tokenized_w_metadata = {'ids': [], 'atten_mask': [], 'toks': [], 'labels': []}
264
  for indie_var in indie_vars:
265
+
266
  if dataset == WIKIBIO:
267
+ if indie_var == BASELINE_MARKER:
268
+ indie_var = WIKIBIO_BASELINE_TEXT
269
  target_text = f"{indie_var}".join(text_portions)
270
  else:
271
+ if indie_var == BASELINE_MARKER:
272
+ indie_var = REDDIT_BASELINE_TEXT
273
  target_text = f"r/{indie_var}: {input_text}"
274
 
275
  tokenized_sample = tokenize_and_append_metadata(
 
303
  ]
304
  return round(sum(pronoun_preds) / num_preds * 100, DECIMAL_PLACES)
305
 
306
+ def get_figure(results, dataset, gender, indie_var_name, include_baseline=True):
307
+ colors = ['b', 'g', 'c', 'm', 'y', 'r', 'k'] # assert no
308
 
309
+ # Grab then remove baselines from df
310
+ baseline = results.loc[BASELINE_MARKER]
311
+ results.drop(index=BASELINE_MARKER, axis=1, inplace=True)
312
+
313
  fig, ax = plt.subplots()
314
+ for i, col in enumerate(results.columns):
315
+ ax.plot(results[col], color=colors[i])#, color=colors)
316
+
317
+ if include_baseline == True:
318
+ for i, (name, value) in enumerate(baseline.items()):
319
+ if name == indie_var_name:
320
+ continue
321
+ ax.axhline(value, ls='--', color=colors[i])
322
 
323
  if dataset == REDDIT:
324
  ax.set_xlabel("Subreddit prepended to input text")
 
335
  dataset,
336
  bert_like_models,
337
  normalizing,
338
+ include_baseline,
339
  input_text,
340
  ):
341
  """Run inference on input_text for each model type, returning df and plots of precentage
 
344
 
345
  male_gendered_token_ids, female_gendered_token_ids = get_gendered_token_ids(tokenizer)
346
  if dataset == REDDIT:
347
+ indie_vars = [BASELINE_MARKER] + SUBREDDITS
348
  conditioning_variables = SUBREDDIT_CONDITIONING_VARIABLES
349
  indie_var_name = 'subreddit'
350
  else:
351
+ indie_vars = [BASELINE_MARKER] + np.linspace(START_YEAR, STOP_YEAR, 20).astype(int).tolist()
352
  conditioning_variables = WIKIBIO_CONDITIONING_VARIABLES
353
  indie_var_name = 'date'
354
 
 
355
  tokenized = get_tokenized_text_with_metadata(
356
  input_text,
357
  indie_vars,
 
437
  female_dfs.append(pd.DataFrame({prefix : female_pronoun_preds}))
438
  male_dfs.append(pd.DataFrame({prefix : male_pronoun_preds}))
439
 
440
+ # Pick a sample to display to user as an example
441
+ toks = tokenized["toks"][3]
442
+ target_text_w_masks = ' '.join(toks[1:-1]) # Removing [CLS] and [SEP]
443
 
444
  # Plots / dataframe for display to users
445
  female_results = pd.concat(female_dfs, axis=1).set_index(indie_var_name)
446
  male_results = pd.concat(male_dfs, axis=1).set_index(indie_var_name)
447
+ female_fig = get_figure(female_results, dataset, "female", indie_var_name, include_baseline)
448
+ male_fig = get_figure(male_results, dataset, "male", indie_var_name, include_baseline)
 
449
  female_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index?
450
  male_results.reset_index(inplace=True) # Gradio Dataframe doesn't 'see' index?
451
 
 
458
  )
459
 
460
 
 
461
  gr.Interface(
462
  fn=predict_gender_pronouns,
463
  inputs=[
 
480
  default = "True",
481
  type="index",
482
  ),
483
+ gr.inputs.Dropdown(
484
+ ["False", "True"],
485
+ label="Include baseline predictions (dashed-lines)?",
486
+ default = "True",
487
+ type="index",
488
+ ),
489
  gr.inputs.Textbox(
490
  lines=5,
491
  label="Input Text: Sentence about a single person using some gendered pronouns to refer to them.",