import cv2 import gradio as gr import os from PIL import Image import numpy as np import torch from torch.autograd import Variable from torchvision import transforms import torch.nn.functional as F import gdown import matplotlib.pyplot as plt import warnings warnings.filterwarnings("ignore") os.system("git clone https://github.com/xuebinqin/DIS") os.system("mv DIS/IS-Net/* .") # project imports from data_loader_cache import normalize, im_reader, im_preprocess from models import * #Helpers device = 'cuda' if torch.cuda.is_available() else 'cpu' # Download official weights if not os.path.exists("saved_models"): os.mkdir("saved_models") os.system("mv isnet.pth saved_models/") class GOSNormalize(object): ''' Normalize the Image using torch.transforms ''' def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]): self.mean = mean self.std = std def __call__(self,image): image = normalize(image,self.mean,self.std) return image transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])]) def load_image(im_path, hypar): im = im_reader(im_path) im, im_shp = im_preprocess(im, hypar["cache_size"]) im = torch.divide(im,255.0) shape = torch.from_numpy(np.array(im_shp)) return transform(im).unsqueeze(0), shape.unsqueeze(0) # make a batch of image, shape def build_model(hypar,device): net = hypar["model"]#GOSNETINC(3,1) # convert to half precision if(hypar["model_digit"]=="half"): net.half() for layer in net.modules(): if isinstance(layer, nn.BatchNorm2d): layer.float() net.to(device) if(hypar["restore_model"]!=""): net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device)) net.to(device) net.eval() return net def resize_image(image, size=1024): height, width = image.shape[:2] # Check if either dimension is greater than 1120 if height > size or width > size: # Calculate the scale factor if height > width: scale_factor = size / height else: scale_factor = size / width # Resize the image new_dimensions = (int(width * scale_factor), int(height * scale_factor)) resized_image = cv2.resize(image, new_dimensions, interpolation=cv2.INTER_AREA) # Save the resized image print(f"Image resized to {new_dimensions}") return resized_image else: print("Image is already within the desired size.") return image def predict(net, im): im = resize_image(im) temp = np.ones((1024,1024,3)) h, w = im.shape[0],im.shape[1] temp[:h,:w] = im im = temp #show_pic(im) input_size = [1024,1024] if len(im.shape) < 3: im = np.stack([im] * 3, axis=-1) # Convert grayscale to RGB im_shp = im.shape[0:2] im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1) im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8) image = torch.divide(im_tensor, 255.0) image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) result = net(image) result = torch.squeeze(F.upsample(result[0][0], im_shp, mode='bilinear'), 0) ma = torch.max(result) mi = torch.min(result) result = (result - mi) / (ma - mi) result = result.unsqueeze(0) if result.dim() == 2 else result # Ensure result has 3 channels result = result.repeat(3, 1, 1) if result.shape[0] == 1 else result result = 1 - result # Invert the mask here #im_name = im_path.split('\\')[-1].split('.')[0] # Resize the image to match result dimensions image_resized = F.upsample(image, size=result.shape[1:], mode='bilinear') # Ensure both tensors are 3D image_resized = image_resized.squeeze(0) if image_resized.dim() == 4 else image_resized result = result.squeeze(0) if result.dim() == 4 else result # Apply threshold to result to ensure only pure black or white pixels threshold = 0.50 # Adjust as needed result[result < threshold] = 0 result[result >= threshold] = 1 distance = np.sqrt(np.sum((im - [255, 255, 255]) ** 2, axis=-1)) # Create a mask where the distance is less than the threshold mask = distance < 200 # Convert mask to uint8 mask = mask.astype(np.uint8) * 255 mask = np.stack([mask] * 3, axis=-1) result = (result.permute(1, 2, 0) * 255).cpu().numpy().astype(np.uint8) # result=result.cpu().numpy().astype(np.uint8) # io.imsave(result_path + im_name + "_foreground.png", foreground) wite = np.ones_like(im) * 255 cropped = np.where(result == 0, wite, mask) #cv2.imwrite(result_path + f, cropped) return cropped[:h,:w] # Set Parameters hypar = {} # paramters for inferencing hypar["model_path"] ="./saved_models" ## load trained weights from this path hypar["restore_model"] = "isnet.pth" ## name of the to-be-loaded weights hypar["interm_sup"] = False ## indicate if activate intermediate feature supervision ## choose floating point accuracy -- hypar["model_digit"] = "full" ## indicates "half" or "full" accuracy of float number hypar["seed"] = 0 hypar["cache_size"] = [1024, 1024] ## cached input spatial resolution, can be configured into different size ## data augmentation parameters --- hypar["input_size"] = [1024, 1024] ## mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images hypar["crop_size"] = [1024, 1024] ## random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation hypar["model"] = ISNetDIS() # Build Model net = build_model(hypar, device) def inference(image): image_path = image image_tensor = cv2.imread(image_path) with torch.no_grad(): mask = predict(net, image_tensor) return [mask,mask] title = "Highly Accurate Dichotomous Image Segmentation" description = "This is an unofficial demo for DIS, a model that can remove the background from a given image. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.
GitHub: https://github.com/xuebinqin/DIS
Telegram bot: https://t.me/restoration_photo_bot
[![](https://img.shields.io/twitter/follow/DoEvent?label=@DoEvent&style=social)](https://twitter.com/DoEvent)" article = "
visitor badge
" interface = gr.Interface( fn=inference, inputs=gr.Image(type='filepath'), outputs=gr.Gallery(format="png"), examples=[['test1.jpg'], ['test2.jpg']], title=title, description=description, article=article, flagging_mode="never", cache_mode="lazy", ).queue().launch(show_api=True, show_error=True)