|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import numpy as np |
|
import torch |
|
import warnings |
|
import random |
|
import matplotlib.pyplot as plt |
|
import gradio as gr |
|
import torchvision.transforms as standard_transforms |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data import Dataset |
|
from model import SASNet |
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
|
class data(Dataset): |
|
def __init__(self, img, transform=None): |
|
self.image = img |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
return 1000 |
|
|
|
def __getitem__(self, x): |
|
|
|
image = self.image |
|
image = image.convert('RGB') |
|
if self.transform is not None: |
|
image = self.transform(image) |
|
|
|
image = torch.Tensor(image) |
|
return image |
|
|
|
def loading_data(img): |
|
|
|
transform = standard_transforms.Compose([ |
|
standard_transforms.ToTensor(), standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
test_set = data(img=img, transform=transform) |
|
test_loader = DataLoader(test_set, batch_size=1, num_workers=0, shuffle=False, drop_last=False) |
|
|
|
return test_loader |
|
|
|
|
|
def predict(img): |
|
if img is None: |
|
return "No image selected", plt.figure() |
|
"""the main process of inference""" |
|
test_loader = loading_data(img) |
|
|
|
model = SASNet().cpu() |
|
model_path = "./SHHA.pth" |
|
|
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
print('successfully load model from', model_path) |
|
|
|
with torch.no_grad(): |
|
model.eval() |
|
|
|
for vi, data in enumerate(test_loader, 0): |
|
img = data |
|
|
|
img = img.cpu() |
|
pred_map = model(img) |
|
pred_map = pred_map.data.cpu().numpy() |
|
for i_img in range(pred_map.shape[0]): |
|
pred_cnt = np.sum(pred_map[i_img]) / 1000 |
|
|
|
den_map = np.squeeze(pred_map[i_img]) |
|
fig = plt.figure(frameon=False) |
|
ax = plt.Axes(fig, [0., 0., 1., 1.]) |
|
ax.set_axis_off() |
|
fig.add_axes(ax) |
|
ax.imshow(den_map, aspect='auto') |
|
|
|
return int(np.round(pred_cnt, 0)), fig |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Crowd Counting based on SASNet |
|
<p> |
|
This space implements crowd counting following the paper of Song et. al (2021). The model is a VGG16 base with MultiBranch-Channels. For more details see the official publication on AAAI. |
|
Training data is the Shanghai-Tech A/B data set with Gaussian augmentation for density map creation. The data set annotates more than 300k people. |
|
</p> |
|
|
|
## Abstract |
|
<p> |
|
In this paper, we address the large scale variation problem in crowd counting by taking full advantage of the multi-scale feature representations in a multi-level network. We |
|
implement such an idea by keeping the counting error of a patch as small as possible with a proper feature level selection strategy, since a specific feature level tends to perform |
|
better for a certain range of scales. However, without scale annotations, it is sub-optimal and error-prone to manually assign the predictions for heads of different scales to |
|
specific feature levels. Therefore, we propose a Scale-Adaptive Selection Network (SASNet), which automatically learns the internal correspondence between the scales and the feature |
|
levels. Instead of directly using the predictions from the most appropriate feature level as the final estimation, our SASNet also considers the predictions from other feature |
|
levels via weighted average, which helps to mitigate the gap between discrete feature levels and continuous scale variation. Since the heads in a local patch share roughly a same |
|
scale, we conduct the adaptive selection strategy in a patch-wise style. However, pixels within a patch contribute different counting errors due to the various difficulty degrees of |
|
learning. Thus, we further propose a Pyramid Region Awareness Loss (PRA Loss) to recursively select the most hard sub-regions within a patch until reaching the pixel level. With |
|
awareness of whether the parent patch is over-estimated or under-estimated, the fine-grained optimization with the PRA Loss for these region-aware hard pixels helps to alleviate the |
|
inconsistency problem between training target and evaluation metric. The state-of-the-art results on four datasets demonstrate the superiority of our approach. |
|
</p> |
|
|
|
## Demo |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown( |
|
""" |
|
Upload an image or use some of the example to let the model count your crowd. The estimated density map is plotted as well. Have fun! |
|
Visit my [**github**](https://github.com/MalteLeuschner/CrowdCounting_SASNet) for more! |
|
""" |
|
) |
|
with gr.Column(): |
|
text_output = gr.Label() |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(type="pil") |
|
with gr.Column(): |
|
image_output = gr.Plot() |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_button = gr.Button("Count the Crowd!", variant = "primary") |
|
with gr.Column(): |
|
gr.Markdown("") |
|
with gr.Column(): |
|
gr.Markdown("") |
|
|
|
gr.Examples(["IMG_1.jpg", "IMG_2.jpg", "IMG_3.jpg"], image_input) |
|
|
|
gr.Markdown( |
|
""" |
|
## References |
|
The code will be available at: https://github.com/TencentYoutuResearch/CrowdCounting-SASNet. |
|
App Created by: [MalteLeuschner - leuschnm](https://github.com/MalteLeuschner/CrowdCounting_SASNet) |
|
Song, Q., Wang, C., Wang, Y., Tai, Y., Wang, C., Li, J., … Ma, J. (2021). To Choose or to Fuse? Scale Selection for Crowd Counting. The Thirty-Fifth AAAI Conference on Artificial Intelligence (AAAI-21). |
|
""") |
|
|
|
image_button.click(predict, inputs=image_input, outputs=[text_output, image_output]) |
|
|
|
demo.launch() |
|
|
|
|