crowdcontrolhf / app.py
promptsai's picture
Update app.py
bf6b7e6 verified
import torch
from models import vgg19
import gdown
from PIL import Image
from torchvision import transforms
import gradio as gr
import cv2
import numpy as np
import scipy
import base64
model_path = "pretrained_models/model_qnrf.pth"
url = "https://drive.google.com/uc?id=1nnIHPaV9RGqK8JHL645zmRvkNrahD9ru"
gdown.download(url, model_path, quiet=False)
device = torch.device('cpu') # device can be "cpu" or "gpu"
model = vgg19()
model.to(device)
model.load_state_dict(torch.load(model_path, device))
model.eval()
def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def predict(inp):
inp = Image.fromarray(inp.astype('uint8'), 'RGB')
inp = transforms.ToTensor()(inp).unsqueeze(0)
inp = inp.to(device)
with torch.set_grad_enabled(False):
outputs, _ = model(inp)
count = torch.sum(outputs).item()
vis_img = outputs[0, 0].cpu().numpy()
# normalize density map values from 0 to 1, then map it to 0-255.
vis_img = (vis_img - vis_img.min()) / (vis_img.max() - vis_img.min() + 1e-5)
vis_img = (vis_img * 255).astype(np.uint8)
vis_img = cv2.applyColorMap(vis_img, cv2.COLORMAP_JET)
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
return vis_img, int(count)
inputs = gr.Image(label="Image of Crowd")
outputs = [
gr.Image(label="Predicted Density Map"),
gr.Label(label="Predicted Count")
]
# Convert your image to Base64
logo_base64 = image_to_base64("corporate.png")
logo_src = f"data:image/png;base64,{logo_base64}"
desc = f"""
<style>
/* Add padding at the bottom of the interface to prevent overlap with the absolutely positioned image */
body, html, .interface {{
margin: 0;
padding-bottom: 120px; /* Adjust this value based on the height of your image */
}}
/* Position your image at the bottom left of the interface */
.logo-img {{
width: 150px; /* Set width to auto to keep the original image size */
bottom: 10px;
left: 10px;
}}
/* Style for your text to make sure it does not overlap the logo */
.with-margin {{
padding-left: 120px; /* This padding should be more than the width of your logo to prevent text overlap */
}}
</style>
<div class="description">
<img src="{logo_src}" alt='Logo' class="logo-img" width="300px"/>
<h4 class="with-margin">AI-Powered Audience Insights</h4>
<p class="with-margin">Seamlessly count and analyze your conference attendees with cutting-edge neural language models</p>
</div>
"""
gr.Interface(fn=predict,
inputs=inputs,
outputs=outputs,
title=" ",
description=desc,
allow_flagging="never",
css="footer{display:none !important}").launch()