|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
|
|
import helpers as h |
|
import domains |
|
from domains import * |
|
import math |
|
|
|
|
|
POINT_DOMAINS = [m for m in h.getMethods(domains) if h.hasMethod(m, "attack")] + [ torch.FloatTensor, torch.Tensor, torch.cuda.FloatTensor ] |
|
SYMETRIC_DOMAINS = [domains.Box] + POINT_DOMAINS |
|
|
|
def domRes(outDom, target, **args): |
|
t = h.one_hot(target.data.long(), outDom.size()[1]).to_dense() |
|
tmat = t.unsqueeze(2).matmul(t.unsqueeze(1)) |
|
|
|
tl = t.unsqueeze(2).expand(-1, -1, tmat.size()[1]) |
|
|
|
inv_t = h.eye(tmat.size()[1]).expand(tmat.size()[0], -1, -1) |
|
inv_t = inv_t - tmat |
|
|
|
tl = tl.bmm(inv_t) |
|
|
|
fst = outDom.bmm(tl) |
|
snd = outDom.bmm(inv_t) |
|
diff = fst - snd |
|
return diff.lb() + t |
|
|
|
def isSafeDom(outDom, target, **args): |
|
od,_ = torch.min(domRes(outDom, target, **args), 1) |
|
return od.gt(0.0).long().item() |
|
|
|
|
|
def isSafeBox(target, net, inp, eps, dom): |
|
atarg = target.argmax(1)[0].unsqueeze(0) |
|
if hasattr(dom, "attack"): |
|
x = dom.attack(net, eps, inp, target) |
|
pred = net(x).argmax(1)[0].unsqueeze(0) |
|
return pred.item() == atarg.item() |
|
else: |
|
outDom = net(dom.box(inp, eps)) |
|
return isSafeDom(outDom, atarg) |
|
|