File size: 13,253 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
from pytorch_lightning import (
    LightningModule,
    Trainer,
)
from pytorch_lightning.callbacks import (
    LearningRateMonitor,
)
from fengshen.models.clip import (
    TaiyiCLIPModel,
    TaiyiCLIPProcessor,
)
from fengshen.models.model_utils import (
    add_module_args,
    configure_optimizers,
    get_total_steps,
)
import torch
import torch.nn.functional as F
import argparse
import math
from fengshen.data.universal_datamodule import UniversalDataModule
from fengshen.data.taiyi_stable_diffusion_datasets.taiyi_datasets import add_data_args, load_data
from fengshen.utils.universal_checkpoint import UniversalCheckpoint
import os
import numpy as np
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor

OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)


class Collator():
    def __init__(self, args, processor):
        self.processor = processor
        self.seq_length = args.seq_length
        self.transforms = Compose([
            ToTensor(),
            RandomResizedCrop(args.resolution, scale=(0.9, 1.0),
                              interpolation=InterpolationMode.BICUBIC),
            Normalize(mean=OPENAI_DATASET_MEAN, std=OPENAI_DATASET_STD),
        ])

    def __call__(self, inputs):
        max_length = min(self.seq_length, max([len(i['caption']) for i in inputs]))
        images = []
        texts = []
        labels = []
        for i in inputs:
            # instance_image = Image.open(i['img_path'])
            # instance_image = jpeg4py.JPEG(i['img_path']).decode()
            instance_image = np.load(i['npy_path'])
            images.append(self.transforms(instance_image))
            texts.append(i['caption'])
            labels.append(i['labels'] if 'labels' in i else -100)
        # images_input = self.processor(images=images, return_tensors="pt")
        texts_input = self.processor(text=texts,
                                     max_length=max_length,
                                     padding='max_length',
                                     truncation=True,
                                     return_tensors='pt')
        # return images_input, texts_input, labels
        return {'pixel_values': torch.stack(images)}, texts_input, labels


class TaiyiCLIP(LightningModule):
    @staticmethod
    def add_module_specific_args(parent_parser):
        parser = parent_parser.add_argument_group('Taiyi CLIP')
        parser.add_argument('--loss_type', choices=['local', 'global'], default='local')
        parser.add_argument('--seq_length', default=77)
        parser.add_argument('--gather_with_grad', default=False, action='store_true')
        parser.add_argument('--freeze_image_tower', default=False, action='store_true')
        return parent_parser

    def __init__(self, args, **kwargs) -> None:
        super().__init__()
        self.save_hyperparameters(args)

        self.model = TaiyiCLIPModel.from_pretrained(args.model_path)
        self.processor = TaiyiCLIPProcessor.from_pretrained(args.model_path)

        self.local_loss = args.loss_type == 'local'

        if args.freeze_image_tower:
            for param in self.model.vision_model.parameters():
                param.requires_grad = False
            self.model.visual_projection.requires_grad = False

        # cache
        self.cache_labels = True
        self.prev_num_logits = 0
        self.labels = {}

    def setup(self, stage) -> None:
        if stage == 'fit':
            self.total_steps = get_total_steps(self.trainer, self.hparams)
            print('Total steps: {}' .format(self.total_steps))
        elif stage == 'validate':
            self.total_steps = 100

    def configure_optimizers(self):
        return configure_optimizers(self)

    def forward(self, image, text):
        assert image is not None
        assert text is not None
        image_features = self.model.get_image_features(**image)
        text_features = self.model.get_text_features(**text)

        image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)

        return image_features, text_features, self.model.logit_scale.exp()

    def gather_features(self, features):
        if self.trainer.world_size == 1:
            return features
        all_features = self.all_gather(
            features, sync_grads=self.hparams.gather_with_grad)
        if not self.local_loss and not self.gather_with_grad:
            # 如果是全局loss,并且不需要梯度,需要把梯度更新回tensor
            all_features[self.global_rank] = features
        all_features = all_features.view(-1, all_features.shape[-1])
        return all_features

    def clip_loss(self, image_features, text_features, logit_scale):

        logits_per_image = None

        # 如果我冻住VIT并且是local_loss,那么我只需要自己的这部分text feature就行
        # 因为根本不需要image2text的feature训练VIT
        if self.hparams.freeze_image_tower and self.local_loss:
            all_text_features = None
        else:
            all_text_features = self.gather_features(
                text_features)
        all_image_features = self.gather_features(
            image_features)

        if self.local_loss:
            if all_text_features is not None:
                logits_per_image = logit_scale * image_features @ all_text_features.T
            logits_per_text = logit_scale * text_features @ all_image_features.T
        else:
            # 如果是global_loss,那all_text_features肯定不是空的
            logits_per_image = logit_scale * all_image_features @ all_text_features.T
            logits_per_text = logits_per_image.T

        num_logits = logits_per_text.shape[0]
        if self.prev_num_logits != num_logits or self.device not in self.labels:
            labels = torch.arange(num_logits, device=self.device, dtype=torch.long)
            if self.trainer.world_size > 1 and self.local_loss:
                labels = labels + num_logits * self.global_rank
            if self.cache_labels:
                self.labels[self.device] = labels
                self.prev_num_logits = num_logits
        else:
            labels = self.labels[self.device]

        total_loss = (
            F.cross_entropy(logits_per_image, labels) +
            F.cross_entropy(logits_per_text, labels)
        ) / 2 if logits_per_image is not None else F.cross_entropy(logits_per_text, labels)
        return total_loss

    def training_step(self, batch):
        image, text, _ = batch
        image_features, text_features, logit_scale = self(image, text)
        total_loss = self.clip_loss(image_features, text_features, logit_scale)
        self.log('train_loss', total_loss, sync_dist=False)
        return total_loss

    def on_train_batch_end(self, outputs, batch, batch_idx: int) -> None:
        with torch.no_grad():
            self.model.logit_scale.clamp_(0, math.log(100))

    def get_metrics(self, image_features, text_features, labels, logit_scale):
        # 计算相似度,支持多个样本的情况(比如一个图片有多个caption)
        # img2txt计算的时候要用到,因为一张图片可能对应多个文本。
        # txt2img计算的时候不需要(一般一个text只有一个对应图片)
        metrics = {}
        logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
        logits_per_text = logits_per_image.t().detach().cpu()

        logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}

        label2idx = {}  # 计算label到idx的映射。
        repeat_id = []
        for i, label in enumerate(labels):
            if label not in label2idx:
                label2idx[label] = [i]
            else:
                # 表示该index的标签出现过,记录这个index,后续算txt2img分数的时候,这些index的权值要降低。
                label2idx[label].append(i)
                repeat_id.append(i)

        ground_truth = [label2idx[label] for label in labels]

        for name, logit in logits.items():
            if name == 'text_to_image':
                logit[:, repeat_id] -= 1e8   # 这部分的分数要降低。(重复出现的图片,直接忽略)
            r_stat = {1: [], 5: [], 10: []}
            # r1_stat, r5_stat, r10_stat = [], [], []
            # index of the largest element to the smallest
            ranking = torch.argsort(logit, descending=True)
            for i, each_query in enumerate(ranking[:, :10]):
                for j, q in enumerate(each_query):
                    found = False
                    if q in ground_truth[i]:
                        for k, v in r_stat.items():
                            if j < k:
                                found = True
                                v.append(1)
                    if found:
                        break
            for k, v in r_stat.items():
                metrics[f'{name}_R@{k}'] = sum(v)/len(logit)
        return metrics

    def validation_step(self, batch, batch_idx):
        image, text, label = batch
        image_features, text_features, logit_scale = self(image, text)
        return image_features, text_features, logit_scale, text['input_ids'].shape[0], label

    def validation_epoch_end(self, val_outputs):
        all_image_features = []
        all_text_features = []
        all_labels = []
        sample_size = 0
        for o in val_outputs:
            all_image_features.append(o[0])
            all_text_features.append(o[1])
            sample_size += o[3]
            all_labels += o[4]
        if len(all_image_features) == 0 or len(all_text_features) == 0:
            return
        all_image_features = torch.cat(all_image_features)
        all_text_features = torch.cat(all_text_features)
        logit_scale = val_outputs[0][2].mean()
        logits_per_image = logit_scale * all_image_features @ all_text_features.t()
        logits_per_text = logits_per_image.t()

        labels = torch.arange(sample_size, device=self.device).long()
        total_loss = (F.cross_entropy(logits_per_image, labels)
                      + F.cross_entropy(logits_per_text, labels)) / 2

        val_metrics = self.get_metrics(
            image_features=all_image_features,
            text_features=all_text_features,
            logit_scale=logit_scale,
            labels=all_labels)
        loss = total_loss / sample_size
        self.log('val_loss', loss, sync_dist=False)
        for k, v in val_metrics.items():
            self.log(f'val_{k}', v, sync_dist=False)

    def on_load_checkpoint(self, checkpoint) -> None:
        # 兼容低版本lightning,低版本lightning从ckpt起来时steps数会被重置为0
        global_step_offset = checkpoint["global_step"]
        if 'global_samples' in checkpoint:
            self.consumed_samples = checkpoint['global_samples']
        self.trainer.fit_loop.epoch_loop._batches_that_stepped = global_step_offset

    def on_save_checkpoint(self, checkpoint) -> None:
        # 保存的时候把权重按huggingface的形式保存出来
        if self.global_rank == 0:
            dir_path = os.path.join(
                self.hparams.default_root_dir, f'hf_out_{self.trainer.current_epoch}_{self.trainer.global_step}')
            if not os.path.exists(dir_path):
                os.mkdir(dir_path)
            self.model.save_pretrained(dir_path)
            self.processor.save_pretrained(dir_path)


if __name__ == '__main__':
    args_parser = argparse.ArgumentParser()
    args_parser = add_module_args(args_parser)
    args_parser = add_data_args(args_parser)
    args_parser = UniversalDataModule.add_data_specific_args(args_parser)
    args_parser = Trainer.add_argparse_args(args_parser)
    args_parser = TaiyiCLIP.add_module_specific_args(args_parser)
    args_parser = UniversalCheckpoint.add_argparse_args(args_parser)
    args = args_parser.parse_args()

    lr_monitor = LearningRateMonitor(logging_interval='step')
    checkpoint_callback = UniversalCheckpoint(args)

    trainer = Trainer.from_argparse_args(args,
                                         callbacks=[
                                             lr_monitor,
                                             checkpoint_callback])

    model = TaiyiCLIP(args)
    processor = model.processor
    collate_fn = Collator(args, processor)
    datasets = load_data(args, global_rank=trainer.global_rank)

    # 加载单个验证集:!!!验证代码有效性临时这样干的,验证完有效性会删除
    from fengshen.examples.pretrain_taiyi_clip.flickr_datasets import flickr30k_CNA
    img_root = '/shared_space/ccnl/mm_data/Flickr30k-CNA/flickr30k/images'
    text_annot_path = '/shared_space/ccnl/mm_data/Flickr30k-CNA/test/flickr30k_cn_test.txt'

    datasets[args.val_datasets_field] = flickr30k_CNA(img_root, text_annot_path, collate_fn)

    datamoule = UniversalDataModule(
        tokenizer=None, collate_fn=collate_fn, args=args, datasets=datasets)

    trainer.fit(model, datamoule, ckpt_path=args.load_ckpt_path)