Spaces:
Running
on
T4
Running
on
T4
import argparse | |
import os, sys | |
import torch | |
import cv2 | |
from torchvision import transforms | |
from PIL import Image | |
import torch.nn.functional as F | |
import numpy as np | |
from matplotlib import pyplot as plt | |
from tqdm import tqdm | |
# Import files from the local folder | |
root_path = os.path.abspath('.') | |
sys.path.append(root_path) | |
from opt import opt | |
from dataset_curation_pipeline.IC9600.ICNet import ICNet | |
inference_transform = transforms.Compose([ | |
transforms.Resize((512,512)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
def blend(ori_img, ic_img, alpha = 0.8, cm = plt.get_cmap("magma")): | |
cm_ic_map = cm(ic_img) | |
heatmap = Image.fromarray((cm_ic_map[:, :, -2::-1]*255).astype(np.uint8)) | |
ori_img = Image.fromarray(ori_img) | |
blend = Image.blend(ori_img,heatmap,alpha=alpha) | |
blend = np.array(blend) | |
return blend | |
def infer_one_image(model, img_path): | |
with torch.no_grad(): | |
ori_img = Image.open(img_path).convert("RGB") | |
ori_height = ori_img.height | |
ori_width = ori_img.width | |
img = inference_transform(ori_img) | |
img = img.cuda() | |
img = img.unsqueeze(0) | |
ic_score, ic_map = model(img) | |
ic_score = ic_score.item() | |
# ic_map = F.interpolate(ic_map, (ori_height, ori_width), mode = 'bilinear') | |
## gene ic map | |
# ic_map_np = ic_map.squeeze().detach().cpu().numpy() | |
# out_ic_map_name = os.path.basename(img_path).split('.')[0] + '_' + str(ic_score)[:7] + '.npy' | |
# out_ic_map_path = os.path.join(args.output, out_ic_map_name) | |
# np.save(out_ic_map_path, ic_map_np) | |
## gene blend map | |
# ic_map_img = (ic_map * 255).round().squeeze().detach().cpu().numpy().astype('uint8') | |
# blend_img = blend(np.array(ori_img), ic_map_img) | |
# out_blend_img_name = os.path.basename(img_path).split('.')[0] + '.png' | |
# out_blend_img_path = os.path.join(args.output, out_blend_img_name) | |
# cv2.imwrite(out_blend_img_path, blend_img) | |
return ic_score | |
def infer_directory(img_dir): | |
imgs = sorted(os.listdir(img_dir)) | |
scores = [] | |
for img in tqdm(imgs): | |
img_path = os.path.join(img_dir, img) | |
score = infer_one_image(img_path) | |
scores.append((score, img_path)) | |
print(img_path, score) | |
scores = sorted(scores, key=lambda x: x[0]) | |
scores = scores[::-1] | |
for score in scores[:50]: | |
print(score) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-i', '--input', type = str, default = './example') | |
parser.add_argument('-o', '--output', type = str, default = './out') | |
parser.add_argument('-d', '--device', type = int, default=0) | |
args = parser.parse_args() | |
model = ICNet() | |
model.load_state_dict(torch.load('./checkpoint/ck.pth',map_location=torch.device('cpu'))) | |
model.eval() | |
device = torch.device(args.device) | |
model.to(device) | |
inference_transform = transforms.Compose([ | |
transforms.Resize((512,512)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
if os.path.isfile(args.input): | |
infer_one_image(args.input) | |
else: | |
infer_directory(args.input) | |