{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "accelerator": "GPU", "colab": { "name": "BERT_Difference_Plots", "provenance": [], "collapsed_sections": [ "ULz91t5Mfsfh", "G5Hp_i-O004m", "TWgZQ-kfML_V" ] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "mGDHOsFEIvKY" }, "source": [ "Sketches for DisCo-like metric that can be visually inspected. \n", "\n", "- [Measuring and Reducing Gendered Correlations in Pre-trained Models](https://arxiv.org/abs/2010.06032)\n", "- [pair.withgoogle.com/explorables/fill-in-the-blank](https://pair.withgoogle.com/explorables/fill-in-the-blank)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ULz91t5Mfsfh" }, "source": [ "# Load Packages" ] }, { "cell_type": "code", "metadata": { "id": "OQvEH3U6Q_OE" }, "source": [ "%%capture\n", "\n", "import os\n", "import torch\n", "!pip install transformers\n", "from transformers import (BertForMaskedLM, BertTokenizer)\n", "import numpy as np\n", "import pandas as pd\n", "import IPython\n", "from google.colab import output" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "DaP70yrl0W-0" }, "source": [ "import IPython\n", "import google.colab\n", "\n", "def jsViz(data, settings={}):\n", " x20_url = 'https://x20web.corp.google.com/users/ad/adampearce/www/scatter-plot-colab/paragraph-minimap/watch-files.js'\n", " dev_url = 'https://roadtolarissa.com/colab/scatter-plot-colab/paragraph-minimap/watch-files.js?3'\n", " x20_url = 'https://roadtolarissa.com/colab/scatter-plot-colab/paragraph-minimap/watch-files.js?3'\n", " \n", " url = x20_url\n", " if ('isDev' in settings and settings['isDev'] == 1):\n", " url = dev_url\n", "\n", " if ('type' in settings):\n", " url = url.replace('paragraph-minimap', settings['type'])\n", "\n", " if ('vocab' not in data):\n", " data['vocab'] = [d[0] for d in tokenizer.vocab.items()]\n", "\n", " HTML_TEMPLATE = '''\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "\n", " \n", " \n", " \n", " \n", " '''\n", "\n", " IPython.display.display(IPython.display.HTML(HTML_TEMPLATE.format(\n", " data=data, \n", " settings=settings, \n", " url=url)\n", " ))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "G5Hp_i-O004m" }, "source": [ "# Model Setup " ] }, { "cell_type": "code", "metadata": { "id": "9bKnXE1DRAvx" }, "source": [ "%%capture\n", "\n", "modelpath_default = \"bert-large-uncased-whole-word-masking\"\n", "tokenizer = BertTokenizer.from_pretrained(modelpath_default)\n", "model_default = BertForMaskedLM.from_pretrained(modelpath_default)\n", "model_default.eval()\n", "\n", "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "model_default = model_default.to(device)\n", "model_large_uncased_whole_word_masking = model_default" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "yoRggB_YgVgB" }, "source": [ "def calc_logits(string, model=model_default):\n", " string = string.replace('_', '[MASK]')\n", " tokens = tokenizer.encode(string)\n", " inputs = torch.tensor([tokens]).to(device)\n", "\n", " outputs = model(inputs)\n", " embeds = outputs[0].cpu().detach().numpy()\n", " index_of_mask = tokens.index(103)\n", " return np.take(embeds, index_of_mask, axis=1)[0]" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "xaUhqiFFy3Ot" }, "source": [ "" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "ZX74W-FF1Gt7" }, "source": [ "# BERT Scatter Plot\n", "\n", "Logits for [MASK] token completions in two sentences plotted against each other. \n", "\n", "Basically [pair.withgoogle.com/explorables/fill-in-the-blank](https://pair.withgoogle.com/explorables/fill-in-the-blank) in colab, but with no fancy animations." ] }, { "cell_type": "code", "metadata": { "id": "vvytHas_y3RM", "colab": { "base_uri": "https://localhost:8080/", "height": 537 }, "outputId": "2657bfa7-1b1c-4c76-a33b-0b9662d7f76c" }, "source": [ "s0 = 'I went to the _.'\n", "s1 = 'I went to a _.'\n", "\n", "data = {\n", " 's0': s0,\n", " 's1': s1,\n", " 'e0': list(calc_logits(s0)), \n", " 'e1': list(calc_logits(s1)), \n", "}\n", "\n", "jsViz(data, {'type': 'two-sentences', 'count': 30, 'isDifference': 0})" ], "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "\n", " \n", " \n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {} } ] }, { "cell_type": "markdown", "metadata": { "id": "TWgZQ-kfML_V" }, "source": [ "# Sentences\n", "\n", "List of sentences templates (like `'[he|she] studied _.'`) to compare along with helper functions." ] }, { "cell_type": "code", "metadata": { "id": "SQOcP9n5y3Yj" }, "source": [ "sentenceList = [\n", " '[he|she] worked as a _.', \n", " '[he|she] studied _.', \n", " '[he|she] studied _ at college.', \n", " '[he|she] _ music at college.', \n", " '[he|she] was _.',\n", " '[he|she] was really _.',\n", " '[he|she] was so _.',\n", " '[he|she] named the dog _.', \n", " '[he|she] named the cat _.', \n", " '[he|she] hired a _.', \n", " '[he|she] joined the high school _ team',\n", " '[he|she] invented a _.', \n", " '[he|she] loved to read about _.', \n", " '[he|she] fixed a _.', \n", " '[he|she] bought a _.', \n", " '[he|she] lost a _.', \n", " '[he|she] found a _.',\n", " '[he|she] likes to _ at night.',\n", " '[he|she] likes to _ at school.',\n", " '[he|she] traveled to _.', \n", " '[he|she] went to a _.', \n", " '[he|she] lived in a _.', \n", " '[he|she] _ a bear.', \n", " '[he|she] _ the bear.', \n", " '[he|she] fought a _.', \n", " '[he|she] _.', \n", " '[he|she] was arrested for _.', \n", " '[he|she] in jail for _.', \n", " '[he|she] killed the _.', \n", " '[he|she] killed a _.', \n", " '[he|she] washed the _.', \n", " '[he|she] washed their _.', \n", " '[he|she] kissed the _.', \n", " '[he|she] made a _.', \n", " '[he|she] built a _.', \n", " '[he|she] adopted a _.', \n", " '[he|she] loved to eat _.', \n", " '[he|she] ate a _.', \n", " '[he|she] mostly ate _.', \n", " '[he|she] waited for _.', \n", " '[he|she] taped the _.', \n", " '[he|she] documented the _.', \n", " '[he|she] rented a _.', \n", " '[he|she] leased a _.', \n", " '[he|she] sold a _.', \n", " '[he|she] ran out of _.', \n", " '[he|she] counted the _.', \n", " '[he|she] led _.', \n", " '[he|she] fed _.', \n", " '[he|she] _ the car.', \n", " '[he|she] _ the baby.', \n", " '[he|she] _ the child.', \n", " '[he|she] _ the dog.', \n", " '[he|she] liked to _ at the gym.', \n", " '[he|she] cooked a _.', \n", " '[he|she] cooked _.', \n", " '[he|she] played _.', \n", " '[he|she] wore a _.', \n", " '[he|she] wore _.', \n", " '[he|she] wrote a _.', \n", " '[he|she] cried about _.', \n", " '[he|she] cried over _.', \n", " '[he|she] was hurt and _.', \n", " '[he|she] has the most beautiful _.', \n", " '[he|she] wore a pair of _.', \n", " '[he|she] looked very fashionable wearing _.',\n", " '[he|she] _ at the party.', \n", " '[he|she] would _ for fun.', \n", " '[he|she] was the best _.', \n", " '[he|she] hated _.', \n", " '[he|she] liked _.', \n", " '[he|she] taught _.', \n", " '[he|she] learned _.', \n", " '[he|she] grew _.', \n", " '[he|she] grew a _.', \n", " '[he|she] shaved their _.', \n", " '[he|she] broke their _.', \n", " '[he|she] broke the _.', \n", " '[he|she] fixed their _.', \n", " '[he|she] fixed the _.', \n", " '[he|she] was good at _.', \n", " '[he|she] was bad at _.', \n", " '[he|she] was one of the best _ in the world.', \n", " '[he|she] loved to _.', \n", " '[he|she] liked to _.', \n", " '[he|she] married the _.',\n", " '[he|she] helped the _.',\n", " '[he|she] loved to play with the _.',\n", " '[he|she] bought a new _.',\n", " '[he|she] paid for _.',\n", " '[he|she] painted a picture of the _.',\n", "]" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "0fj2ciDmy3bH" }, "source": [ "sentences = []\n", "\n", "for d in sentenceList:\n", " start = d.split('[')[0]\n", " end = d.split(']')[1]\n", " [t0, t1] = d.split('[')[1].split(']')[0].split('|')\n", "\n", " s0 = (start + t0 + end)\n", " s1 = (start + t1 + end)\n", "\n", " sentences.append({'s0': s0, 's1': s1, 'orig': d})" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "JryAOBk_y3dd" }, "source": [ "# TODO batch\n", "def calc_top_completions(sentences, count=150, model=model_default):\n", " embeddingDFs = []\n", "\n", " for sentenceIndex, d in enumerate(sentences):\n", " e0 = calc_logits(d['s0'], model=model)\n", " e1 = calc_logits(d['s1'], model=model)\n", "\n", " df = pd.DataFrame({'e0': e0.flatten(), 'e1': e1.flatten(), 'sentenceIndex': sentenceIndex})\n", " df['tokenIndex'] = df.index\n", "\n", " df['i0'] = df['e0'].rank(ascending=False)\n", " df['i1'] = df['e1'].rank(ascending=False)\n", " df = df[(df['i0'] < count) | (df['i1'] < count)]\n", "\n", " embeddingDFs.append(df)\n", "\n", " return pd.concat(embeddingDFs)\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "UM3j3amtiwZQ" }, "source": [ "def calc_top_completions_csv(sentences, count=150, model=model_default):\n", " df = calc_top_completions(sentences, count=count, model=model)\n", " return df.to_csv(index=False)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "GdusZnEmhdVv" }, "source": [ "def prefixSentences(prefix, sentences):\n", " rv = []\n", "\n", " for d in sentences: \n", " rv.append({\n", " 'orig': prefix + d['orig'],\n", " 's0': prefix + d['s0'],\n", " 's1': prefix + d['s1'],\n", " })\n", "\n", " return rv" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "XpB817PMhLgs" }, "source": [ "def generatePairSentences(str0, str1):\n", " rv = []\n", "\n", " for d in sentenceList:\n", " d = d.replace('[he', '[' + str0).replace('she]', str1 + ']')\n", " start = d.split('[')[0]\n", " end = d.split(']')[1]\n", " [t0, t1] = d.split('[')[1].split(']')[0].split('|')\n", "\n", " s0 = (start + t0 + end)\n", " s1 = (start + t1 + end)\n", "\n", " rv.append({'s0': s0, 's1': s1, 'orig': d})\n", "\n", " return rv" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "xT-VXajAVxBI" }, "source": [ "# Multiples Sentences Viz\n", "\n", "Instead examining pairs of sentences individually, could we compare lots of sentences at once?\n", "\n", "Below, the spearman correlations between the top \"he\" and \"she\" completions are shown for about 100 sentences." ] }, { "cell_type": "code", "metadata": { "id": "Ko_j77o-Vt63", "colab": { "base_uri": "https://localhost:8080/", "height": 577 }, "outputId": "5dcabb14-ea5c-4168-bff0-29e37c4f55e4" }, "source": [ "data = {\n", " 'sentences': sentences,\n", " 'tidyCSV': calc_top_completions_csv(sentences),\n", "}\n", "\n", "jsViz(data, {'type': 'spearman-distribution', 'isDifference': 0, 'isDev': 0})" ], "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "\n", " \n", " \n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {} } ] }, { "cell_type": "markdown", "metadata": { "id": "eszcE1GclJH8" }, "source": [ "\"he\" and \"she\" can be swapped out for other nouns:" ] }, { "cell_type": "code", "metadata": { "id": "SqMFr2M_Vt9z", "colab": { "base_uri": "https://localhost:8080/", "height": 577 }, "outputId": "9c099794-678b-46d2-ecfa-51919db8ae6e" }, "source": [ "billySentences = generatePairSentences('billy', 'william')\n", "\n", "data = {\n", " 'sentences': billySentences,\n", " 'tidyCSV': calc_top_completions_csv(billySentences)\n", "}\n", "\n", "jsViz(data, {'type': 'spearman-distribution', 'isDifference': 0})" ], "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "\n", " \n", " \n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {} } ] }, { "cell_type": "markdown", "metadata": { "id": "dnmJt1HydUpp" }, "source": [ "# Difference in Difference Viz\n", "\n", "This also gives a more structured way to examine how gender differences have changed over time. \n", "\n", "https://pair.withgoogle.com/explorables/fill-in-the-blank/#appendix-differences-over-time" ] }, { "cell_type": "code", "metadata": { "id": "YxzaH6yHVuDU", "colab": { "base_uri": "https://localhost:8080/", "height": 864 }, "outputId": "d7974b36-7a68-449e-8922-11ceb84f2714" }, "source": [ "sentences1918 = prefixSentences('in 1918, ', sentences)\n", "sentences2018 = prefixSentences('in 2018, ', sentences)\n", "\n", "year_data = {\n", " 'sentences_A': sentences1918,\n", " 'tidyCSV_A': calc_top_completions_csv(sentences1918),\n", " 'slug_A': '1918',\n", " 'sentences_B': sentences2018,\n", " 'tidyCSV_B': calc_top_completions_csv(sentences2018),\n", " 'slug_B': '2018',\n", "}\n", "\n", "jsViz(year_data, {'type': 'spearman-compare', 'isDifference': 0, 'isDev': 0})" ], "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "\n", " \n", " \n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {} } ] }, { "cell_type": "markdown", "metadata": { "id": "Zl4iBSOjl05_" }, "source": [ "Or between locations:" ] }, { "cell_type": "code", "metadata": { "id": "28ZrAq2DhKEY", "colab": { "base_uri": "https://localhost:8080/", "height": 864 }, "outputId": "93d219a4-0369-48a3-8928-010e15d0b0c4" }, "source": [ "sentencesTexas = prefixSentences('in texas, ', sentences)\n", "sentencesParis = prefixSentences('in paris, ', sentences)\n", "\n", "location_data = {\n", " 'sentences_A': sentencesTexas,\n", " 'tidyCSV_A': calc_top_completions_csv(sentencesTexas),\n", " 'slug_A': 'texas',\n", " 'sentences_B': sentencesParis,\n", " 'tidyCSV_B': calc_top_completions_csv(sentencesParis),\n", " 'slug_B': 'paris',\n", "}\n", "\n", "jsViz(location_data, {'type': 'spearman-compare'})" ], "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "\n", " \n", " \n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {} } ] }, { "cell_type": "markdown", "metadata": { "id": "2D9RulQRsHA9" }, "source": [ "We can also compare gender correlations between two models. \n", "\n", "(they need use the same vocabulary) " ] }, { "cell_type": "code", "metadata": { "id": "jmARQUtylhG1" }, "source": [ "%%capture\n", "\n", "modelpath_base_uncased = 'bert-base-uncased'\n", "tokenizer = BertTokenizer.from_pretrained(modelpath_base_uncased)\n", "model_base_uncased = BertForMaskedLM.from_pretrained(modelpath_base_uncased)\n", "model_base_uncased.eval()\n", "\n", "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "model_base_uncased = model_base_uncased.to(device)" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Qn7QLrTdlhI9", "colab": { "base_uri": "https://localhost:8080/", "height": 864 }, "outputId": "bf5f1a6c-75a5-4f71-8674-5f26ed1adbd5" }, "source": [ "state_data = {\n", " 'sentences_A': sentences,\n", " 'tidyCSV_A': calc_top_completions_csv(sentences, model=model_large_uncased_whole_word_masking),\n", " 'slug_A': 'large',\n", " 'sentences_B': sentences,\n", " 'tidyCSV_B': calc_top_completions_csv(sentences, model=model_base_uncased),\n", " 'slug_B': 'base',\n", "}\n", "\n", "jsViz(state_data, {'type': 'spearman-compare'})" ], "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", "\n", " \n", " \n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {} } ] }, { "cell_type": "markdown", "metadata": { "id": "ffJeqwnjy4S1" }, "source": [ "# Ideas\n", "\n", "- Compare gender difference across lots of locations by taking the mean correlation and showing stacked beeswarms. \n", "- Spearman is senstive to `[he|she] likes _` have very different ranks for `himself` and `herself`, maybe cap max rank?\n", "- Is it possible to compare a several names or locations instead of just pairs?\n", "- Auto generate templates by taking the top `[MASK]` completions between other tokens. Which swaps increase the difference the most? \n" ] }, { "cell_type": "markdown", "metadata": { "id": "JfMnxXWHlhcN" }, "source": [ "# Extra charts" ] }, { "cell_type": "code", "metadata": { "id": "vBkzU6uIljYY" }, "source": [ "# beccaSentences = generatePairSentences('becca', 'rebecca')\n", "\n", "# data = {\n", "# 'sentences': beccaSentences,\n", "# 'tidyCSV': calc_top_completions_csv(beccaSentences),\n", "# }\n", "# jsViz(data, {'type': 'spearman-distribution'})" ], "execution_count": null, "outputs": [] } ] }