literacy / plotsjs.js
Gabriela Nicole Gonzalez Saez
init
c2ad8fd
async () => {
// set testFn() function on globalThis, so you html onlclick can access it
globalThis.testFn = () => {
document.getElementById('demo').innerHTML = "Hello?"
};
const d37 = await import("https://cdn.jsdelivr.net/npm/d3@7/+esm");
const d3 = await import("https://cdn.jsdelivr.net/npm/d3@5/+esm");
const $ = await import("https://cdn.jsdelivr.net/npm/[email protected]/dist/jquery.min.js");
globalThis.$ = $;
globalThis.d3 = d3;
globalThis.d3Fn = () => {
d3.select('#viz').append('svg')
.append('rect')
.attr('width', 50)
.attr('height', 50)
.attr('fill', 'black')
.on('mouseover', function(){d3.select(this).attr('fill', 'red')})
.on('mouseout', function(){d3.select(this).attr('fill', 'black')});
};
globalThis.testFn_out = (val,radio_c) => {
// document.getElementById('demo').innerHTML = val
console.log(val);
// globalThis.d3Fn();
return([val,radio_c]);
};
globalThis.testFn_out_json = (data) => {
console.log(data);
var $ = jQuery;
data_beam = data[1][0];
data_probs = data[1][1];
data_html_inputs = data[1][2];
data_html_target = data[1][3];
data_embds = data[2];
attViz(data[3]);
attViz(data[4]);
attViz(data[5]);
console.log(data_beam, )
const idMapping = data_beam.reduce((acc, el, i) => {
acc[el.id] = i;
return acc;
}, {});
let root;
data_beam.forEach(el => {
// Handle the root element
if (el.parentId === null) {
root = el;
return;
}
// Use our mapping to locate the parent element in our data_beam array
const parentEl = data_beam[idMapping[el.parentId]];
// Add our current el to its parent's `children` array
parentEl.children = [...(parentEl.children || []), el];
});
// console.log(Tree(root));
// document.getElementById('d3_beam_search').innerHTML = Tree(root)
d3.select('#d3_beam_search').html("");
d3.select('#d3_beam_search').append(function(){return Tree(root);});
//probabilities;
//
d3.select('#d3_text_grid').html("");
d3.select('#d3_text_grid').append(function(){return TextGrid(data_probs);});
// $('#d3_text_grid').html(TextGrid(data)) ;
//tokenization;
d3.select('#d3_tok').html(data_html_inputs);
d3.select('#d3_tok_target').html(data_html_target);
//embeddings
d3.select("#d3_embeds_source").html("here");
// words or token visualization ?
console.log(d3.select("#select_type").node().value);
d3.select("#select_type").attr("hidden", null);
d3.select("#select_type").on("change", change);
change();
// tokens
// network plots;
['input', 'output'].forEach(text_type => {
['tokens', 'words'].forEach(text_key => {
// console.log(type, key, data[0][text_type]);
data_i = data_embds[text_type][text_key];
embeddings_network([], data_i['tnse'], data_i['similar_queries'], type=text_type +"_"+text_key, )
});
});
// $('#d3_beam_search').html(Tree(root)) ;
return(['string', {}])
}
function change() {
show_type = d3.select("#select_type").node().value;
// hide all
d3.selectAll(".d3_embed").attr("hidden",'');
d3.selectAll(".d3_graph").attr("hidden", '');
// show current type;
d3.select("#d3_embeds_input_" + show_type).attr("hidden", null);
d3.select("#d3_embeds_output_" + show_type).attr("hidden", null);
d3.select("#d3_graph_input_" + show_type).attr("hidden", null);
d3.select("#d3_graph_output_" + show_type).attr("hidden", null);
}
function embeddings_network(tokens_text, dict_projected_embds, similar_vocab_queries, type="source", ){
// tokens_text : not used;
// dict_projected_embds = tnse
console.log("Each token is a node; distance if in similar list", type );
console.log(tokens_text, dict_projected_embds, similar_vocab_queries);
// similar_vocab_queries_target[key]['similar_topk']
var nodes_tokens = {}
var nodeHash = {};
var nodes = []; // [{id: , label: }]
var edges = []; // [{source: , target: weight: }]
var edges_ids = []; // [{source: , target: weight: }]
// similar_vocab_queries {key: {similar_topk : [], distance : []}}
console.log('similar_vocab_queries', similar_vocab_queries);
prev_node = '';
for ([sent_token, value] of Object.entries(similar_vocab_queries)) {
// console.log('dict_projected_embds',sent_token, parseInt(sent_token), value, dict_projected_embds);
// sent_token = parseInt(sent_token); // Object.entries assumes key:string;
token_text = dict_projected_embds[sent_token][3]
if (!nodeHash[sent_token]) {
nodeHash[sent_token] = {id: sent_token, label: token_text, type: 'sentence', type_i: 0};
nodes.push(nodeHash[sent_token]);
}
sim_tokens = value['similar_topk']
dist_tokens = value['distance']
for (let index = 0; index < sim_tokens.length; index++) {
const sim = sim_tokens[index];
const dist = dist_tokens[index];
token_text_sim = dict_projected_embds[sim][3]
if (!nodeHash[sim]) {
nodeHash[sim] = {id: sim, label: token_text_sim, type:'similar', type_i: 1};
nodes.push(nodeHash[sim]);
}
edges.push({source: nodeHash[sent_token], target: nodeHash[sim], weight: dist});
edges_ids.push({source: sent_token, target: sim, weight: dist});
}
if (prev_node != '' ) {
edges.push({source: nodeHash[prev_node], target:nodeHash[sent_token], weight: 1});
edges_ids.push({source: prev_node, target: sent_token, weight: 1});
}
prev_node = sent_token;
}
console.log("TYPE", type, edges, nodes, edges_ids, similar_vocab_queries)
// d3.select('#d3_graph_input_tokens').html(networkPlot({nodes: nodes, links:edges}, similar_vocab_queries, div_type=type) );
// type +"_"+key
d3.select('#d3_graph_'+type).html("");
d3.select('#d3_graph_'+type).append(function(){return networkPlot({nodes: nodes, links:edges}, similar_vocab_queries, dict_projected_embds,div_type=type);});
// $('#d3_embeds_network_target').html(networkPlot({nodes: nodes, links:edges}));
// $('#d3_embeds_network_'+type).html(etworkPlot({nodes: nodes, link:edges}));
}
function networkPlot(data, similar_vocab_queries,dict_proj, div_type="source", {
width = 400, // outer width, in pixels
height , // outer height, in pixels
r = 3, // radius of nodes
padding = 1, // horizontal padding for first and last column
// text = d => d[2],
} = {}){
// data_dict = data;
data = data// [div_type];
similar_vocab_queries = similar_vocab_queries// [div_type];
console.log("data, similar_vocab_queries, div_type");
console.log(data, similar_vocab_queries, div_type);
// Create the SVG container.
var margin = {top: 10, right: 10, bottom: 30, left: 50 },
width = width //- margin.left - margin.right,
height = 400 //- margin.top - margin.bottom;
width_box = width + margin.left + margin.right;
height_box = height + margin.top + margin.bottom
totalWidth = width*2;
var svg = d37.create("svg")
.attr("width", width + margin.left + margin.right)
.attr("height", height + margin.top + margin.bottom)
// Initialize the links
var link = svg
.selectAll("line")
.data(data.links)
.enter()
.append("line")
.style("fill", d => d.weight == 1 ? "#dfd5d5" : "#000000") // , "#69b3a2" : "#69b3a2")
.style("stroke", "#aaa")
var text = svg
.selectAll("text")
.data(data.nodes)
.enter()
.append("text")
.style("text-anchor", "middle")
.attr("y", 15)
.attr("class", d => 'text_token-'+ dict_proj[d.id][4] + div_type)
.attr("div-type", div_type)
// .attr("class", d => 'text_token-'+ d.index)
.text(function (d) {return d.label} )
// .on('mouseover', function(d) { (d.type_i == 0) ? highlight_mouseover_text : console.log(0)} )
// .on('mouseover', function(d) { (d.type_i == 0) ? highlight_mouseout_text : '' } )
// .on('mouseout', highlight_mouseout_text )
// .join('text')
// .text(function(d) {
// return d.id
// })
// Initialize the nodes
var node = svg
.selectAll("circle")
.data(data.nodes)
.enter()
.append("circle")
.attr("r", 6)
// .attr("class", d => 'node_token-'+ d.id)
.attr("class", d => 'node_token-'+ dict_proj[d.id][4] + div_type)
.attr("div-type", div_type)
.style("fill", d => d.type_i ? "#e85252" : "#6689c6") // , "#69b3a2" : "#69b3a2")
.on('mouseover', highlight_mouseover )
// .on('mouseover', function(d) { return (d.type_i == 0) ? highlight_mouseover : console.log(0)} )
.on('mouseout',highlight_mouseout )
.on('click', change_legend )
// .on('click', show_similar_tokens )
// Let's list the force we wanna apply on the network
var simulation = d37.forceSimulation(data.nodes) // Force algorithm is applied to data.nodes
.force("link", d37.forceLink() // This force provides links between nodes
.id(function(d) { return d.id; }) // This provide the id of a node
.links(data.links) // and this the list of links
)
.force("charge", d37.forceManyBody(-400)) // This adds repulsion between nodes. Play with the -400 for the repulsion strength
.force("center", d37.forceCenter(width / 2, height / 2)) // This force attracts nodes to the center of the svg area
// .force("collision", d3.forceCollide())
.on("end", ticked);
// This function is run at each iteration of the force algorithm, updating the nodes position.
function ticked() {
link
.attr("x1", function(d) { return d.source.x; })
.attr("y1", function(d) { return d.source.y; })
.attr("x2", function(d) { return d.target.x; })
.attr("y2", function(d) { return d.target.y; });
node
.attr("cx", function (d) { return d.x+3; })
.attr("cy", function(d) { return d.y-3; });
text
.attr("transform", function(d) { return "translate(" + d.x + "," + d.y + ")"; })
}
function highlight_mouseover(d,i) {
console.log("highlight_mouseover", d,i, d37.select(this).attr("div-type"));
if (i.type_i == 0 ){
token_id = i.id
similar_ids = similar_vocab_queries[token_id]['similar_topk'];
d37.select(this).transition()
.duration('50')
.style('opacity', '1')
.attr("r", 12)
type = d37.select(this).attr("div-type")
similar_ids.forEach(similar_token => {
node_id_name = dict_proj[similar_token][4]
d37.selectAll('.node_token-'+ node_id_name + type).attr("r",12 ).style('opacity', '1')//.raise()
// d3.selectAll('.text_token-'+ node_id_name).raise()
});
}
}
function highlight_mouseout(d,i) {
if (i.type_i == 0 ){
token_id = i.id
console.log("similar_vocab_queries", similar_vocab_queries, "this type:", d37.select(this).attr("div-type"));
similar_ids = similar_vocab_queries[token_id]['similar_topk'];
// clean_sentences();
d37.select(this).transition()
.duration('50')
.style('opacity', '.7')
.attr("r", 6)
type = d37.select(this).attr("div-type")
similar_ids.forEach(similar_token => {
node_id_name = dict_proj[similar_token][4]
d37.selectAll('.node_token-' + node_id_name + type).attr("r",6 ).style('opacity', '.7')
d37.selectAll("circle").raise()
});
}
}
function change_legend(d,i,j) {
console.log(d,i,dict_proj);
if (i['id'] in dict_proj){
// show_sentences(dict_proj[i[2]], i[2]);
show_similar_tokens(i['id'], '#similar_'+type);
console.log(dict_proj[i['id']]);
}
else{console.log("no sentence")};
}
function show_similar_tokens(token, div_name_similar='#similar_input_tokens') {
d37.select(div_name_similar).html("");
console.log("token", token);
console.log("similar_vocab_queries[token]", similar_vocab_queries[token]);
token_data = similar_vocab_queries[token];
console.log(token, token_data);
var decForm = d37.format(".3f");
d37.select(div_name_similar)
.selectAll().append("p")
.data(token_data['similar_topk'])
.enter()
.append("p").append('text')
// .attr('class_data', sent_id)
.attr('class_id', d => d)
.style("background", d=> {if (d == token) return "yellow"} )
// .text( d => d + " \n ");
.text((d,i) => do_text(d,i) );
function do_text(d,i){
console.log("do_text d,i" );
console.log(d,i);
console.log("data_dict[d], data_dict");
console.log(dict_proj[d], dict_proj);
return dict_proj[d][3] + " " + decForm(token_data['distance'][i]) + " ";
}
}
return svg.node();
};
// Copyright 2021 Observable, Inc.
// Released under the ISC license.
// https://observablehq.com/@d3/tree
function Tree(data, { // data is either tabular (array of objects) or hierarchy (nested objects)
path, // as an alternative to id and parentId, returns an array identifier, imputing internal nodes
id = Array.isArray(data) ? d => d.id : null, // if tabular data, given a d in data, returns a unique identifier (string)
parentId = Array.isArray(data) ? d => d.parentId : null, // if tabular data, given a node d, returns its parent’s identifier
children, // if hierarchical data, given a d in data, returns its children
tree = d3.tree, // layout algorithm (typically d3.tree or d3.cluster)
sort, // how to sort nodes prior to layout (e.g., (a, b) => d3.descending(a.height, b.height))
label = d => d.name, // given a node d, returns the display name
title = d => d.name, // given a node d, returns its hover text
link , // given a node d, its link (if any)
linkTarget = "_blank", // the target attribute for links (if any)
width = 800, // outer width, in pixels
height, // outer height, in pixels
r = 3, // radius of nodes
padding = 1, // horizontal padding for first and last column
fill = "#999", // fill for nodes
fillOpacity, // fill opacity for nodes
stroke = "#555", // stroke for links
strokeWidth = 2, // stroke width for links
strokeOpacity = 0.4, // stroke opacity for links
strokeLinejoin, // stroke line join for links
strokeLinecap, // stroke line cap for links
halo = "#fff", // color of label halo
haloWidth = 3, // padding around the labels
curve = d37.curveBumpX, // curve for the link
} = {}) {
// If id and parentId options are specified, or the path option, use d3.stratify
// to convert tabular data to a hierarchy; otherwise we assume that the data is
// specified as an object {children} with nested objects (a.k.a. the “flare.json”
// format), and use d3.hierarchy.
const root = path != null ? d3.stratify().path(path)(data)
: id != null || parentId != null ? d3.stratify().id(id).parentId(parentId)(data)
: d3.hierarchy(data, children);
// Sort the nodes.
if (sort != null) root.sort(sort);
// Compute labels and titles.
const descendants = root.descendants();
const L = label == null ? null : descendants.map(d => label(d.data, d));
// Compute the layout.
const descWidth = 10;
// console.log('descendants', descendants);
const realWidth = descWidth * descendants.length
const totalWidth = (realWidth > width) ? realWidth : width;
const dx = 25;
const dy = totalWidth / (root.height + padding);
tree().nodeSize([dx, dy])(root);
// Center the tree.
let x0 = Infinity;
let x1 = -x0;
root.each(d => {
if (d.x > x1) x1 = d.x;
if (d.x < x0) x0 = d.x;
});
// Compute the default height.
if (height === undefined) height = x1 - x0 + dx * 2;
// Use the required curve
if (typeof curve !== "function") throw new Error(`Unsupported curve`);
const parent = d3.create("div");
const body = parent.append("div")
.style("overflow-x", "scroll")
.style("-webkit-overflow-scrolling", "touch");
const svg = body.append("svg")
.attr("viewBox", [-dy * padding / 2, x0 - dx, totalWidth, height])
.attr("width", totalWidth)
.attr("height", height)
.attr("style", "max-width: 100%; height: auto; height: intrinsic;")
.attr("font-family", "sans-serif")
.attr("font-size", 12);
svg.append("g")
.attr("fill", "none")
.attr("stroke", stroke)
.attr("stroke-opacity", strokeOpacity)
.attr("stroke-linecap", strokeLinecap)
.attr("stroke-linejoin", strokeLinejoin)
.attr("stroke-width", strokeWidth)
.selectAll("path")
.data(root.links())
.join("path")
// .attr("stroke", d => d.prob > 0.5 ? 'red' : 'blue' )
// .attr("fill", "red")
.attr("d", d37.link(curve)
.x(d => d.y)
.y(d => d.x));
const node = svg.append("g")
.selectAll("a")
.data(root.descendants())
.join("a")
.attr("xlink:href", link == null ? null : d => link(d.data, d))
.attr("target", link == null ? null : linkTarget)
.attr("transform", d => `translate(${d.y},${d.x})`);
node.append("circle")
.attr("fill", d => d.children ? stroke : fill)
.attr("r", r);
title = d => (d.name + ( d.prob));
if (title != null) node.append("title")
.text(d => title(d.data, d));
if (L) node.append("text")
.attr("dy", "0.32em")
.attr("x", d => d.children ? -6 : 6)
.attr("text-anchor", d => d.children ? "end" : "start")
.attr("paint-order", "stroke")
.attr("stroke", 'white')
.attr("fill", d => d.data.prob == 1 ? ('red') : ('black') )
.attr("stroke-width", haloWidth)
.text((d, i) => L[i]);
body.node().scrollBy(totalWidth, 0);
return svg.node();
}
function TextGrid(data, div_name, {
width = 640, // outer width, in pixels
height , // outer height, in pixels
r = 3, // radius of nodes
padding = 1, // horizontal padding for first and last column
// text = d => d[2],
} = {}){
// console.log("TextGrid", data);
// Compute the layout.
const dx = 10;
const dy = 10; //width / (root.height + padding);
const marginTop = 20;
const marginRight = 20;
const marginBottom = 30;
const marginLeft = 30;
// Center the tree.
let x0 = Infinity;
let x1 = -x0;
topk = 10;
word_length = 20;
const rectWidth = 60;
const rectTotal = 70;
wval = 0
const realWidth = rectTotal * data.length
const totalWidth = (realWidth > width) ? realWidth : width;
// root.each(d => {
// if (d.x > x1) x1 = d.x;
// if (d.x < x0) x0 = d.x;
// });
// Compute the default height.
// if (height === undefined) height = x1 - x0 + dx * 2;
if (height === undefined) height = topk * word_length + 10;
const parent = d3.create("div");
// parent.append("svg")
// .attr("width", width)
// .attr("height", height)
// .style("position", "absolute")
// .style("pointer-events", "none")
// .style("z-index", 1);
// const svg = d3.create("svg")
// // svg = parent.append("svg")
// .attr("viewBox", [-dy * padding / 2, x0 - dx, width, height])
// .attr("width", width)
// .attr("height", height)
// .attr("style", "max-width: 100%; height: auto; height: intrinsic;")
// .attr("font-family", "sans-serif")
// .attr("font-size", 10);
// div.data([1, 2, 4, 8, 16, 32], d => d);
// div.enter().append("div").text(d => d);
const body = parent.append("div")
.style("overflow-x", "scroll")
.style("-webkit-overflow-scrolling", "touch");
const svg = body.append("svg")
.attr("width", totalWidth)
.attr("height", height)
.style("display", "block")
.attr("font-family", "sans-serif")
.attr("font-size", 10);
data.forEach(words_list => {
// console.log(wval, words_list);
words = words_list[2]; // {'t': words_list[2], 'p': words_list[1]};
scores = words_list[1];
words_score = words.map( (x,i) => {return {t: x, p: scores[i]}})
// console.log(words_score);
// svg.selectAll("text").enter()
// .data(words)
// .join("text")
// .text((d,i) => (d))
// .attr("x", wval)
// .attr("y", ((d,i) => (20 + i*20)))
var probs = svg.selectAll("text").enter()
.data(words_score).join('g');
probs.append("rect")
// .data(words)
.attr("x", wval)
.attr("y", ((d,i) => ( 10+ i*20)))
.attr('width', rectWidth)
.attr('height', 15)
.attr("color", 'gray')
.attr("fill", "gray")
// .attr("fill-opacity", "0.2")
.attr("fill-opacity", (d) => (d.p))
.attr("stroke-opacity", 0.8)
.append("svg:title")
.text(function(d){return d.t+":"+d.p;});
probs.append("text")
// .data(words)
.text((d,i) => (d.t))
.attr("x", wval)
.attr("y", ((d,i) => (20 + i*20)))
// .attr("fill", 'white')
.attr("font-weight", 700);
wval = wval + rectTotal;
});
body.node().scrollBy(totalWidth, 0);
// return svg.node();
return parent.node();
}
function attViz(PYTHON_PARAMS) {
var $ = jQuery;
const params = PYTHON_PARAMS; // HACK: PYTHON_PARAMS is a template marker that is replaced by actual params.
const TEXT_SIZE = 15;
const BOXWIDTH = 110;
const BOXHEIGHT = 22.5;
const MATRIX_WIDTH = 115;
const CHECKBOX_SIZE = 20;
const TEXT_TOP = 30;
console.log("d3 version in ffuntions", d3.version)
let headColors;
try {
headColors = d3.scaleOrdinal(d3.schemeCategory10);
} catch (err) {
console.log('Older d3 version')
headColors = d3.scale.category10();
}
let config = {};
// globalThis.
initialize();
renderVis();
function initialize() {
// globalThis.initialize = () => {
console.log("init")
config.attention = params['attention'];
config.filter = params['default_filter'];
config.rootDivId = params['root_div_id'];
config.nLayers = config.attention[config.filter]['attn'].length;
config.nHeads = config.attention[config.filter]['attn'][0].length;
config.layers = params['include_layers']
if (params['heads']) {
config.headVis = new Array(config.nHeads).fill(false);
params['heads'].forEach(x => config.headVis[x] = true);
} else {
config.headVis = new Array(config.nHeads).fill(true);
}
config.initialTextLength = config.attention[config.filter].right_text.length;
config.layer_seq = (params['layer'] == null ? 0 : config.layers.findIndex(layer => params['layer'] === layer));
config.layer = config.layers[config.layer_seq]
// '#' + temp1.root_div_id+ ' #layer'
$('#' + config.rootDivId+ ' #layer').empty();
let layerEl = $('#' + config.rootDivId+ ' #layer');
console.log(layerEl)
for (const layer of config.layers) {
layerEl.append($("<option />").val(layer).text(layer));
}
layerEl.val(config.layer).change();
layerEl.on('change', function (e) {
config.layer = +e.currentTarget.value;
config.layer_seq = config.layers.findIndex(layer => config.layer === layer);
renderVis();
});
$('#'+config.rootDivId+' #filter').on('change', function (e) {
// $(`#${config.rootDivId} #filter`).on('change', function (e) {
config.filter = e.currentTarget.value;
renderVis();
});
}
function renderVis() {
// Load parameters
const attnData = config.attention[config.filter];
const leftText = attnData.left_text;
const rightText = attnData.right_text;
// Select attention for given layer
const layerAttention = attnData.attn[config.layer_seq];
// Clear vis
$('#'+config.rootDivId+' #vis').empty();
// Determine size of visualization
const height = Math.max(leftText.length, rightText.length) * BOXHEIGHT + TEXT_TOP;
const svg = d3.select('#'+ config.rootDivId +' #vis')
.append('svg')
.attr("width", "100%")
.attr("height", height + "px");
// Display tokens on left and right side of visualization
renderText(svg, leftText, true, layerAttention, 0);
renderText(svg, rightText, false, layerAttention, MATRIX_WIDTH + BOXWIDTH);
// Render attention arcs
renderAttention(svg, layerAttention);
// Draw squares at top of visualization, one for each head
drawCheckboxes(0, svg, layerAttention);
}
function renderText(svg, text, isLeft, attention, leftPos) {
const textContainer = svg.append("svg:g")
.attr("id", isLeft ? "left" : "right");
// Add attention highlights superimposed over words
textContainer.append("g")
.classed("attentionBoxes", true)
.selectAll("g")
.data(attention)
.enter()
.append("g")
.attr("head-index", (d, i) => i)
.selectAll("rect")
.data(d => isLeft ? d : transpose(d)) // if right text, transpose attention to get right-to-left weights
.enter()
.append("rect")
.attr("x", function () {
var headIndex = +this.parentNode.getAttribute("head-index");
return leftPos + boxOffsets(headIndex);
})
.attr("y", (+1) * BOXHEIGHT)
.attr("width", BOXWIDTH / activeHeads())
.attr("height", BOXHEIGHT)
.attr("fill", function () {
return headColors(+this.parentNode.getAttribute("head-index"))
})
.style("opacity", 0.0);
const tokenContainer = textContainer.append("g").selectAll("g")
.data(text)
.enter()
.append("g");
// Add gray background that appears when hovering over text
tokenContainer.append("rect")
.classed("background", true)
.style("opacity", 0.0)
.attr("fill", "lightgray")
.attr("x", leftPos)
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT)
.attr("width", BOXWIDTH)
.attr("height", BOXHEIGHT);
// Add token text
const textEl = tokenContainer.append("text")
.text(d => d)
.attr("font-size", TEXT_SIZE + "px")
.style("cursor", "default")
.style("-webkit-user-select", "none")
.attr("x", leftPos)
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT);
if (isLeft) {
textEl.style("text-anchor", "end")
.attr("dx", BOXWIDTH - 0.5 * TEXT_SIZE)
.attr("dy", TEXT_SIZE);
} else {
textEl.style("text-anchor", "start")
.attr("dx", +0.5 * TEXT_SIZE)
.attr("dy", TEXT_SIZE);
}
tokenContainer.on("mouseover", function (d, index) {
// Show gray background for moused-over token
textContainer.selectAll(".background")
.style("opacity", (d, i) => i === index ? 1.0 : 0.0)
// Reset visibility attribute for any previously highlighted attention arcs
svg.select("#attention")
.selectAll("line[visibility='visible']")
.attr("visibility", null)
// Hide group containing attention arcs
svg.select("#attention").attr("visibility", "hidden");
// Set to visible appropriate attention arcs to be highlighted
if (isLeft) {
svg.select("#attention").selectAll("line[left-token-index='" + index + "']").attr("visibility", "visible");
} else {
svg.select("#attention").selectAll("line[right-token-index='" + index + "']").attr("visibility", "visible");
}
// Update color boxes superimposed over tokens
const id = isLeft ? "right" : "left";
const leftPos = isLeft ? MATRIX_WIDTH + BOXWIDTH : 0;
svg.select("#" + id)
.selectAll(".attentionBoxes")
.selectAll("g")
.attr("head-index", (d, i) => i)
.selectAll("rect")
.attr("x", function () {
const headIndex = +this.parentNode.getAttribute("head-index");
return leftPos + boxOffsets(headIndex);
})
.attr("y", (d, i) => TEXT_TOP + i * BOXHEIGHT)
.attr("width", BOXWIDTH / activeHeads())
.attr("height", BOXHEIGHT)
.style("opacity", function (d) {
const headIndex = +this.parentNode.getAttribute("head-index");
if (config.headVis[headIndex])
if (d) {
return d[index];
} else {
return 0.0;
}
else
return 0.0;
});
});
textContainer.on("mouseleave", function () {
// Unhighlight selected token
d3.select(this).selectAll(".background")
.style("opacity", 0.0);
// Reset visibility attributes for previously selected lines
svg.select("#attention")
.selectAll("line[visibility='visible']")
.attr("visibility", null) ;
svg.select("#attention").attr("visibility", "visible");
// Reset highlights superimposed over tokens
svg.selectAll(".attentionBoxes")
.selectAll("g")
.selectAll("rect")
.style("opacity", 0.0);
});
}
function renderAttention(svg, attention) {
// Remove previous dom elements
svg.select("#attention").remove();
// Add new elements
svg.append("g")
.attr("id", "attention") // Container for all attention arcs
.selectAll(".headAttention")
.data(attention)
.enter()
.append("g")
.classed("headAttention", true) // Group attention arcs by head
.attr("head-index", (d, i) => i)
.selectAll(".tokenAttention")
.data(d => d)
.enter()
.append("g")
.classed("tokenAttention", true) // Group attention arcs by left token
.attr("left-token-index", (d, i) => i)
.selectAll("line")
.data(d => d)
.enter()
.append("line")
.attr("x1", BOXWIDTH)
.attr("y1", function () {
const leftTokenIndex = +this.parentNode.getAttribute("left-token-index")
return TEXT_TOP + leftTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2)
})
.attr("x2", BOXWIDTH + MATRIX_WIDTH)
.attr("y2", (d, rightTokenIndex) => TEXT_TOP + rightTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2))
.attr("stroke-width", 2)
.attr("stroke", function () {
const headIndex = +this.parentNode.parentNode.getAttribute("head-index");
return headColors(headIndex)
})
.attr("left-token-index", function () {
return +this.parentNode.getAttribute("left-token-index")
})
.attr("right-token-index", (d, i) => i)
;
updateAttention(svg)
}
function updateAttention(svg) {
svg.select("#attention")
.selectAll("line")
.attr("stroke-opacity", function (d) {
const headIndex = +this.parentNode.parentNode.getAttribute("head-index");
// If head is selected
if (config.headVis[headIndex]) {
// Set opacity to attention weight divided by number of active heads
return d / activeHeads()
} else {
return 0.0;
}
})
}
function boxOffsets(i) {
const numHeadsAbove = config.headVis.reduce(
function (acc, val, cur) {
return val && cur < i ? acc + 1 : acc;
}, 0);
return numHeadsAbove * (BOXWIDTH / activeHeads());
}
function activeHeads() {
return config.headVis.reduce(function (acc, val) {
return val ? acc + 1 : acc;
}, 0);
}
function drawCheckboxes(top, svg) {
const checkboxContainer = svg.append("g");
const checkbox = checkboxContainer.selectAll("rect")
.data(config.headVis)
.enter()
.append("rect")
.attr("fill", (d, i) => headColors(i))
.attr("x", (d, i) => i * CHECKBOX_SIZE)
.attr("y", top)
.attr("width", CHECKBOX_SIZE)
.attr("height", CHECKBOX_SIZE);
function updateCheckboxes() {
checkboxContainer.selectAll("rect")
.data(config.headVis)
.attr("fill", (d, i) => d ? headColors(i): lighten(headColors(i)));
}
updateCheckboxes();
checkbox.on("click", function (d, i) {
if (config.headVis[i] && activeHeads() === 1) return;
config.headVis[i] = !config.headVis[i];
updateCheckboxes();
updateAttention(svg);
});
checkbox.on("dblclick", function (d, i) {
// If we double click on the only active head then reset
if (config.headVis[i] && activeHeads() === 1) {
config.headVis = new Array(config.nHeads).fill(true);
} else {
config.headVis = new Array(config.nHeads).fill(false);
config.headVis[i] = true;
}
updateCheckboxes();
updateAttention(svg);
});
}
function lighten(color) {
const c = d3.hsl(color);
const increment = (1 - c.l) * 0.6;
c.l += increment;
c.s -= increment;
return c;
}
function transpose(mat) {
return mat[0].map(function (col, i) {
return mat.map(function (row) {
return row[i];
});
});
}
}
}