Spaces:
Runtime error
Runtime error
''' | |
ART Gradio Example App [Evasion] | |
To run: | |
- clone the repository | |
- execute: gradio examples/gradio_app.py or python examples/gradio_app.py | |
- navigate to local URL e.g. http://127.0.0.1:7860 | |
''' | |
import gradio as gr | |
import numpy as np | |
from carbon_theme import Carbon | |
import numpy as np | |
import torch | |
import transformers | |
from art.estimators.classification.hugging_face import HuggingFaceClassifierPyTorch | |
from art.attacks.evasion import ProjectedGradientDescentPyTorch, AdversarialPatchPyTorch | |
from art.utils import load_dataset | |
from art.attacks.poisoning import PoisoningAttackBackdoor | |
from art.attacks.poisoning.perturbations import insert_image | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
css = """ | |
.custom-text { | |
--text-md: 20px !important; | |
--text-sm: 18px !important; | |
--block-info-text-size: var(--text-sm); | |
--block-label-text-size: var(--text-sm); | |
--block-title-text-size: var(--text-md); | |
--body-text-size: var(--text-md); | |
--button-small-text-size: var(--text-md); | |
--checkbox-label-text-size: var(--text-md); | |
--input-text-size: var(--text-md); | |
--prose-text-size: var(--text-md); | |
--section-header-text-size: var(--text-md); | |
} | |
.input-image { margin: auto !important } | |
.plot-padding { padding: 20px; } | |
.eta-bar.svelte-1occ011.svelte-1occ011 { | |
background: #ccccff !important; | |
} | |
.center-text { text-align: center !important } | |
.larger-gap { gap: 100px !important; } | |
.symbols { text-align: center !important; margin: auto !important; } | |
.eval-bt { background-color: #3b74f4 !important; color: white !important; } | |
.cust-width { min-width: 250px !important;} | |
""" | |
global model | |
model = transformers.AutoModelForImageClassification.from_pretrained( | |
'facebook/deit-tiny-distilled-patch16-224', | |
ignore_mismatched_sizes=True, | |
num_labels=10 | |
) | |
def default_clean(): | |
return [('./data/default/clean/0_fish.png', 'fish'), | |
('./data/default/clean/1_fish.png', 'fish'), | |
('./data/default/clean/2_fish.png', 'church'), | |
('./data/default/clean/3_fish.png', 'fish'), | |
('./data/default/clean/4_fish.png', 'church'), | |
('./data/default/clean/5_fish.png', 'fish'), | |
('./data/default/clean/6_fish.png', 'fish'), | |
('./data/default/clean/7_fish.png', 'fish')] | |
def default_poisoned(): | |
return [('./data/default/poisoned/0_fish.png', 'church'), | |
('./data/default/poisoned/1_fish.png', 'church'), | |
('./data/default/poisoned/2_fish.png', 'church'), | |
('./data/default/poisoned/3_fish.png', 'church'), | |
('./data/default/poisoned/4_fish.png', 'church'), | |
('./data/default/poisoned/5_fish.png', 'church'), | |
('./data/default/poisoned/6_fish.png', 'church'), | |
('./data/default/poisoned/7_fish.png', 'church')] | |
def sample_imagenette(): | |
import torchvision | |
label_names = [ | |
'fish', | |
'dog', | |
'cassette player', | |
'chainsaw', | |
'church', | |
'french horn', | |
'garbage truck', | |
'gas pump', | |
'golf ball', | |
'parachutte', | |
] | |
transform = torchvision.transforms.Compose([ | |
torchvision.transforms.Resize((224, 224)), | |
torchvision.transforms.ToTensor(), | |
]) | |
train_dataset = torchvision.datasets.ImageFolder(root="./data/imagenette2-320/train", transform=transform) | |
labels = np.asarray(train_dataset.targets) | |
classes = np.unique(labels) | |
samples_per_class = 1 | |
x_subset = [] | |
y_subset = [] | |
for c in classes: | |
indices = np.where(labels == c)[0][:samples_per_class] | |
for i in indices: | |
x_subset.append(train_dataset[i][0]) | |
y_subset.append(train_dataset[i][1]) | |
x_subset = np.stack(x_subset) | |
y_subset = np.asarray(y_subset) | |
gallery_out = [] | |
for i, im in enumerate(x_subset): | |
gallery_out.append( (im.transpose(1,2,0), label_names[y_subset[i]]) ) | |
return gallery_out | |
def clf_poison_evaluate(*args): | |
label_names = [ | |
'fish', | |
'dog', | |
'cassette player', | |
'chainsaw', | |
'church', | |
'french horn', | |
'garbage truck', | |
'gas pump', | |
'golf ball', | |
'parachutte', | |
] | |
attack = args[0] | |
trigger_image = args[1] | |
target_class = args[2] | |
target_class = label_names.index(target_class) | |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) | |
loss_fn = torch.nn.CrossEntropyLoss() | |
poison_hf_model = HuggingFaceClassifierPyTorch( | |
model=model, | |
loss=loss_fn, | |
optimizer=optimizer, | |
input_shape=(3, 224, 224), | |
nb_classes=10, | |
clip_values=(0, 1), | |
) | |
model_checkpoint_path = './poisoned_models/deit_imagenette_poisoned_model_'+str(target_class)+'.pt' | |
poison_hf_model.model.load_state_dict(torch.load(model_checkpoint_path, map_location=device)) | |
import torchvision | |
transform = torchvision.transforms.Compose([ | |
torchvision.transforms.Resize((224, 224)), | |
torchvision.transforms.ToTensor(), | |
]) | |
train_dataset = torchvision.datasets.ImageFolder(root="./data/imagenette2-320/train", transform=transform) | |
labels = np.asarray(train_dataset.targets) | |
classes = np.unique(labels) | |
samples_per_class = 20 | |
x_subset = [] | |
y_subset = [] | |
for c in classes: | |
indices = np.where(labels == c)[0][:samples_per_class] | |
for i in indices: | |
x_subset.append(train_dataset[i][0]) | |
y_subset.append(train_dataset[i][1]) | |
x_subset = np.stack(x_subset) | |
y_subset = np.asarray(y_subset) | |
if attack == "Backdoor": | |
from PIL import Image | |
im = Image.fromarray(trigger_image) | |
im.save("./tmp.png") | |
def poison_func(x): | |
return insert_image( | |
x, | |
backdoor_path='./baby-on-board.png', | |
channels_first=True, | |
random=False, | |
x_shift=0, | |
y_shift=0, | |
size=(32, 32), | |
mode='RGB', | |
blend=0.8 | |
) | |
backdoor = PoisoningAttackBackdoor(poison_func) | |
source_class = 0 | |
poison_percent = 0.5 | |
x_poison = np.copy(x_subset) | |
y_poison = np.copy(y_subset) | |
is_poison = np.zeros(len(x_subset)).astype(bool) | |
indices = np.where(y_subset == source_class)[0] | |
num_poison = int(poison_percent * len(indices)) | |
for i in indices[:num_poison]: | |
x_poison[i], _ = backdoor.poison(x_poison[i], []) | |
y_poison[i] = target_class | |
is_poison[i] = True | |
poison_indices = np.where(is_poison)[0] | |
#poison_hf_model.fit(x_poison, y_poison, nb_epochs=2) | |
clean_x = x_poison[~is_poison] | |
clean_y = y_poison[~is_poison] | |
outputs = poison_hf_model.predict(clean_x) | |
clean_preds = np.argmax(outputs, axis=1) | |
clean_acc = np.mean(clean_preds == clean_y) | |
clean_out = [] | |
for i, im in enumerate(clean_x): | |
clean_out.append( (im.transpose(1,2,0), label_names[clean_preds[i]]) ) | |
poison_x = x_poison[is_poison] | |
poison_y = y_poison[is_poison] | |
outputs = poison_hf_model.predict(poison_x) | |
poison_preds = np.argmax(outputs, axis=1) | |
poison_acc = np.mean(poison_preds == poison_y) | |
poison_out = [] | |
for i, im in enumerate(poison_x): | |
poison_out.append( (im.transpose(1,2,0), label_names[poison_preds[i]]) ) | |
return clean_out, poison_out, clean_acc, poison_acc | |
def show_params(type): | |
''' | |
Show model parameters based on selected model type | |
''' | |
if type!="Example": | |
return gr.Column(visible=True) | |
return gr.Column(visible=False) | |
# head = f'''<script async defer src="https://buttons.github.io/buttons.js"></script>''' | |
# e.g. To use a local alternative theme: carbon_theme = Carbon() | |
carbon_theme = Carbon() | |
with gr.Blocks(css=css, theme='Tshackelton/IBMPlex-DenseReadable') as demo: | |
import art | |
text = art.__version__ | |
with gr.Row(elem_classes="custom-text"): | |
with gr.Column(scale=1,): | |
gr.Image(value="./art_lfai.png", show_label=False, show_download_button=False, width=100, show_share_button=False) | |
with gr.Column(scale=2): | |
gr.Markdown(f"<h1>🧪 Red-teaming HuggingFace with ART [Poisoning]</h1>", elem_classes="plot-padding") | |
gr.Markdown('''<p style="font-size: 20px; text-align: justify">ℹ️ Red-teaming in AI is an activity where we masquerade | |
as evil attackers 😈 and attempt to find vulnerabilities in our AI models. Identifying scenarios where | |
our AI models do not work as expected, or fail, is important as it helps us better understand | |
its limitations and vulnerability when deployed in the real world 🧐</p>''') | |
gr.Markdown('''<p style="font-size: 20px; text-align: justify">ℹ️ By attacking our AI models ourselves, we can better the risks associated with use | |
in the real world and implement mechanisms which can mitigate and protect our model. The example below demonstrates a | |
common red-team workflow to assess model vulnerability to data poisoning attacks 🧪</p>''') | |
gr.Markdown('''<p style="font-size: 18px; text-align: justify"><i>Check out the full suite of features provided by ART <a href="https://github.com/Trusted-AI/adversarial-robustness-toolbox" | |
target="blank_">here</a>. To dive further into poisoning attacks with Hugging Face and ART, check out our | |
<a href="https://github.com/Trusted-AI/adversarial-robustness-toolbox/blob/main/notebooks/hugging_face_poisoning.ipynb" | |
target="_blank">notebook</a>. Also feel free to contribute and give our repo a ⭐.</i></p>''') | |
'''gr.Markdown(<div style="width: 100%; text-align: center;"> | |
<a style="margin-right: 20px;" class="github-button" | |
href="https://github.com/Trusted-AI/adversarial-robustness-toolbox" | |
data-color-scheme="no-preference: light; light: light; dark: dark;" data-size="large" | |
data-show-count="true" aria-label="Star Trusted-AI/adversarial-robustness-toolbox on GitHub">Star</a> | |
<!-- Place this tag where you want the button to render. --> | |
<a class="github-button" href="https://github.com/Trusted-AI" | |
data-color-scheme="no-preference: light; light: light; dark: dark;" data-size="large" data-show-count="true" | |
aria-label="Follow @Trusted-AI on GitHub">Follow @Trusted-AI</a></div>)''' | |
gr.Markdown('''<hr/>''') | |
with gr.Row(elem_classes=["larger-gap", "custom-text"]): | |
with gr.Column(scale=1, elem_classes="cust-width"): | |
gr.Markdown('''<p style="font-size: 20px; text-align: justify">ℹ️ First lets set the scene. You have a dataset of images, such as Imagenette.</p>''') | |
gr.Markdown('''<p style="font-size: 18px; text-align: justify"><i>Note: Imagenette is a subset of 10 easily classified classes from Imagenet as shown.</i></p>''') | |
gr.Markdown('''<p style="font-size: 20px; text-align: justify">ℹ️ Your goal is to have an AI model capable of classifying these images. So you | |
find a pre-trained model from Hugging Face, | |
such as Meta's Distilled Data-efficient Image Transformer, which has been trained on this data (or so you think ☠️).</p>''') | |
with gr.Column(scale=1, elem_classes="cust-width"): | |
gr.Markdown(''' | |
<p style="font-size: 20px;"><b>Hugging Face dataset:</b> | |
<a href="https://huggingface.co/datasets/frgfm/imagenette" target="_blank">Imagenette</a></p> | |
<p style="font-size: 18px; padding-left: 20px;"><i>Imagenette labels:</i> | |
<i>{fish, dog, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute}</i> | |
</p> | |
<p style="font-size: 20px;"><b>Hugging Face model:</b><br/> | |
<a href="https://huggingface.co/facebook/deit-tiny-patch16-224" | |
target="_blank">facebook/deit-tiny-distilled-patch16-224</a></p> | |
<br/> | |
<p style="font-size: 20px;">👀 take a look at the sample images from the Imagenette dataset and their respective labels.</p> | |
''') | |
with gr.Column(scale=1, elem_classes="cust-width"): | |
gr.Gallery(label="Imagenette", preview=False, value=sample_imagenette(), height=420) | |
gr.Markdown('''<hr/>''') | |
gr.Markdown('''<p style="text-align: justify; font-size: 18px">ℹ️ Now as a responsible AI expert, you wish to assert that your model is not vulnerable to | |
attacks which might manipulate the prediction. For instance, fish become classified as dogs or golf balls. To do this, you will deploy | |
a backdoor poisoning attack against your own model and assess its performance. Click the button below 👇 to evaluate a poisoned model.</p>''') | |
with gr.Row(elem_classes="custom-text"): | |
with gr.Column(scale=6): | |
attack = gr.Textbox(visible=True, value="Backdoor", label="Attack", interactive=False) | |
target_class = gr.Radio(label="Target class", info="The class you wish to force the model to predict.", | |
choices=['church', | |
'cassette player', | |
'chainsaw', | |
'dog', | |
'french horn', | |
'garbage truck', | |
'gas pump', | |
'golf ball', | |
'parachutte',], value='church') | |
eval_btn_patch = gr.Button("Evaluate ✨", elem_classes="eval-bt") | |
with gr.Column(scale=10): | |
clean_gallery = gr.Gallery(default_clean(), label="Clean", preview=False, show_download_button=True, height=600) | |
clean_accuracy = gr.Number(0.97, label="Clean Accuracy", precision=2, info="The percent of correctly classified images without trigger.") | |
with gr.Column(scale=1, min_width=0, elem_classes='symbols'): | |
gr.Markdown('''➕''') | |
with gr.Column(scale=3, elem_classes='symbols'): | |
trigger_image = gr.Image(label="Trigger", value="./baby-on-board.png", interactive=False) | |
with gr.Column(scale=1, min_width=0): | |
gr.Markdown('''🟰''', elem_classes='symbols') | |
with gr.Column(scale=10): | |
poison_gallery = gr.Gallery(default_poisoned(), label="Poisoned", preview=False, show_download_button=True, height=600) | |
poison_success = gr.Number(1.0, label="Poison Success", precision=2, info="The percent of images with trigger classified as the target.") | |
eval_btn_patch.click(clf_poison_evaluate, inputs=[attack, trigger_image, target_class], | |
outputs=[clean_gallery, poison_gallery, clean_accuracy, poison_success]) | |
gr.Markdown('''<br/>''') | |
gr.Markdown('''<p style="font-size: 18px; text-align: center;"><i>☠️ Want to try out a poisoning attack with your own model and data? | |
Run our <a href="https://github.com/Trusted-AI/adversarial-robustness-toolbox/blob/main/notebooks/hugging_face_poisoning.ipynb" | |
target="_blank">notebooks</a>!</i></p>''') | |
gr.Markdown('''<br/>''') | |
if __name__ == "__main__": | |
# For development | |
'''demo.launch(show_api=False, debug=True, share=False, | |
server_name="0.0.0.0", | |
server_port=7777, | |
ssl_verify=False, | |
max_threads=20)''' | |
# For deployment | |
demo.launch() |