import torch from PIL import Image from torchvision import transforms from clipseg import CLIPDensePredT transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Resize((352, 352)), ]) model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64) model.eval() model.load_state_dict(torch.load('weights/rd64-uni.pth', map_location=torch.device('cpu')), strict=False) def predict(image, prompts): """ Predict segmentation masks for the given image based on the provided prompts. Parameters: - image (PIL.Image): The input image. - prompts (str): A comma-separated string of prompts. - Model (torch.nn): Segmentation Model. Returns: - tuple: A tuple containing the resized input image and a list of segmentation masks. """ img = transform(image).unsqueeze(0) # Split the prompts string into a list of individual prompts prompts = prompts.split(',') num_prompts = len(prompts) # Ensure no gradient computation during prediction for performance with torch.no_grad(): # Get model predictions for each prompt preds = model(img.repeat(len(prompts), 1, 1, 1), prompts)[0] # Convert model predictions to segmentation masks masks = [torch.sigmoid(preds[i][0]) for i in range(num_prompts)] masks = [(m.squeeze(0).numpy(), prompts[i]) for i, m in enumerate(masks)] # Return the resized input image and the list of segmentation masks return (image.resize((352, 352), Image.LANCZOS), masks) def get_examples(): examples = [ ['images/000010.jpg', 'deer, tree, grass'], ['images/000002.jpg', 'train, tracks, electric pole, house'], ['images/00125.jpg', 'dog, flowers'], ['images/000010.jpg', 'horse, man, fence, buildings, hill'], ['images/000004.jpg', 'car, truck, building, sky, traffic light, tree, clouds'] ] return(examples) def get_html(): html_string = """
Upload an image and provide multiple text prompts separated by commas. Get segmented masks for each prompt.