File size: 4,222 Bytes
f331e3a
 
 
 
 
 
 
 
 
 
 
 
627ad10
f331e3a
 
 
 
 
 
47b1e6f
f331e3a
627ad10
 
f331e3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
from model import MaskedAutoencoderViT, mae_vit_base_patch16
import numpy as np
from PIL import Image
import torch.nn.functional as F
from einops import rearrange
from transformers import AutoTokenizer
from collections import OrderedDict
from huggingface_hub import hf_hub_download

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', )

ckpt = torch.load(hf_hub_download('tennant/MUG', 'laion_mug_vit_base_5ep.pth'), map_location='cpu')

new_dict = OrderedDict()
for k, v in ckpt.items():
    k = k[len('image_encoder.model.'):]
    new_dict.update({k: v})

model = mae_vit_base_patch16(uni_dim=768, uni_heads=12, less_u=True)

msg = model.load_state_dict(new_dict, strict=False)
print(msg)
if torch.cuda.is_available():
    model.cuda()
model.eval()

@torch.no_grad()
def visual_recon(x, model, mask_ratio=0.75):
    target = model.patchify(x)
    mean = target.mean(dim=-1, keepdim=True)
    var = target.var(dim=-1, keepdim=True)

    latent, mask, ids_restore, _ = model.forward_encoder(x, mask_ratio=mask_ratio)
    y, _ = model.forward_decoder(latent, ids_restore)
    y = y * (var + 1.e-6)**.5 + mean
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y)
    
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask)
    
    x = torch.einsum('nchw->nhwc', x)
    
    return x * (1 - mask), x * (1 - mask) + y * mask, y, latent

@torch.no_grad()
def caption_next_word(latent, model, tokenizer, prefix='a photo of a'):
    assert latent.shape[0] == 1, 'can only caption one image at a time'
    
    x_l = torch.tensor(tokenizer([prefix, ])['input_ids'])[:, :-1]
    seq = x_l.shape[1]
    if torch.cuda.is_available():
        x_l = x_l.cuda()

    cls_mask = rearrange(x_l != 0, 'b j -> b 1 j')
    attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)

    x_l = model.embed_text(x_l)

    for cross_attn1, cross_attn2 in model.multimodal_layers:
        x_l = cross_attn1(x_l, latent)
        x_l = cross_attn2(x_l, latent)

    pred = model.to_logits(x_l)
    pred[:, :, 103] = -100
    pred[:, :, 101] = -100
    pred[:, :, 100] = -100
    pred[:, :, 0] = -100
    next_word = pred.argmax(dim=-1)[0, -1]
    next_word = tokenizer.decode(next_word)
    
    return next_word

def caption(max_len, latent, model, tokenizer, prefix='a photo of a'):
    words = prefix.split()
    while len(words) < max_len:
        next_word = caption_next_word(latent, model, tokenizer, prefix=' '.join(words))
        words.append(next_word)
        if next_word == '[SEP]':
            break
    return ' '.join(words)


def gr_caption(x, mask_ratio=0.75, max_len=20, prefix='a'):
    imagenet_mean = np.array([0.485, 0.456, 0.406])
    imagenet_std = np.array([0.229, 0.224, 0.225])
    x = np.array(x) / 255.
    x = x - imagenet_mean
    x = x / imagenet_std

    x = torch.tensor(x).float()
    x = x.unsqueeze(0)
    x = torch.einsum('nhwc->nchw', x)
    if torch.cuda.is_available():
        x = x.cuda()
        
    def unnorm_pix(img):
        img = img.squeeze(0).cpu().detach().numpy()
        img = img * imagenet_std + imagenet_mean
        return np.clip(img, a_min=0., a_max=1.)

    masked, masked_recon, recon, latent = visual_recon(x, model, mask_ratio=mask_ratio)
    caption_from_model = caption(max_len, latent, model, tokenizer, prefix=prefix)
    
    masked, masked_recon, recon = map(unnorm_pix, (masked, masked_recon, recon))
    return_img = np.concatenate([masked, masked_recon, recon], axis=1)
    
    return return_img, caption_from_model

import gradio as gr

demo = gr.Interface(gr_caption, 
                    inputs=[gr.Image(value='cat.jpeg', shape=(224, 224)),
                            gr.Number(value=0.75, label='mask ratio'),
                            gr.Number(value=20, label='max length'),
                            gr.Textbox(value='a photo of a', label='caption prefix')], 
                    outputs=[gr.Image(shape=(224, 224 * 3)), 
                             'text'],
                    # examples=[['cat.jpeg', 0.75, 20, 'a photo of a']],
                )
demo.launch()