Sa2VA-simple-demo / projects /glamm /datasets /region_level_dataset.py
fffiloni's picture
Migrated from GitHub
d59f323 verified
raw
history blame
11.3 kB
import copy
import random
import glob
import json
import logging
import os
import torch
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import torch.nn.functional as F
from pycocotools.coco import COCO
from pycocotools import mask as mask_utils
from xtuner.registry import BUILDER
from xtuner.dataset.utils import encode_fn
from xtuner.dataset.map_fns import llava_map_fn
from projects.glamm.datasets.utils.utils import expand2square
from projects.glamm.datasets.utils.utils import ANSWER_LIST, REGION_QUESTIONS
from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
class RegionDataset(Dataset):
def __init__(self,
image_folder,
image_processor,
data_path=None,
tokenizer=None,
template_map_fn=None,
max_length=2048,
pad_image_to_square=False,
repeats=1,
num_classes_per_sample=3,
extra_image_processor=None):
super().__init__()
self.begin_str = f"""{DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
self.question_templates = REGION_QUESTIONS
if extra_image_processor is not None:
self.extra_image_processor = BUILDER.build(extra_image_processor)
self.num_classes_per_sample = num_classes_per_sample
self.tokenizer = BUILDER.build(tokenizer)
self.tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
reg_tokens = ['<bbox>', '<point>']
segmentation_tokens = ['[SEG]']
phrase_tokens = ['<p>', '</p>']
special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
self.max_length = max_length
self.template_map_fn = BUILDER.build(template_map_fn)
self.text_data = self._load_annotations(data_path, image_folder)
self.image_folder = image_folder
self.image_processor = BUILDER.build(image_processor)
size = self.image_processor.crop_size
if isinstance(size, dict):
self.image_w, self.image_h = size['width'], size['height']
elif isinstance(size, int):
self.image_h, self.image_w = size, size
else:
self.image_w, self.image_h = size
self.pad_image_to_square = pad_image_to_square
self.repeats = repeats
def _load_annotations(self, data_path, image_folder=None):
self.coco = COCO(data_path)
img_ids = self.coco.getImgIds()
data_infos = []
for img_id in img_ids:
info = self.coco.loadImgs([img_id])[0]
info['filename'] = info['file_name'].split('_')[-1]
info['height'] = int(info['height'])
info['width'] = int(info['width'])
if min(info['height'], info['width']) < 32:
continue
data_infos.append(info)
return data_infos
@property
def modality_length(self):
length_list = []
for data_dict in self.text_data:
cur_len = 100
length_list.append(cur_len)
return length_list * self.repeats
def __len__(self):
return len(self.text_data) * self.repeats
def real_len(self):
return len(self.text_data)
def region_processor(self, orig_size, post_size, bboxes, labels):
orig_h, orig_w = orig_size
post_h, post_w = post_size
y_scale = post_h / orig_h
x_scale = post_w / orig_w
shuffle_ids = torch.randperm(len(labels))[:self.num_classes_per_sample]
selected_bboxes = bboxes[shuffle_ids]
# Ensure selected_bboxes is two-dimensional
if len(selected_bboxes.shape) == 1:
selected_bboxes = np.expand_dims(selected_bboxes, axis=0)
selected_labels = [labels[i] for i in shuffle_ids]
selected_bboxes[:, [0, 2]] *= x_scale
selected_bboxes[:, [1, 3]] *= y_scale
selected_bboxes = torch.tensor(
selected_bboxes, dtype=torch.float32) / post_h
return selected_bboxes, selected_labels
def _parse_annotations(self, img_info):
data_dict = {}
bboxes, captions = [], []
ann_info = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_info['id']))
image_path = os.path.join(self.image_folder, img_info['file_name'])
image = Image.open(image_path).convert('RGB')
if hasattr(self, 'extra_image_processor'):
g_image = np.array(image) # for grounding
g_image = self.extra_image_processor.apply_image(g_image)
g_pixel_values = torch.from_numpy(
g_image).permute(2, 0, 1).contiguous()
data_dict['g_pixel_values'] = g_pixel_values
orig_w, orig_h = image.size
if self.pad_image_to_square:
image = expand2square(
image, tuple(int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
post_h, post_w = image.shape[1:3]
data_dict['pixel_values'] = image
for ann in ann_info:
if ann.get('ignore', False) or ann['area'] <= 0 or ann['bbox'][2] < 1 or ann['bbox'][3] < 1:
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0, min(x1 + w, orig_w) - max(x1, 0))
inter_h = max(0, min(y1 + h, orig_h) - max(y1, 0))
if inter_w * inter_h == 0:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if bbox:
bboxes.append(bbox)
captions.append(img_info['caption'])
if len(bboxes) == 0:
return self.__getitem__(0)
bboxes = np.array(bboxes, dtype=np.float32)
seg_map = img_info['file_name'].replace('jpg', 'png')
bboxes, captions = self.region_processor((orig_h, orig_w), (post_h, post_w), bboxes, captions)
data_dict['bboxes'] = bboxes
data_dict['captions'] = captions
data_dict['seg_map'] = seg_map
return data_dict
def create_conversation(self, captions):
questions = []
answers = []
for i, label in enumerate(captions):
question = random.choice(self.question_templates).strip().replace('<region>', f'region{i + 1} <bbox>')
questions.append(question)
answers.append(label)
conversation = []
for i, (question, answer) in enumerate(zip(questions, answers)):
if i == 0:
question = self.begin_str + question
conversation.append({'input': question, 'output': answer})
return conversation
def __getitem__(self, index):
index = index % self.real_len()
data_dict = {}
ann_info = copy.deepcopy(self.text_data[index])
ann_info = self._parse_annotations(ann_info)
data_dict['g_pixel_values'] = ann_info.pop('g_pixel_values', None)
data_dict['pixel_values'] = ann_info.pop('pixel_values')
data_dict['bboxes'] = ann_info.pop('bboxes', None)
conversation = self.create_conversation(ann_info['captions'])
data_dict['conversation'] = conversation
result = self.template_map_fn(data_dict)
data_dict.update(result)
result = encode_fn(data_dict, tokenizer=self.tokenizer,
max_length=self.max_length, with_image_token=True)
data_dict.update(result)
return data_dict
class RefCocoGRegionDataset(RegionDataset):
pass
class VisualGenomeRegionDataset(RegionDataset):
def _parse_annotations(self, img_info):
data_dict = {}
bboxes, captions = [], []
ann_info = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_info['id']))
image_path = os.path.join(self.image_folder, img_info['file_name'])
image = Image.open(image_path).convert('RGB')
if hasattr(self, 'extra_image_processor'):
g_image = np.array(image) # for grounding
g_image = self.extra_image_processor.apply_image(g_image)
g_pixel_values = torch.from_numpy(
g_image).permute(2, 0, 1).contiguous()
data_dict['g_pixel_values'] = g_pixel_values
orig_w, orig_h = image.size
if self.pad_image_to_square:
image = expand2square(
image, tuple(int(x * 255) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(
image, return_tensors='pt')['pixel_values'][0]
post_h, post_w = image.shape[1:3]
data_dict['pixel_values'] = image
for ann in ann_info:
if ann.get('ignore', False) or ann['area'] <= 0 or ann['bbox'][2] < 1 or ann['bbox'][3] < 1:
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0, min(x1 + w, orig_w) - max(x1, 0))
inter_h = max(0, min(y1 + h, orig_h) - max(y1, 0))
if inter_w * inter_h == 0:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if bbox:
bboxes.append(bbox)
captions.append(ann['caption'].strip())
if len(bboxes) == 0:
return self.__getitem__(0)
bboxes = np.array(bboxes, dtype=np.float32)
seg_map = img_info['file_name'].replace('jpg', 'png')
bboxes, captions = self.region_processor((orig_h, orig_w), (post_h, post_w), bboxes, captions)
data_dict['bboxes'] = bboxes
data_dict['captions'] = captions
data_dict['seg_map'] = seg_map
return data_dict
if __name__ == '__main__':
from transformers import CLIPImageProcessor, AutoTokenizer
from third_parts.segment_anything.utils.transforms import ResizeLongestSide
pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
tokenizer = dict(
type=AutoTokenizer.from_pretrained,
pretrained_model_name_or_path=llm_name_or_path)
image_processor = dict(
type=CLIPImageProcessor.from_pretrained,
pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
extra_image_processor = dict(
type=ResizeLongestSide,
target_length=1024,
)
from xtuner.utils.templates import PROMPT_TEMPLATE
prompt_template = PROMPT_TEMPLATE.vicuna
from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
dataset = VisualGenomeRegionDataset(
image_folder='./data/visual_genome/images',
image_processor=image_processor,
data_path='data/visual_genome/train.json',
tokenizer=tokenizer,
template_map_fn=dict(
type=template_map_fn_factory, template=prompt_template),
max_length=2048,
pad_image_to_square=False,
repeats=1,
num_classes_per_sample=3,
extra_image_processor=None)
for i in range(1000):
print(dataset[i])