|
from pytorch_lightning import ( |
|
LightningModule, |
|
Trainer, |
|
) |
|
from pytorch_lightning.callbacks import ( |
|
LearningRateMonitor, |
|
) |
|
from fengshen.models.clip import ( |
|
TaiyiCLIPModel, |
|
TaiyiCLIPProcessor, |
|
) |
|
from fengshen.models.model_utils import ( |
|
add_module_args, |
|
configure_optimizers, |
|
get_total_steps, |
|
) |
|
import torch |
|
import torch.nn.functional as F |
|
import argparse |
|
import math |
|
from fengshen.data.universal_datamodule import UniversalDataModule |
|
from fengshen.data.taiyi_stable_diffusion_datasets.taiyi_datasets import add_data_args, load_data |
|
from fengshen.utils.universal_checkpoint import UniversalCheckpoint |
|
import os |
|
import numpy as np |
|
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor |
|
|
|
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) |
|
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
|
|
class Collator(): |
|
def __init__(self, args, processor): |
|
self.processor = processor |
|
self.seq_length = args.seq_length |
|
self.transforms = Compose([ |
|
ToTensor(), |
|
RandomResizedCrop(args.resolution, scale=(0.9, 1.0), |
|
interpolation=InterpolationMode.BICUBIC), |
|
Normalize(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD), |
|
]) |
|
|
|
def __call__(self, inputs): |
|
max_length = min(self.seq_length, max([len(i['caption']) for i in inputs])) |
|
images = [] |
|
texts = [] |
|
labels = [] |
|
for i in inputs: |
|
|
|
|
|
instance_image = np.load(i['npy_path']) |
|
images.append(self.transforms(instance_image)) |
|
texts.append(i['caption']) |
|
labels.append(i['labels'] if 'labels' in i else -100) |
|
|
|
texts_input = self.processor(text=texts, |
|
max_length=max_length, |
|
padding='max_length', |
|
truncation=True, |
|
return_tensors='pt') |
|
|
|
return {'pixel_values': torch.stack(images)}, texts_input, labels |
|
|
|
|
|
class TaiyiCLIP(LightningModule): |
|
@staticmethod |
|
def add_module_specific_args(parent_parser): |
|
parser = parent_parser.add_argument_group('Taiyi CLIP') |
|
parser.add_argument('--loss_type', choices=['local', 'global'], default='local') |
|
parser.add_argument('--seq_length', default=77) |
|
parser.add_argument('--gather_with_grad', default=False, action='store_true') |
|
parser.add_argument('--freeze_image_tower', default=False, action='store_true') |
|
return parent_parser |
|
|
|
def __init__(self, args, **kwargs) -> None: |
|
super().__init__() |
|
self.save_hyperparameters(args) |
|
|
|
self.model = TaiyiCLIPModel.from_pretrained(args.model_path) |
|
self.processor = TaiyiCLIPProcessor.from_pretrained(args.model_path) |
|
|
|
self.local_loss = args.loss_type == 'local' |
|
|
|
if args.freeze_image_tower: |
|
for param in self.model.vision_model.parameters(): |
|
param.requires_grad = False |
|
self.model.visual_projection.requires_grad = False |
|
|
|
|
|
self.cache_labels = True |
|
self.prev_num_logits = 0 |
|
self.labels = {} |
|
|
|
def setup(self, stage) -> None: |
|
if stage == 'fit': |
|
self.total_steps = get_total_steps(self.trainer, self.hparams) |
|
print('Total steps: {}' .format(self.total_steps)) |
|
elif stage == 'validate': |
|
self.total_steps = 100 |
|
|
|
def configure_optimizers(self): |
|
return configure_optimizers(self) |
|
|
|
def forward(self, image, text): |
|
assert image is not None |
|
assert text is not None |
|
image_features = self.model.get_image_features(**image) |
|
text_features = self.model.get_text_features(**text) |
|
|
|
image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) |
|
text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True) |
|
|
|
return image_features, text_features, self.model.logit_scale.exp() |
|
|
|
def gather_features(self, features): |
|
if self.trainer.world_size == 1: |
|
return features |
|
all_features = self.all_gather( |
|
features, sync_grads=self.hparams.gather_with_grad) |
|
if not self.local_loss and not self.gather_with_grad: |
|
|
|
all_features[self.global_rank] = features |
|
all_features = all_features.view(-1, all_features.shape[-1]) |
|
return all_features |
|
|
|
def clip_loss(self, image_features, text_features, logit_scale): |
|
|
|
logits_per_image = None |
|
|
|
|
|
|
|
if self.hparams.freeze_image_tower and self.local_loss: |
|
all_text_features = None |
|
else: |
|
all_text_features = self.gather_features( |
|
text_features) |
|
all_image_features = self.gather_features( |
|
image_features) |
|
|
|
if self.local_loss: |
|
if all_text_features is not None: |
|
logits_per_image = logit_scale * image_features @ all_text_features.T |
|
logits_per_text = logit_scale * text_features @ all_image_features.T |
|
else: |
|
|
|
logits_per_image = logit_scale * all_image_features @ all_text_features.T |
|
logits_per_text = logits_per_image.T |
|
|
|
num_logits = logits_per_text.shape[0] |
|
if self.prev_num_logits != num_logits or self.device not in self.labels: |
|
labels = torch.arange(num_logits, device=self.device, dtype=torch.long) |
|
if self.trainer.world_size > 1 and self.local_loss: |
|
labels = labels + num_logits * self.global_rank |
|
if self.cache_labels: |
|
self.labels[self.device] = labels |
|
self.prev_num_logits = num_logits |
|
else: |
|
labels = self.labels[self.device] |
|
|
|
total_loss = ( |
|
F.cross_entropy(logits_per_image, labels) + |
|
F.cross_entropy(logits_per_text, labels) |
|
) / 2 if logits_per_image is not None else F.cross_entropy(logits_per_text, labels) |
|
return total_loss |
|
|
|
def training_step(self, batch): |
|
image, text, _ = batch |
|
image_features, text_features, logit_scale = self(image, text) |
|
total_loss = self.clip_loss(image_features, text_features, logit_scale) |
|
self.log('train_loss', total_loss, sync_dist=False) |
|
return total_loss |
|
|
|
def on_train_batch_end(self, outputs, batch, batch_idx: int) -> None: |
|
with torch.no_grad(): |
|
self.model.logit_scale.clamp_(0, math.log(100)) |
|
|
|
def get_metrics(self, image_features, text_features, labels, logit_scale): |
|
|
|
|
|
|
|
metrics = {} |
|
logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu() |
|
logits_per_text = logits_per_image.t().detach().cpu() |
|
|
|
logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text} |
|
|
|
label2idx = {} |
|
repeat_id = [] |
|
for i, label in enumerate(labels): |
|
if label not in label2idx: |
|
label2idx[label] = [i] |
|
else: |
|
|
|
label2idx[label].append(i) |
|
repeat_id.append(i) |
|
|
|
ground_truth = [label2idx[label] for label in labels] |
|
|
|
for name, logit in logits.items(): |
|
if name == 'text_to_image': |
|
logit[:, repeat_id] -= 1e8 |
|
r_stat = {1: [], 5: [], 10: []} |
|
|
|
|
|
ranking = torch.argsort(logit, descending=True) |
|
for i, each_query in enumerate(ranking[:, :10]): |
|
for j, q in enumerate(each_query): |
|
found = False |
|
if q in ground_truth[i]: |
|
for k, v in r_stat.items(): |
|
if j < k: |
|
found = True |
|
v.append(1) |
|
if found: |
|
break |
|
for k, v in r_stat.items(): |
|
metrics[f'{name}_R@{k}'] = sum(v)/len(logit) |
|
return metrics |
|
|
|
def validation_step(self, batch, batch_idx): |
|
image, text, label = batch |
|
image_features, text_features, logit_scale = self(image, text) |
|
return image_features, text_features, logit_scale, text['input_ids'].shape[0], label |
|
|
|
def validation_epoch_end(self, val_outputs): |
|
all_image_features = [] |
|
all_text_features = [] |
|
all_labels = [] |
|
sample_size = 0 |
|
for o in val_outputs: |
|
all_image_features.append(o[0]) |
|
all_text_features.append(o[1]) |
|
sample_size += o[3] |
|
all_labels += o[4] |
|
if len(all_image_features) == 0 or len(all_text_features) == 0: |
|
return |
|
all_image_features = torch.cat(all_image_features) |
|
all_text_features = torch.cat(all_text_features) |
|
logit_scale = val_outputs[0][2].mean() |
|
logits_per_image = logit_scale * all_image_features @ all_text_features.t() |
|
logits_per_text = logits_per_image.t() |
|
|
|
labels = torch.arange(sample_size, device=self.device).long() |
|
total_loss = (F.cross_entropy(logits_per_image, labels) |
|
+ F.cross_entropy(logits_per_text, labels)) / 2 |
|
|
|
val_metrics = self.get_metrics( |
|
image_features=all_image_features, |
|
text_features=all_text_features, |
|
logit_scale=logit_scale, |
|
labels=all_labels) |
|
loss = total_loss / sample_size |
|
self.log('val_loss', loss, sync_dist=False) |
|
for k, v in val_metrics.items(): |
|
self.log(f'val_{k}', v, sync_dist=False) |
|
|
|
def on_load_checkpoint(self, checkpoint) -> None: |
|
|
|
global_step_offset = checkpoint["global_step"] |
|
if 'global_samples' in checkpoint: |
|
self.consumed_samples = checkpoint['global_samples'] |
|
self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset |
|
|
|
def on_save_checkpoint(self, checkpoint) -> None: |
|
|
|
if self.global_rank == 0: |
|
dir_path = os.path.join( |
|
self.hparams.default_root_dir, f'hf_out_{self.trainer.current_epoch}_{self.trainer.global_step}') |
|
if not os.path.exists(dir_path): |
|
os.mkdir(dir_path) |
|
self.model.save_pretrained(dir_path) |
|
self.processor.save_pretrained(dir_path) |
|
|
|
|
|
if __name__ == '__main__': |
|
args_parser = argparse.ArgumentParser() |
|
args_parser = add_module_args(args_parser) |
|
args_parser = add_data_args(args_parser) |
|
args_parser = UniversalDataModule.add_data_specific_args(args_parser) |
|
args_parser = Trainer.add_argparse_args(args_parser) |
|
args_parser = TaiyiCLIP.add_module_specific_args(args_parser) |
|
args_parser = UniversalCheckpoint.add_argparse_args(args_parser) |
|
args = args_parser.parse_args() |
|
|
|
lr_monitor = LearningRateMonitor(logging_interval='step') |
|
checkpoint_callback = UniversalCheckpoint(args) |
|
|
|
trainer = Trainer.from_argparse_args(args, |
|
callbacks=[ |
|
lr_monitor, |
|
checkpoint_callback]) |
|
|
|
model = TaiyiCLIP(args) |
|
processor = model.processor |
|
collate_fn = Collator(args, processor) |
|
datasets = load_data(args, global_rank=trainer.global_rank) |
|
|
|
|
|
from fengshen.examples.pretrain_taiyi_clip.flickr_datasets import flickr30k_CNA |
|
img_root = '/shared_space/ccnl/mm_data/Flickr30k-CNA/flickr30k/images' |
|
text_annot_path = '/shared_space/ccnl/mm_data/Flickr30k-CNA/test/flickr30k_cn_test.txt' |
|
|
|
datasets[args.val_datasets_field] = flickr30k_CNA(img_root, text_annot_path, collate_fn) |
|
|
|
datamoule = UniversalDataModule( |
|
tokenizer=None, collate_fn=collate_fn, args=args, datasets=datasets) |
|
|
|
trainer.fit(model, datamoule, ckpt_path=args.load_ckpt_path) |
|
|