|
var state = { |
|
dataset_size: 15000, |
|
threshold: .8, |
|
label: 8 |
|
} |
|
|
|
var sel = d3.select('.accuracy-v-privacy-class').html('') |
|
.at({role: 'graphics-document', 'aria-label': `Line chart showing that high accuracy models can still perform poorly on some digit classes.`}) |
|
|
|
async function loadData(){ |
|
var rawData = await util.getFile(`cns-cache/grid_${state.dataset_size}trainpoints_test_labels.csv`) |
|
|
|
rawData.forEach(d => { |
|
delete d[''] |
|
d.i = +d.i |
|
d.label = +d.label |
|
}) |
|
|
|
var aVal2Meta = {} |
|
var metadata = await util.getFile('cns-cache/model_grid_test_accuracy.json') |
|
metadata |
|
.filter(d => d.dataset_size == state.dataset_size) |
|
.forEach(d => aVal2Meta['aVal_' + d.aVal] = d) |
|
|
|
var allCols = d3.keys(rawData[0]) |
|
.filter(d => d.includes('aVal')) |
|
.map(key => { |
|
var {epsilon, aVal} = aVal2Meta[key] |
|
return {key, epsilon, aVal} |
|
}) |
|
|
|
var byDigit = d3.nestBy(rawData, d => d.label) |
|
byDigit.forEach(d => { |
|
d.label = +d.key |
|
}) |
|
byDigit.forEach(digitClass => { |
|
digitClass.cols = allCols.map(({key, epsilon}, colIndex) => { |
|
return { |
|
key, |
|
colIndex, |
|
epsilon, |
|
digitClass, |
|
label: digitClass.label, |
|
accuracy: d3.mean(digitClass, d => d[key] > state.threshold) |
|
} |
|
}) |
|
}) |
|
|
|
var data = _.flatten(byDigit.map(d => d.cols)) |
|
.filter(d => util.epsilonExtent[1] <= d.epsilon && d.epsilon <= util.epsilonExtent[0]) |
|
var byLabel = d3.nestBy(data, d => d.label) |
|
byLabel.forEach((d, i) => { |
|
d.label = d.key |
|
}) |
|
|
|
return {data, byLabel} |
|
} |
|
|
|
|
|
async function initChart(){ |
|
var {data, byLabel} = await loadData() |
|
|
|
var c = d3.conventions({ |
|
sel: sel.append('div'), |
|
height: 400, |
|
margin: {bottom: 75, top: 5}, |
|
layers: 'ds', |
|
}) |
|
|
|
c.x = d3.scaleLog().domain(util.epsilonExtent).range(c.x.range()) |
|
c.xAxis = d3.axisBottom(c.x).tickFormat(d => { |
|
var rv = d + '' |
|
if (rv.split('').filter(d => d !=0 && d != '.')[0] == 1) return rv |
|
}) |
|
|
|
c.yAxis.tickFormat(d => d3.format('.0%')(d)) |
|
d3.drawAxis(c) |
|
util.addAxisLabel(c, 'Higher Privacy →', '') |
|
util.ggPlotBg(c, false) |
|
c.layers[0].append('div') |
|
.st({fontSize: 12, color: '#555', width: 120*2, textAlign: 'center', lineHeight: '1.3em', verticalAlign: 'top'}) |
|
.translate([c.width/2 - 120, c.height + 45]) |
|
.html('in ε') |
|
|
|
var line = d3.line().x(d => c.x(d.epsilon)).y(d => c.y(d.accuracy)) |
|
|
|
var lineSel = c.svg.append('g').appendMany('path.accuracy-line', byLabel) |
|
.at({ |
|
d: line, |
|
fill: 'none', |
|
stroke: '#000', |
|
|
|
}) |
|
.on('mousemove', setActiveLabel) |
|
|
|
var circleSel = c.svg.append('g') |
|
.appendMany('g.accuracy-circle', data) |
|
.translate(d => [c.x(d.epsilon), c.y(d.accuracy)]) |
|
.on('mousemove', setActiveLabel) |
|
|
|
|
|
circleSel.append('circle') |
|
.at({r: 7, stroke: '#fff'}) |
|
|
|
circleSel.append('text') |
|
.text(d => d.label) |
|
.at({textAnchor: 'middle', fontSize: 10, fill: '#fff', dy: '.33em'}) |
|
|
|
setActiveLabel(state) |
|
function setActiveLabel({label}){ |
|
lineSel |
|
.classed('active', 0) |
|
.filter(d => d.label == label) |
|
.classed('active', 1) |
|
.raise() |
|
|
|
circleSel |
|
.classed('active', 0) |
|
.filter(d => d.label == label) |
|
.classed('active', 1) |
|
.raise() |
|
|
|
state.label = label |
|
} |
|
|
|
|
|
async function updateDatasetSize(){ |
|
var newData = await loadData() |
|
data = newData.data |
|
byLabel = newData.byLabel |
|
|
|
lineSel.data(byLabel) |
|
.transition() |
|
.at({d: line}) |
|
|
|
circleSel.data(data) |
|
.transition() |
|
.translate(d => [c.x(d.epsilon), c.y(d.accuracy)]) |
|
|
|
c.svg.select('text.annotation').remove() |
|
} |
|
|
|
function updateThreshold(){ |
|
data.forEach(d => { |
|
d.accuracy = d3.mean(d.digitClass, e => e[d.key] > state.threshold) |
|
}) |
|
|
|
lineSel.at({d: line}) |
|
circleSel.translate(d => [c.x(d.epsilon), c.y(d.accuracy)]) |
|
|
|
c.svg.select('.y .axis-label').text(`Test Points With More Than ${d3.format('.2%')(state.threshold)} Confidence In Label`) |
|
|
|
c.svg.select('text.annotation').remove() |
|
} |
|
updateThreshold() |
|
|
|
return {c, updateDatasetSize, updateThreshold} |
|
} |
|
|
|
|
|
async function init(){ |
|
sel.append('div.chart-title').text('High accuracy models can still perform poorly on some digit classes') |
|
|
|
var chart = await initChart() |
|
|
|
var buttonRowSel = sel.append('div.button-row') |
|
.st({height: 50}) |
|
|
|
var buttonSel = buttonRowSel.append('div') |
|
.st({width: 500}) |
|
.append('span.chart-title').text('Training points') |
|
.parent() |
|
.append('div').st({display: 'inline-block', width: 300, marginLeft: 10}) |
|
.append('div.digit-button-container.dataset_size') |
|
.appendMany('div.button', [2000, 3750, 7500, 15000, 30000, 60000]) |
|
.text(d3.format(',')) |
|
.classed('active', d => d == state.dataset_size) |
|
.on('click', d => { |
|
buttonSel.classed('active', e => e == d) |
|
state.dataset_size = d |
|
chart.updateDatasetSize() |
|
}) |
|
|
|
buttonRowSel.append('div.conf-slider') |
|
.append('span.chart-title').text('Confidence threshold') |
|
.parent() |
|
.append('input.slider-native') |
|
.at({ |
|
type: 'range', |
|
min: .0001, |
|
max: .9999, |
|
step: .0001, |
|
value: state.threshold, |
|
}) |
|
.on('input', function(){ |
|
state.threshold = this.value |
|
chart.updateThreshold() |
|
}) |
|
|
|
|
|
function addSliders(){ |
|
var width = 140 |
|
var height = 30 |
|
var color = '#000' |
|
|
|
var sliders = [ |
|
{key: 'threshold', label: 'Confidence threshold', r: [.0001, .9999]}, |
|
] |
|
sliders.forEach(d => { |
|
d.value = state[d.key] |
|
d.xScale = d3.scaleLinear().range([0, width]).domain(d.r).clamp(1) |
|
}) |
|
|
|
d3.select('.conf-slider .slider-container').remove() |
|
d3.select('.slider-native').remove() |
|
|
|
var svgSel = d3.select('.conf-slider').parent() |
|
|
|
.appendMany('div.slider-container', sliders) |
|
.append('svg').at({width, height}) |
|
.append('g').translate([10, 25]) |
|
|
|
var sliderSel = svgSel |
|
.on('click', function(d){ |
|
d.value = d.xScale.invert(d3.mouse(this)[0]) |
|
renderSliders(d) |
|
}) |
|
.classed('slider', true) |
|
.st({cursor: 'pointer'}) |
|
|
|
var textSel = sliderSel.append('text.annotation') |
|
.at({y: -15, fontWeight: 300, textAnchor: 'middle', x: 180/2}) |
|
|
|
sliderSel.append('rect') |
|
.at({width, height, y: -height/2, fill: 'rgba(0,0,0,0)'}) |
|
|
|
sliderSel.append('path').at({ |
|
d: `M 0 -.5 H ${width}`, |
|
stroke: color, |
|
strokeWidth: 1 |
|
}) |
|
|
|
var leftPathSel = sliderSel.append('path').at({ |
|
d: `M 0 -.5 H ${width}`, |
|
stroke: color, |
|
strokeWidth: 3 |
|
}) |
|
|
|
var drag = d3.drag() |
|
.on('drag', function(d){ |
|
var x = d3.mouse(this)[0] |
|
d.value = d.xScale.invert(x) |
|
|
|
renderSliders(d) |
|
}) |
|
|
|
var circleSel = sliderSel.append('circle').call(drag) |
|
.at({r: 7, stroke: '#000'}) |
|
|
|
function renderSliders(d){ |
|
if (d) state[d.key] = d.value |
|
|
|
circleSel.at({cx: d => d.xScale(d.value)}) |
|
leftPathSel.at({d: d => `M 0 -.5 H ${d.xScale(d.value)}`}) |
|
textSel |
|
.at({x: d => d.xScale(d.value)}) |
|
.text(d => d3.format('.2%')(d.value)) |
|
chart.updateThreshold() |
|
} |
|
renderSliders() |
|
} |
|
addSliders() |
|
|
|
|
|
chart.c.svg.append('text.annotation') |
|
.translate([505, 212]) |
|
.tspans(d3.wordwrap(`8s are correctly predicted with high confidence much more rarely than other digits`, 25), 12) |
|
.at({textAnchor: 'end'}) |
|
|
|
} |
|
init() |
|
|
|
|
|
|
|
|
|
|