{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "accelerator": "GPU", "colab": { "name": "BERT_Difference_Plots", "provenance": [], "collapsed_sections": [ "ULz91t5Mfsfh", "G5Hp_i-O004m", "TWgZQ-kfML_V" ] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "mGDHOsFEIvKY" }, "source": [ "Sketches for DisCo-like metric that can be visually inspected. \n", "\n", "- [Measuring and Reducing Gendered Correlations in Pre-trained Models](https://arxiv.org/abs/2010.06032)\n", "- [pair.withgoogle.com/explorables/fill-in-the-blank](https://pair.withgoogle.com/explorables/fill-in-the-blank)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ULz91t5Mfsfh" }, "source": [ "# Load Packages" ] }, { "cell_type": "code", "metadata": { "id": "OQvEH3U6Q_OE" }, "source": [ "%%capture\n", "\n", "import os\n", "import torch\n", "!pip install transformers\n", "from transformers import (BertForMaskedLM, BertTokenizer)\n", "import numpy as np\n", "import pandas as pd\n", "import IPython\n", "from google.colab import output" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "DaP70yrl0W-0" }, "source": [ "import IPython\n", "import google.colab\n", "\n", "def jsViz(data, settings={}):\n", " x20_url = 'https://x20web.corp.google.com/users/ad/adampearce/www/scatter-plot-colab/paragraph-minimap/watch-files.js'\n", " dev_url = 'https://roadtolarissa.com/colab/scatter-plot-colab/paragraph-minimap/watch-files.js?3'\n", " x20_url = 'https://roadtolarissa.com/colab/scatter-plot-colab/paragraph-minimap/watch-files.js?3'\n", " \n", " url = x20_url\n", " if ('isDev' in settings and settings['isDev'] == 1):\n", " url = dev_url\n", "\n", " if ('type' in settings):\n", " url = url.replace('paragraph-minimap', settings['type'])\n", "\n", " if ('vocab' not in data):\n", " data['vocab'] = [d[0] for d in tokenizer.vocab.items()]\n", "\n", " HTML_TEMPLATE = '''\n", " \n", " \n", " \n", " \n", " \n", "
\n", "\n", " \n", " \n", " \n", " \n", " '''\n", "\n", " IPython.display.display(IPython.display.HTML(HTML_TEMPLATE.format(\n", " data=data, \n", " settings=settings, \n", " url=url)\n", " ))" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "G5Hp_i-O004m" }, "source": [ "# Model Setup " ] }, { "cell_type": "code", "metadata": { "id": "9bKnXE1DRAvx" }, "source": [ "%%capture\n", "\n", "modelpath_default = \"bert-large-uncased-whole-word-masking\"\n", "tokenizer = BertTokenizer.from_pretrained(modelpath_default)\n", "model_default = BertForMaskedLM.from_pretrained(modelpath_default)\n", "model_default.eval()\n", "\n", "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "model_default = model_default.to(device)\n", "model_large_uncased_whole_word_masking = model_default" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "yoRggB_YgVgB" }, "source": [ "def calc_logits(string, model=model_default):\n", " string = string.replace('_', '[MASK]')\n", " tokens = tokenizer.encode(string)\n", " inputs = torch.tensor([tokens]).to(device)\n", "\n", " outputs = model(inputs)\n", " embeds = outputs[0].cpu().detach().numpy()\n", " index_of_mask = tokens.index(103)\n", " return np.take(embeds, index_of_mask, axis=1)[0]" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "xaUhqiFFy3Ot" }, "source": [ "" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "ZX74W-FF1Gt7" }, "source": [ "# BERT Scatter Plot\n", "\n", "Logits for [MASK] token completions in two sentences plotted against each other. \n", "\n", "Basically [pair.withgoogle.com/explorables/fill-in-the-blank](https://pair.withgoogle.com/explorables/fill-in-the-blank) in colab, but with no fancy animations." ] }, { "cell_type": "code", "metadata": { "id": "vvytHas_y3RM", "colab": { "base_uri": "https://localhost:8080/", "height": 537 }, "outputId": "2657bfa7-1b1c-4c76-a33b-0b9662d7f76c" }, "source": [ "s0 = 'I went to the _.'\n", "s1 = 'I went to a _.'\n", "\n", "data = {\n", " 's0': s0,\n", " 's1': s1,\n", " 'e0': list(calc_logits(s0)), \n", " 'e1': list(calc_logits(s1)), \n", "}\n", "\n", "jsViz(data, {'type': 'two-sentences', 'count': 30, 'isDifference': 0})" ], "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " Click to authenticate\n", " \n", "\n", " \n", " \n", " \n", " \n", " " ], "text/plain": [ " \n", "