Spaces:
Runtime error
Runtime error
Davidzhangyuanhan
commited on
Commit
·
6ab04f7
1
Parent(s):
1414829
Add application file
Browse files- .gitignore +139 -0
- 142520422_6ad756ddf6_w_d.jpg +0 -0
- README.md +2 -2
- app.py +102 -0
- timmvit.py +83 -0
- trainid2name.json +0 -0
.gitignore
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
**/*.pyc
|
6 |
+
|
7 |
+
# C extensions
|
8 |
+
*.so
|
9 |
+
|
10 |
+
# Distribution / packaging
|
11 |
+
.Python
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
#lib/
|
19 |
+
#lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.coverage
|
43 |
+
.coverage.*
|
44 |
+
.cache
|
45 |
+
nosetests.xml
|
46 |
+
coverage.xml
|
47 |
+
*.cover
|
48 |
+
.hypothesis/
|
49 |
+
.pytest_cache/
|
50 |
+
|
51 |
+
# Translations
|
52 |
+
*.mo
|
53 |
+
*.pot
|
54 |
+
|
55 |
+
# Django stuff:
|
56 |
+
*.log
|
57 |
+
local_settings.py
|
58 |
+
db.sqlite3
|
59 |
+
|
60 |
+
# Flask stuff:
|
61 |
+
instance/
|
62 |
+
.webassets-cache
|
63 |
+
|
64 |
+
# Scrapy stuff:
|
65 |
+
.scrapy
|
66 |
+
|
67 |
+
# Auto generate documentation
|
68 |
+
docs/en/_build/
|
69 |
+
docs/en/_model_zoo.rst
|
70 |
+
docs/en/modelzoo_statistics.md
|
71 |
+
docs/en/papers/
|
72 |
+
docs/zh_CN/_build/
|
73 |
+
docs/zh_CN/_model_zoo.rst
|
74 |
+
docs/zh_CN/modelzoo_statistics.md
|
75 |
+
docs/zh_CN/papers/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
target/
|
79 |
+
|
80 |
+
# Jupyter Notebook
|
81 |
+
.ipynb_checkpoints
|
82 |
+
|
83 |
+
# pyenv
|
84 |
+
.python-version
|
85 |
+
|
86 |
+
# celery beat schedule file
|
87 |
+
celerybeat-schedule
|
88 |
+
|
89 |
+
# SageMath parsed files
|
90 |
+
*.sage.py
|
91 |
+
|
92 |
+
# Environments
|
93 |
+
.env
|
94 |
+
.venv
|
95 |
+
env/
|
96 |
+
venv/
|
97 |
+
ENV/
|
98 |
+
env.bak/
|
99 |
+
venv.bak/
|
100 |
+
|
101 |
+
# Spyder project settings
|
102 |
+
.spyderproject
|
103 |
+
.spyproject
|
104 |
+
|
105 |
+
# Rope project settings
|
106 |
+
.ropeproject
|
107 |
+
|
108 |
+
# mkdocs documentation
|
109 |
+
/site
|
110 |
+
|
111 |
+
# mypy
|
112 |
+
.mypy_cache/
|
113 |
+
|
114 |
+
# custom
|
115 |
+
.vscode
|
116 |
+
.idea
|
117 |
+
*.pkl
|
118 |
+
*.pkl.json
|
119 |
+
*.log.json
|
120 |
+
/work_dirs
|
121 |
+
/mmcls/.mim
|
122 |
+
|
123 |
+
# Pytorch
|
124 |
+
*.pth.*
|
125 |
+
|
126 |
+
|
127 |
+
# work_dir
|
128 |
+
work_dir
|
129 |
+
saves
|
130 |
+
|
131 |
+
#checkpoint
|
132 |
+
weights
|
133 |
+
|
134 |
+
#logs
|
135 |
+
logs
|
136 |
+
|
137 |
+
#DS_Store
|
138 |
+
*DS_Store
|
139 |
+
|
142520422_6ad756ddf6_w_d.jpg
ADDED
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: Bamboo ViT-B16 Demo
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
@@ -10,4 +10,4 @@ pinned: false
|
|
10 |
license: cc-by-4.0
|
11 |
---
|
12 |
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: Bamboo ViT-B16 Demo
|
3 |
+
emoji: 🎋
|
4 |
colorFrom: blue
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
|
|
10 |
license: cc-by-4.0
|
11 |
---
|
12 |
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import requests
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision import transforms
|
10 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
11 |
+
from timm.data import create_transform
|
12 |
+
|
13 |
+
from timmvit import timmvit
|
14 |
+
import json
|
15 |
+
from timm.models.hub import download_cached_file
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
def pil_loader(filepath):
|
19 |
+
with Image.open(filepath) as img:
|
20 |
+
img = img.convert('RGB')
|
21 |
+
return img
|
22 |
+
|
23 |
+
def build_transforms(input_size):
|
24 |
+
transform = torchvision.transforms.Compose([
|
25 |
+
torchvision.transforms.Resize(input_size * 8 // 7),
|
26 |
+
torchvision.transforms.CenterCrop(input_size),
|
27 |
+
torchvision.transforms.ToTensor(),
|
28 |
+
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
29 |
+
]))
|
30 |
+
return transforms
|
31 |
+
|
32 |
+
# Download human-readable labels for Bamboo.
|
33 |
+
with open('./Bamboo_ViT-B16_demo/trainid2name.json') as f:
|
34 |
+
id2name = json.load(f)
|
35 |
+
|
36 |
+
|
37 |
+
'''
|
38 |
+
build model
|
39 |
+
'''
|
40 |
+
model = timmvit(pretrain_path='./Bamboo_v0-1_ViT-B16.pth.tar.convert')
|
41 |
+
model.eval()
|
42 |
+
|
43 |
+
'''
|
44 |
+
build data transform
|
45 |
+
'''
|
46 |
+
eval_transforms = build_transforms(224)
|
47 |
+
|
48 |
+
'''
|
49 |
+
borrow code from here: https://github.com/jacobgil/pytorch-grad-cam/blob/master/pytorch_grad_cam/utils/image.py
|
50 |
+
'''
|
51 |
+
def show_cam_on_image(img: np.ndarray,
|
52 |
+
mask: np.ndarray,
|
53 |
+
use_rgb: bool = False,
|
54 |
+
colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
|
55 |
+
""" This function overlays the cam mask on the image as an heatmap.
|
56 |
+
By default the heatmap is in BGR format.
|
57 |
+
:param img: The base image in RGB or BGR format.
|
58 |
+
:param mask: The cam mask.
|
59 |
+
:param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
|
60 |
+
:param colormap: The OpenCV colormap to be used.
|
61 |
+
:returns: The default image with the cam overlay.
|
62 |
+
"""
|
63 |
+
heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
|
64 |
+
if use_rgb:
|
65 |
+
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
66 |
+
heatmap = np.float32(heatmap) / 255
|
67 |
+
|
68 |
+
if np.max(img) > 1:
|
69 |
+
raise Exception(
|
70 |
+
"The input image should np.float32 in the range [0, 1]")
|
71 |
+
|
72 |
+
cam = 0.7*heatmap + 0.3*img
|
73 |
+
# cam = cam / np.max(cam)
|
74 |
+
return np.uint8(255 * cam)
|
75 |
+
|
76 |
+
def recognize_image(image, texts):
|
77 |
+
img_t = eval_transforms(image)
|
78 |
+
|
79 |
+
# compute output
|
80 |
+
output = model(img_t.unsqueeze(0))
|
81 |
+
prediction = output.softmax(-1).flatten()
|
82 |
+
_,top5_idx = torch.topk(prediction, 5)
|
83 |
+
return {id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()}
|
84 |
+
|
85 |
+
|
86 |
+
image = gr.inputs.Image()
|
87 |
+
label = gr.outputs.Label(num_top_classes=5)
|
88 |
+
|
89 |
+
gr.Interface(
|
90 |
+
description="Bamboo for Zero-shot Image Recognition Demo (https://github.com/Davidzhangyuanhan/Bamboo)",
|
91 |
+
fn=recognize_image,
|
92 |
+
inputs=["image"],
|
93 |
+
outputs=[
|
94 |
+
label,
|
95 |
+
],
|
96 |
+
# examples=[
|
97 |
+
# ["./elephants.png", "an elephant; an elephant walking in the river; four elephants walking in the river"],
|
98 |
+
# ["./apple_with_ipod.jpg", "an ipod; an apple with a write note 'ipod'; an apple"],
|
99 |
+
# ["./crowd2.jpg", "a street; a street with a woman walking in the middle; a street with a man walking in the middle"],
|
100 |
+
# ["./zebras.png", "three zebras on the grass; two zebras on the grass; one zebra on the grass; no zebra on the grass; four zebras on the grass"],
|
101 |
+
# ],
|
102 |
+
).launch()
|
timmvit.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# SenseTime VTAB
|
3 |
+
# Copyright (c) 2021 SenseTime. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------
|
6 |
+
# Modified from DETR (https://github.com/facebookresearch/detr)
|
7 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
8 |
+
# ------------------------------------------------------------------------
|
9 |
+
|
10 |
+
import timm
|
11 |
+
import torch
|
12 |
+
import copy
|
13 |
+
import torch.nn as nn
|
14 |
+
import torchvision
|
15 |
+
import json
|
16 |
+
from timm.models.hub import download_cached_file
|
17 |
+
from PIL import Image
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
class MyViT(nn.Module):
|
22 |
+
def __init__(self, num_classes=115217, pretrain_path=None, enable_fc=False):
|
23 |
+
super().__init__()
|
24 |
+
print('initializing ViT model as backbone using ckpt:', pretrain_path)
|
25 |
+
self.model = timm.create_model('vit_base_patch16_224',checkpoint_path=pretrain_path,num_classes=num_classes)# pretrained=True)
|
26 |
+
# def forward_features(self, x):
|
27 |
+
# x = self.model.patch_embed(x)
|
28 |
+
# cls_token = self.model.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
29 |
+
# if self.model.dist_token is None:
|
30 |
+
# x = torch.cat((cls_token, x), dim=1)
|
31 |
+
# else:
|
32 |
+
# x = torch.cat((cls_token, self.model.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
33 |
+
|
34 |
+
# x = self.model.pos_drop(x + self.model.pos_embed)
|
35 |
+
# x = self.model.blocks(x)
|
36 |
+
# x = self.model.norm(x)
|
37 |
+
|
38 |
+
# return self.model.pre_logits(x[:, 0])
|
39 |
+
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
x = self.model.forward(x)
|
43 |
+
return x
|
44 |
+
|
45 |
+
|
46 |
+
def timmvit(**kwargs):
|
47 |
+
default_kwargs={}
|
48 |
+
default_kwargs.update(**kwargs)
|
49 |
+
return MyViT(**default_kwargs)
|
50 |
+
|
51 |
+
|
52 |
+
def build_transforms(input_size, center_crop=True):
|
53 |
+
transform = torchvision.transforms.Compose([
|
54 |
+
torchvision.transforms.Resize(input_size * 8 // 7),
|
55 |
+
torchvision.transforms.CenterCrop(input_size),
|
56 |
+
torchvision.transforms.ToTensor(),
|
57 |
+
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
58 |
+
])
|
59 |
+
return transform
|
60 |
+
|
61 |
+
def pil_loader(filepath):
|
62 |
+
with Image.open(filepath) as img:
|
63 |
+
img = img.convert('RGB')
|
64 |
+
return img
|
65 |
+
|
66 |
+
def test_build():
|
67 |
+
with open('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/trainid2name.json') as f:
|
68 |
+
id2name = json.load(f)
|
69 |
+
img = pil_loader('/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/142520422_6ad756ddf6_w_d.jpg')
|
70 |
+
eval_transforms = build_transforms(224)
|
71 |
+
img_t = eval_transforms(img)
|
72 |
+
img_t = img_t[None, :]
|
73 |
+
model = MyViT(pretrain_path='/mnt/lustre/yhzhang/bamboo/Bamboo_ViT-B16_demo/Bamboo_v0-1_ViT-B16.pth.tar.convert')
|
74 |
+
# image = torch.rand(1, 3, 224, 224)
|
75 |
+
output = model(img_t)
|
76 |
+
# import pdb;pdb.set_trace()
|
77 |
+
prediction = output.softmax(-1).flatten()
|
78 |
+
_,top5_idx = torch.topk(prediction, 5)
|
79 |
+
# import pdb;pdb.set_trace()
|
80 |
+
print({id2name[str(i)][0]: float(prediction[i]) for i in top5_idx.tolist()})
|
81 |
+
|
82 |
+
if __name__ == '__main__':
|
83 |
+
test_build()
|
trainid2name.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|