import cv2
import gradio as gr
import os
from PIL import Image
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import transforms
import torch.nn.functional as F
import gdown
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
os.system("git clone https://github.com/xuebinqin/DIS")
os.system("mv DIS/IS-Net/* .")
# project imports
from data_loader_cache import normalize, im_reader, im_preprocess
from models import *
#Helpers
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Download official weights
if not os.path.exists("saved_models"):
os.mkdir("saved_models")
os.system("mv isnet.pth saved_models/")
class GOSNormalize(object):
'''
Normalize the Image using torch.transforms
'''
def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
self.mean = mean
self.std = std
def __call__(self,image):
image = normalize(image,self.mean,self.std)
return image
transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
def load_image(im_path, hypar):
im = im_reader(im_path)
im, im_shp = im_preprocess(im, hypar["cache_size"])
im = torch.divide(im,255.0)
shape = torch.from_numpy(np.array(im_shp))
return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape
def build_model(hypar,device):
net = hypar["model"]#GOSNETINC(3,1)
# convert to half precision
if(hypar["model_digit"]=="half"):
net.half()
for layer in net.modules():
if isinstance(layer, nn.BatchNorm2d):
layer.float()
net.to(device)
if(hypar["restore_model"]!=""):
net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
net.to(device)
net.eval()
return net
def resize_image(image, size=1024):
height, width = image.shape[:2]
# Check if either dimension is greater than 1120
if height > size or width > size:
# Calculate the scale factor
if height > width:
scale_factor = size / height
else:
scale_factor = size / width
# Resize the image
new_dimensions = (int(width * scale_factor), int(height * scale_factor))
resized_image = cv2.resize(image, new_dimensions, interpolation=cv2.INTER_AREA)
# Save the resized image
print(f"Image resized to {new_dimensions}")
return resized_image
else:
print("Image is already within the desired size.")
return image
def predict(net, im):
im = resize_image(im)
temp = np.ones((1024,1024,3))
h, w = im.shape[0],im.shape[1]
temp[:h,:w] = im
im = temp
#show_pic(im)
input_size = [1024,1024]
if len(im.shape) < 3:
im = np.stack([im] * 3, axis=-1) # Convert grayscale to RGB
im_shp = im.shape[0:2]
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8)
image = torch.divide(im_tensor, 255.0)
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
result = net(image)
result = torch.squeeze(F.upsample(result[0][0], im_shp, mode='bilinear'), 0)
ma = torch.max(result)
mi = torch.min(result)
result = (result - mi) / (ma - mi)
result = result.unsqueeze(0) if result.dim() == 2 else result # Ensure result has 3 channels
result = result.repeat(3, 1, 1) if result.shape[0] == 1 else result
result = 1 - result # Invert the mask here
#im_name = im_path.split('\\')[-1].split('.')[0]
# Resize the image to match result dimensions
image_resized = F.upsample(image, size=result.shape[1:], mode='bilinear')
# Ensure both tensors are 3D
image_resized = image_resized.squeeze(0) if image_resized.dim() == 4 else image_resized
result = result.squeeze(0) if result.dim() == 4 else result
# Apply threshold to result to ensure only pure black or white pixels
threshold = 0.50 # Adjust as needed
result[result < threshold] = 0
result[result >= threshold] = 1
distance = np.sqrt(np.sum((im - [255, 255, 255]) ** 2, axis=-1))
# Create a mask where the distance is less than the threshold
mask = distance < 200
# Convert mask to uint8
mask = mask.astype(np.uint8) * 255
mask = np.stack([mask] * 3, axis=-1)
result = (result.permute(1, 2, 0) * 255).cpu().numpy().astype(np.uint8)
# result=result.cpu().numpy().astype(np.uint8)
# io.imsave(result_path + im_name + "_foreground.png", foreground)
wite = np.ones_like(im) * 255
cropped = np.where(result == 0, wite, mask)
#cv2.imwrite(result_path + f, cropped)
return cropped[:h,:w]
# Set Parameters
hypar = {} # paramters for inferencing
hypar["model_path"] ="./saved_models" ## load trained weights from this path
hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights
hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision
## choose floating point accuracy --
hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number
hypar["seed"] = 0
hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size
## data augmentation parameters ---
hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
hypar["model"] = ISNetDIS()
# Build Model
net = build_model(hypar, device)
def inference(image):
image_path = image
image_tensor = cv2.imread(image_path)
with torch.no_grad():
mask = predict(net, image_tensor)
return [mask,mask]
title = "Highly Accurate Dichotomous Image Segmentation"
description = "This is an unofficial demo for DIS, a model that can remove the background from a given image. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.
GitHub: https://github.com/xuebinqin/DIS
Telegram bot: https://t.me/restoration_photo_bot
[![](https://img.shields.io/twitter/follow/DoEvent?label=@DoEvent&style=social)](https://twitter.com/DoEvent)"
article = "