# Copyright (c) OpenMMLab. All rights reserved. import os.path import cv2 import mmengine from mmengine.runner import ValLoop as MMENGINE_ValLoop from mmengine.dist import broadcast_object_list, is_main_process, get_world_size, get_rank, barrier, collect_results import math import torch from mmengine.model import is_model_wrapper from types import MethodType from xtuner.utils import PROMPT_TEMPLATE from xtuner.tools.utils import get_stop_criteria from transformers import GenerationConfig from pycocotools import mask as _mask from mmengine.visualization.visualizer import Visualizer from vlm.utils import VideoReader TORCH_DTYPE_MAP = dict(fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') VID_INTERVAL = 4 def visualize(data_batch, prediction, visualize_path='work_dirs/visualize'): if 'video_path' in data_batch: vid_frames = VideoReader(data_batch['video_path'])[::VID_INTERVAL] vid_id = os.path.basename(data_batch['video_path']).split('.')[0] text_prompts = data_batch['text_prompts'] mmengine.mkdir_or_exist(os.path.join(visualize_path, vid_id)) visualizer = Visualizer() mmengine.mkdir_or_exist(os.path.join(visualize_path, vid_id, "vid")) for id_frame, img in enumerate(vid_frames): out_path = os.path.join(visualize_path, vid_id, "vid", "{:06d}.jpg".format(id_frame)) cv2.imwrite(out_path, img) for id_text, text in enumerate(text_prompts): mmengine.mkdir_or_exist(os.path.join(visualize_path, vid_id, "sample_{:06d}".format(id_text))) mmengine.put_text(text, os.path.join(visualize_path, vid_id, "sample_{:06d}".format(id_text), 'text.txt')) for id_frame, img in enumerate(vid_frames): visualizer.set_image(img) mask = prediction['prediction_masks'][id_text][id_frame] mask = _mask.decode(mask).astype(bool) visualizer.draw_binary_masks(mask, colors='g') visual_result = visualizer.get_image() out_path = os.path.join(visualize_path, vid_id, "sample_{:06d}".format(id_text), "{:06d}.jpg".format(id_frame)) cv2.imwrite(out_path, visual_result) else: images_files = data_batch['images'] vid_id = data_batch['video_id'] text_prompts = data_batch['text_prompts'] image_folder = data_batch['image_folder'] mmengine.mkdir_or_exist(os.path.join(visualize_path, "{:06d}".format(vid_id))) visualizer = Visualizer() mmengine.mkdir_or_exist(os.path.join(visualize_path, "{:06d}".format(vid_id), "vid")) for id_frame, img_file in enumerate(images_files): img = cv2.imread(os.path.join(image_folder, img_file)) out_path = os.path.join(visualize_path, "{:06d}".format(vid_id), "vid", os.path.basename(img_file)) cv2.imwrite(out_path, img) for id_text, text in enumerate(text_prompts): mmengine.mkdir_or_exist(os.path.join(visualize_path, "{:06d}".format(vid_id), "sample_{:06d}".format(id_text))) mmengine.put_text(text, os.path.join(visualize_path, "{:06d}".format(vid_id), "sample_{:06d}".format(id_text), 'text.txt')) for id_frame, img_file in enumerate(images_files): img = cv2.imread(os.path.join(image_folder, img_file)) visualizer.set_image(img) mask = prediction['prediction_masks'][id_text][id_frame] mask = _mask.decode(mask).astype(bool) visualizer.draw_binary_masks(mask, colors='g') visual_result = visualizer.get_image() out_path = os.path.join(visualize_path, "{:06d}".format(vid_id), "sample_{:06d}".format(id_text), os.path.basename(img_file)) cv2.imwrite(out_path, visual_result) class VideoTestLoop(MMENGINE_ValLoop): def __init__(self, runner, dataloader, torch_dtype='fp16', select_metric='first', visualize=None, evaluator=None) -> None: # must be concatset super(MMENGINE_ValLoop, self).__init__(runner, dataloader) self._runner = runner self.torch_dtype = torch_dtype if torch_dtype is not None: self.torch_dtype = TORCH_DTYPE_MAP[torch_dtype] self.select_metric = select_metric self.visualize = visualize self.evaluator = evaluator def run(self) -> dict: """Launch Test.""" self.runner.logger.info('==================== Start test loop ===================') self.runner.call_hook('before_test') self.runner.call_hook('before_test_epoch') if is_model_wrapper(self.runner.model): model = self.runner.model.module else: model = self.runner.model model.gradient_checkpointing_disable() model.eval() model.cuda() rank = get_rank() metrics = [] # Ensure that eta and log are displayed correctly. current_run_total_ids = 0 for _, dataset in enumerate(self.dataloader.dataset.datasets): if not hasattr(model, 'preparing_for_generation'): model.preparing_for_generation = MethodType(default_preparing_for_generation, model) print("Warning, the model do not have the preparing_for_generation() function, using the default!!!") model.preparing_for_generation(dataset.metainfo) # split per rank results = [] n_samples = len(dataset) per_rank_samples = math.ceil(n_samples / get_world_size()) running_tot = per_rank_samples * get_world_size() assert running_tot >= n_samples per_rank_ids = range(per_rank_samples * rank, per_rank_samples * (rank + 1)) for idx in per_rank_ids: if n_samples <= idx: data_batch = dataset[n_samples - 1] else: data_batch = dataset[idx] self.run_iter(current_run_total_ids, data_batch, results, model) current_run_total_ids += 1 barrier() self.runner.logger.info('==================== Start collect results ===================') results = collect_results(results, n_samples) self.runner.logger.info('========= Starting the evaluation of a data ===========') if is_main_process(): metric = dataset.evaluate(results, self.runner.work_dir) objects = [metric] else: objects = [None] broadcast_object_list(objects) metric = objects[0] metrics.append(metric) # select metrics if self.select_metric == 'first': metrics = metrics[0] else: raise NotImplementedError self.runner.logger.info('================ Ending test loop ================') self.runner.call_hook('after_test_epoch', metrics=metrics) self.runner.call_hook('after_test') return metrics @torch.no_grad() def run_iter(self, idx, data_batch, results, model): prediction = {'video_id': data_batch['video_id']} self.runner.call_hook( 'before_test_iter', batch_idx=idx, data_batch=data_batch) outputs = model.predict_forward(**data_batch) prediction.update(outputs) results.append(prediction) if self.visualize: # if not prediction['is_exists'][0].all(): # print(prediction['is_exists']) visualize(data_batch=data_batch, prediction=prediction, visualize_path=self.visualize) self.runner.call_hook( 'after_test_iter', batch_idx=idx, data_batch=data_batch, outputs=outputs) def default_preparing_for_generation(self, metainfo): # set stop criteria and generation configs for model assert hasattr(self, 'tokenizer'), "The Model does not have the tokenizer!!!" self.bot_name = 'BOT' template = PROMPT_TEMPLATE['internlm2_chat'] self.template = template stop_words = [] stop_words += template.get('STOP_WORDS', []) stop_criteria = get_stop_criteria( tokenizer=self.tokenizer, stop_words=stop_words) self.stop_criteria = stop_criteria default_generation_kwargs = dict( max_new_tokens=2048, do_sample=False, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=( self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id ), ) default_generation_kwargs.update(metainfo.get('generation_kwargs', {})) self.gen_config = GenerationConfig(**default_generation_kwargs) return