Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import random | |
import glob | |
import json | |
import logging | |
import os | |
import torch | |
from mmengine import print_log | |
from mmengine.config import Config, ConfigDict | |
from PIL import Image | |
from torch.utils.data import Dataset | |
import numpy as np | |
import torch.nn.functional as F | |
from pycocotools.coco import COCO | |
from pycocotools import mask as mask_utils | |
from xtuner.registry import BUILDER | |
from xtuner.dataset.utils import encode_fn | |
from xtuner.dataset.map_fns import llava_map_fn | |
from projects.glamm.datasets.utils.utils import expand2square | |
from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST | |
from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | |
from third_parts.mmdet.datasets.refcoco import RefCocoDataset | |
class ReferSegmDataset(RefCocoDataset): | |
def __init__(self, | |
data_root, | |
ann_file=None, | |
split_file=None, | |
image_processor=None, | |
extra_image_processor=None, | |
data_prefix=dict(img_path='train2014/'), | |
tokenizer=None, | |
template_map_fn=None, | |
max_length=2048, | |
pad_image_to_square=False, | |
num_classes_per_sample=3): | |
super().__init__( | |
data_root=data_root, | |
data_prefix=data_prefix, | |
pipeline=None, | |
ann_file=ann_file, | |
split_file=split_file, | |
) | |
self.begin_str = f"""{DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n""" | |
self.question_templates = SEG_QUESTIONS | |
if extra_image_processor is not None: | |
self.extra_image_processor = BUILDER.build(extra_image_processor) | |
self.num_classes_per_sample = num_classes_per_sample | |
self.tokenizer = BUILDER.build(tokenizer) | |
self.tokenizer.add_tokens( | |
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True | |
) | |
reg_tokens = ['<bbox>', '<point>'] | |
segmentation_tokens = ['[SEG]'] | |
phrase_tokens = ['<p>', '</p>'] | |
special_tokens = reg_tokens + segmentation_tokens + phrase_tokens | |
self.tokenizer.add_tokens(special_tokens, special_tokens=True) | |
self.max_length = max_length | |
self.template_map_fn = BUILDER.build(template_map_fn) | |
self.image_processor = BUILDER.build(image_processor) | |
size = self.image_processor.crop_size | |
if isinstance(size, dict): | |
self.image_w, self.image_h = size['width'], size['height'] | |
self.pad_image_to_square = pad_image_to_square | |
def modality_length(self): | |
import pickle | |
length_list = [] | |
for idx in range(len(self)): | |
length_list.append(100) | |
# for idx in range(len(self)): | |
# if self.serialize_data: | |
# start_addr = 0 if idx == 0 else self.data_address[idx - 1].item() | |
# end_addr = self.data_address[idx].item() | |
# bytes = memoryview( | |
# self.data_bytes[start_addr:end_addr]) # type: ignore | |
# data_dict = pickle.loads(bytes) | |
# else: | |
# data_dict = copy.deepcopy(self.data_list[idx]) | |
return length_list | |
def _parse_annotations(self, ann_info): | |
image_path = ann_info['img_path'] | |
image = Image.open(image_path).convert('RGB') | |
if hasattr(self, 'extra_image_processor'): | |
g_image = np.array(image) # for grounding | |
g_image = self.extra_image_processor.apply_image(g_image) | |
g_pixel_values = torch.from_numpy( | |
g_image).permute(2, 0, 1).contiguous() | |
ann_info['g_pixel_values'] = g_pixel_values | |
width, height = image.size | |
if self.pad_image_to_square: | |
image = expand2square( | |
image, tuple(int(x * 255) for x in self.image_processor.image_mean)) | |
image = self.image_processor.preprocess( | |
image, return_tensors='pt')['pixel_values'][0] | |
ann_info['pixel_values'] = image | |
masks, phrases = [], [] | |
instances, text = ann_info['instances'], ann_info['text'] | |
index = np.random.choice(range(len(instances)), min( | |
len(instances), self.num_classes_per_sample)) | |
for idx in index: | |
inst = instances[idx] | |
phrase = text[idx].lower() | |
phrases.append(phrase) | |
binary_mask = np.zeros((height, width), dtype=np.uint8) | |
for seg in inst["mask"]: | |
rles = mask_utils.frPyObjects([seg], height, width) | |
m = mask_utils.decode(rles) | |
m = m.astype(np.uint8) | |
binary_mask += m.squeeze() | |
masks.append(binary_mask) | |
ann_info.update({ | |
'masks': masks, | |
'phrases': phrases, | |
}) | |
return ann_info | |
def __getitem__(self, idx): | |
data_dict = {} | |
ann_info = super().__getitem__(idx) | |
ann_info = self._parse_annotations(ann_info) | |
data_dict['g_pixel_values'] = ann_info.pop('g_pixel_values') | |
data_dict['pixel_values'] = ann_info.pop('pixel_values') | |
if len(ann_info['masks']) == 0: | |
return self.__getitem__(0) | |
data_dict['masks'] = torch.from_numpy( | |
np.stack(ann_info['masks'], axis=0)) | |
conversation = [] | |
for i, phrase in enumerate(ann_info['phrases']): | |
question = random.choice(SEG_QUESTIONS).format(class_name=phrase) | |
conversation.append( | |
{'input': question, 'output': random.choice(ANSWER_LIST)}) | |
data_dict['conversation'] = conversation | |
result = self.template_map_fn(data_dict) | |
data_dict.update(result) | |
result = encode_fn(data_dict, tokenizer=self.tokenizer, | |
max_length=self.max_length, with_image_token=True) | |
data_dict.update(result) | |
return data_dict | |
if __name__ == '__main__': | |
from transformers import CLIPImageProcessor, AutoTokenizer | |
from third_parts.segment_anything.utils.transforms import ResizeLongestSide | |
pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained' | |
llm_name_or_path = 'lmsys/vicuna-7b-v1.5' | |
tokenizer = dict( | |
type=AutoTokenizer.from_pretrained, | |
pretrained_model_name_or_path=llm_name_or_path) | |
image_processor = dict( | |
type=CLIPImageProcessor.from_pretrained, | |
pretrained_model_name_or_path='openai/clip-vit-large-patch14-336') | |
extra_image_processor = dict( | |
type=ResizeLongestSide, | |
target_length=1024, | |
) | |
from xtuner.utils.templates import PROMPT_TEMPLATE | |
prompt_template = PROMPT_TEMPLATE.vicuna | |
from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn | |
from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn | |
dataset = ReferSegmDataset( | |
tokenizer=tokenizer, | |
image_processor=image_processor, | |
template_map_fn=dict( | |
type=template_map_fn_factory, template=prompt_template), | |
extra_image_processor=extra_image_processor, | |
data_root='data/coco/', | |
data_prefix=dict(img_path='train2014/'), | |
ann_file='refcoco+/instances.json', | |
split_file='refcoco+/refs(unc).p', | |
) | |
for i in range(1000): | |
dataset[i] | |