vijul.shah commited on
Commit
0f2d9f6
·
1 Parent(s): 51ba5d6

End-to-End Pipeline Configured

Browse files
.gitignore CHANGED
@@ -1 +1 @@
1
- __pycache__/
 
1
+ __pycache__
SR_Inference/codeformer/codeformer_arch.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from .vqgan_arch import *
4
+ from typing import Optional
5
+ from torch import nn, Tensor
6
+ import torch.nn.functional as F
7
+
8
+ def calc_mean_std(feat, eps=1e-5):
9
+ """Calculate mean and std for adaptive_instance_normalization.
10
+
11
+ Args:
12
+ feat (Tensor): 4D tensor.
13
+ eps (float): A small value added to the variance to avoid
14
+ divide-by-zero. Default: 1e-5.
15
+ """
16
+ size = feat.size()
17
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
18
+ b, c = size[:2]
19
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
20
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
21
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
22
+ return feat_mean, feat_std
23
+
24
+ def adaptive_instance_normalization(content_feat, style_feat):
25
+ """Adaptive instance normalization.
26
+
27
+ Adjust the reference features to have the similar color and illuminations
28
+ as those in the degradate features.
29
+
30
+ Args:
31
+ content_feat (Tensor): The reference feature.
32
+ style_feat (Tensor): The degradate features.
33
+ """
34
+ size = content_feat.size()
35
+ style_mean, style_std = calc_mean_std(style_feat)
36
+ content_mean, content_std = calc_mean_std(content_feat)
37
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
38
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
39
+
40
+ class PositionEmbeddingSine(nn.Module):
41
+ """
42
+ This is a more standard version of the position embedding, very similar to the one
43
+ used by the Attention is all you need paper, generalized to work on images.
44
+ """
45
+
46
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
47
+ super().__init__()
48
+ self.num_pos_feats = num_pos_feats
49
+ self.temperature = temperature
50
+ self.normalize = normalize
51
+ if scale is not None and normalize is False:
52
+ raise ValueError("normalize should be True if scale is passed")
53
+ if scale is None:
54
+ scale = 2 * math.pi
55
+ self.scale = scale
56
+
57
+ def forward(self, x, mask=None):
58
+ if mask is None:
59
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
60
+ not_mask = ~mask
61
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
62
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
63
+ if self.normalize:
64
+ eps = 1e-6
65
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
66
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
67
+
68
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
69
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
70
+
71
+ pos_x = x_embed[:, :, :, None] / dim_t
72
+ pos_y = y_embed[:, :, :, None] / dim_t
73
+ pos_x = torch.stack(
74
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
75
+ ).flatten(3)
76
+ pos_y = torch.stack(
77
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
78
+ ).flatten(3)
79
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
80
+ return pos
81
+
82
+ def _get_activation_fn(activation):
83
+ """Return an activation function given a string"""
84
+ if activation == "relu":
85
+ return F.relu
86
+ if activation == "gelu":
87
+ return F.gelu
88
+ if activation == "glu":
89
+ return F.glu
90
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
91
+
92
+ class TransformerSALayer(nn.Module):
93
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
94
+ super().__init__()
95
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
96
+ # Implementation of Feedforward model - MLP
97
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
98
+ self.dropout = nn.Dropout(dropout)
99
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
100
+
101
+ self.norm1 = nn.LayerNorm(embed_dim)
102
+ self.norm2 = nn.LayerNorm(embed_dim)
103
+ self.dropout1 = nn.Dropout(dropout)
104
+ self.dropout2 = nn.Dropout(dropout)
105
+
106
+ self.activation = _get_activation_fn(activation)
107
+
108
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
109
+ return tensor if pos is None else tensor + pos
110
+
111
+ def forward(self, tgt,
112
+ tgt_mask: Optional[Tensor] = None,
113
+ tgt_key_padding_mask: Optional[Tensor] = None,
114
+ query_pos: Optional[Tensor] = None):
115
+
116
+ # self attention
117
+ tgt2 = self.norm1(tgt)
118
+ q = k = self.with_pos_embed(tgt2, query_pos)
119
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
120
+ key_padding_mask=tgt_key_padding_mask)[0]
121
+ tgt = tgt + self.dropout1(tgt2)
122
+
123
+ # ffn
124
+ tgt2 = self.norm2(tgt)
125
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
126
+ tgt = tgt + self.dropout2(tgt2)
127
+ return tgt
128
+
129
+ class Fuse_sft_block(nn.Module):
130
+ def __init__(self, in_ch, out_ch):
131
+ super().__init__()
132
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
133
+
134
+ self.scale = nn.Sequential(
135
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
136
+ nn.LeakyReLU(0.2, True),
137
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
138
+
139
+ self.shift = nn.Sequential(
140
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
141
+ nn.LeakyReLU(0.2, True),
142
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
143
+
144
+ def forward(self, enc_feat, dec_feat, w=1):
145
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
146
+ scale = self.scale(enc_feat)
147
+ shift = self.shift(enc_feat)
148
+ residual = w * (dec_feat * scale + shift)
149
+ out = dec_feat + residual
150
+ return out
151
+
152
+ class CodeFormerArch(VQAutoEncoder):
153
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
154
+ codebook_size=1024, latent_size=256,
155
+ connect_list=['32', '64', '128', '256'],
156
+ fix_modules=['quantize','generator'], vqgan_path=None):
157
+ super(CodeFormerArch, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
158
+
159
+ if vqgan_path is not None:
160
+ self.load_state_dict(
161
+ torch.load(vqgan_path, map_location='cpu')['params_ema'])
162
+
163
+ if fix_modules is not None:
164
+ for module in fix_modules:
165
+ for param in getattr(self, module).parameters():
166
+ param.requires_grad = False
167
+
168
+ self.connect_list = connect_list
169
+ self.n_layers = n_layers
170
+ self.dim_embd = dim_embd
171
+ self.dim_mlp = dim_embd*2
172
+
173
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
174
+ self.feat_emb = nn.Linear(256, self.dim_embd)
175
+
176
+ # transformer
177
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
178
+ for _ in range(self.n_layers)])
179
+
180
+ # logits_predict head
181
+ self.idx_pred_layer = nn.Sequential(
182
+ nn.LayerNorm(dim_embd),
183
+ nn.Linear(dim_embd, codebook_size, bias=False))
184
+
185
+ self.channels = {
186
+ '16': 512,
187
+ '32': 256,
188
+ '64': 256,
189
+ '128': 128,
190
+ '256': 128,
191
+ '512': 64,
192
+ }
193
+
194
+ # after second residual block for > 16, before attn layer for ==16
195
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
196
+ # after first residual block for > 16, before attn layer for ==16
197
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
198
+
199
+ # fuse_convs_dict
200
+ self.fuse_convs_dict = nn.ModuleDict()
201
+ for f_size in self.connect_list:
202
+ in_ch = self.channels[f_size]
203
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
204
+
205
+ def _init_weights(self, module):
206
+ if isinstance(module, (nn.Linear, nn.Embedding)):
207
+ module.weight.data.normal_(mean=0.0, std=0.02)
208
+ if isinstance(module, nn.Linear) and module.bias is not None:
209
+ module.bias.data.zero_()
210
+ elif isinstance(module, nn.LayerNorm):
211
+ module.bias.data.zero_()
212
+ module.weight.data.fill_(1.0)
213
+
214
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
215
+ # ################### Encoder #####################
216
+ enc_feat_dict = {}
217
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
218
+ for i, block in enumerate(self.encoder.blocks):
219
+ x = block(x)
220
+ if i in out_list:
221
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
222
+
223
+ lq_feat = x
224
+ # ################# Transformer ###################
225
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
226
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
227
+ # BCHW -> BC(HW) -> (HW)BC
228
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
229
+ query_emb = feat_emb
230
+ # Transformer encoder
231
+ for layer in self.ft_layers:
232
+ query_emb = layer(query_emb, query_pos=pos_emb)
233
+
234
+ # output logits
235
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
236
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
237
+
238
+ if code_only: # for training stage II
239
+ # logits doesn't need softmax before cross_entropy loss
240
+ return logits, lq_feat
241
+
242
+ # ################# Quantization ###################
243
+ # if self.training:
244
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
245
+ # # b(hw)c -> bc(hw) -> bchw
246
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
247
+ # ------------
248
+ soft_one_hot = F.softmax(logits, dim=2)
249
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
250
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
251
+ # preserve gradients
252
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
253
+
254
+ if detach_16:
255
+ quant_feat = quant_feat.detach() # for training stage III
256
+ if adain:
257
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
258
+
259
+ # ################## Generator ####################
260
+ x = quant_feat
261
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
262
+
263
+ for i, block in enumerate(self.generator.blocks):
264
+ x = block(x)
265
+ if i in fuse_list: # fuse after i-th block
266
+ f_size = str(x.shape[-1])
267
+ if w>0:
268
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
269
+ out = x
270
+ # logits doesn't need softmax before cross_entropy loss
271
+ return out, logits, lq_feat
SR_Inference/codeformer/vqgan_arch.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ VQGAN code, adapted from the original created by the Unleashing Transformers authors:
3
+ https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
4
+
5
+ '''
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from basicsr.utils import get_root_logger
10
+
11
+ def normalize(in_channels):
12
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
13
+
14
+ @torch.jit.script
15
+ def swish(x):
16
+ return x*torch.sigmoid(x)
17
+
18
+ # Define VQVAE classes
19
+ class VectorQuantizer(nn.Module):
20
+ def __init__(self, codebook_size, emb_dim, beta):
21
+ super(VectorQuantizer, self).__init__()
22
+ self.codebook_size = codebook_size # number of embeddings
23
+ self.emb_dim = emb_dim # dimension of embedding
24
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
25
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
26
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
27
+
28
+ def forward(self, z):
29
+ # reshape z -> (batch, height, width, channel) and flatten
30
+ z = z.permute(0, 2, 3, 1).contiguous()
31
+ z_flattened = z.view(-1, self.emb_dim)
32
+
33
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
34
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
35
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
36
+
37
+ mean_distance = torch.mean(d)
38
+ # find closest encodings
39
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
40
+ # min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
41
+ # [0-1], higher score, higher confidence
42
+ # min_encoding_scores = torch.exp(-min_encoding_scores/10)
43
+
44
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
45
+ min_encodings.scatter_(1, min_encoding_indices, 1)
46
+
47
+ # get quantized latent vectors
48
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
49
+ # compute loss for embedding
50
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
51
+ # preserve gradients
52
+ z_q = z + (z_q - z).detach()
53
+
54
+ # perplexity
55
+ e_mean = torch.mean(min_encodings, dim=0)
56
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
57
+ # reshape back to match original input shape
58
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
59
+
60
+ return z_q, loss, {
61
+ "perplexity": perplexity,
62
+ "min_encodings": min_encodings,
63
+ "min_encoding_indices": min_encoding_indices,
64
+ "mean_distance": mean_distance
65
+ }
66
+
67
+ def get_codebook_feat(self, indices, shape):
68
+ # input indices: batch*token_num -> (batch*token_num)*1
69
+ # shape: batch, height, width, channel
70
+ indices = indices.view(-1,1)
71
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
72
+ min_encodings.scatter_(1, indices, 1)
73
+ # get quantized latent vectors
74
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
75
+
76
+ if shape is not None: # reshape back to match original input shape
77
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
78
+
79
+ return z_q
80
+
81
+ class GumbelQuantizer(nn.Module):
82
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
83
+ super().__init__()
84
+ self.codebook_size = codebook_size # number of embeddings
85
+ self.emb_dim = emb_dim # dimension of embedding
86
+ self.straight_through = straight_through
87
+ self.temperature = temp_init
88
+ self.kl_weight = kl_weight
89
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
90
+ self.embed = nn.Embedding(codebook_size, emb_dim)
91
+
92
+ def forward(self, z):
93
+ hard = self.straight_through if self.training else True
94
+
95
+ logits = self.proj(z)
96
+
97
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
98
+
99
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
100
+
101
+ # + kl divergence to the prior loss
102
+ qy = F.softmax(logits, dim=1)
103
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
104
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
105
+
106
+ return z_q, diff, {
107
+ "min_encoding_indices": min_encoding_indices
108
+ }
109
+
110
+ class Downsample(nn.Module):
111
+ def __init__(self, in_channels):
112
+ super().__init__()
113
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
114
+
115
+ def forward(self, x):
116
+ pad = (0, 1, 0, 1)
117
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
118
+ x = self.conv(x)
119
+ return x
120
+
121
+ class Upsample(nn.Module):
122
+ def __init__(self, in_channels):
123
+ super().__init__()
124
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
125
+
126
+ def forward(self, x):
127
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
128
+ x = self.conv(x)
129
+
130
+ return x
131
+
132
+ class ResBlock(nn.Module):
133
+ def __init__(self, in_channels, out_channels=None):
134
+ super(ResBlock, self).__init__()
135
+ self.in_channels = in_channels
136
+ self.out_channels = in_channels if out_channels is None else out_channels
137
+ self.norm1 = normalize(in_channels)
138
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
139
+ self.norm2 = normalize(out_channels)
140
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
141
+ if self.in_channels != self.out_channels:
142
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
143
+
144
+ def forward(self, x_in):
145
+ x = x_in
146
+ x = self.norm1(x)
147
+ x = swish(x)
148
+ x = self.conv1(x)
149
+ x = self.norm2(x)
150
+ x = swish(x)
151
+ x = self.conv2(x)
152
+ if self.in_channels != self.out_channels:
153
+ x_in = self.conv_out(x_in)
154
+
155
+ return x + x_in
156
+
157
+ class AttnBlock(nn.Module):
158
+ def __init__(self, in_channels):
159
+ super().__init__()
160
+ self.in_channels = in_channels
161
+
162
+ self.norm = normalize(in_channels)
163
+ self.q = torch.nn.Conv2d(
164
+ in_channels,
165
+ in_channels,
166
+ kernel_size=1,
167
+ stride=1,
168
+ padding=0
169
+ )
170
+ self.k = torch.nn.Conv2d(
171
+ in_channels,
172
+ in_channels,
173
+ kernel_size=1,
174
+ stride=1,
175
+ padding=0
176
+ )
177
+ self.v = torch.nn.Conv2d(
178
+ in_channels,
179
+ in_channels,
180
+ kernel_size=1,
181
+ stride=1,
182
+ padding=0
183
+ )
184
+ self.proj_out = torch.nn.Conv2d(
185
+ in_channels,
186
+ in_channels,
187
+ kernel_size=1,
188
+ stride=1,
189
+ padding=0
190
+ )
191
+
192
+ def forward(self, x):
193
+ h_ = x
194
+ h_ = self.norm(h_)
195
+ q = self.q(h_)
196
+ k = self.k(h_)
197
+ v = self.v(h_)
198
+
199
+ # compute attention
200
+ b, c, h, w = q.shape
201
+ q = q.reshape(b, c, h*w)
202
+ q = q.permute(0, 2, 1)
203
+ k = k.reshape(b, c, h*w)
204
+ w_ = torch.bmm(q, k)
205
+ w_ = w_ * (int(c)**(-0.5))
206
+ w_ = F.softmax(w_, dim=2)
207
+
208
+ # attend to values
209
+ v = v.reshape(b, c, h*w)
210
+ w_ = w_.permute(0, 2, 1)
211
+ h_ = torch.bmm(v, w_)
212
+ h_ = h_.reshape(b, c, h, w)
213
+
214
+ h_ = self.proj_out(h_)
215
+
216
+ return x+h_
217
+
218
+ class Encoder(nn.Module):
219
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
220
+ super().__init__()
221
+ self.nf = nf
222
+ self.num_resolutions = len(ch_mult)
223
+ self.num_res_blocks = num_res_blocks
224
+ self.resolution = resolution
225
+ self.attn_resolutions = attn_resolutions
226
+
227
+ curr_res = self.resolution
228
+ in_ch_mult = (1,)+tuple(ch_mult)
229
+
230
+ blocks = []
231
+ # initial convultion
232
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
233
+
234
+ # residual and downsampling blocks, with attention on smaller res (16x16)
235
+ for i in range(self.num_resolutions):
236
+ block_in_ch = nf * in_ch_mult[i]
237
+ block_out_ch = nf * ch_mult[i]
238
+ for _ in range(self.num_res_blocks):
239
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
240
+ block_in_ch = block_out_ch
241
+ if curr_res in attn_resolutions:
242
+ blocks.append(AttnBlock(block_in_ch))
243
+
244
+ if i != self.num_resolutions - 1:
245
+ blocks.append(Downsample(block_in_ch))
246
+ curr_res = curr_res // 2
247
+
248
+ # non-local attention block
249
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
250
+ blocks.append(AttnBlock(block_in_ch))
251
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
252
+
253
+ # normalise and convert to latent size
254
+ blocks.append(normalize(block_in_ch))
255
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
256
+ self.blocks = nn.ModuleList(blocks)
257
+
258
+ def forward(self, x):
259
+ for block in self.blocks:
260
+ x = block(x)
261
+
262
+ return x
263
+
264
+ class Generator(nn.Module):
265
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
266
+ super().__init__()
267
+ self.nf = nf
268
+ self.ch_mult = ch_mult
269
+ self.num_resolutions = len(self.ch_mult)
270
+ self.num_res_blocks = res_blocks
271
+ self.resolution = img_size
272
+ self.attn_resolutions = attn_resolutions
273
+ self.in_channels = emb_dim
274
+ self.out_channels = 3
275
+ block_in_ch = self.nf * self.ch_mult[-1]
276
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
277
+
278
+ blocks = []
279
+ # initial conv
280
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
281
+
282
+ # non-local attention block
283
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
284
+ blocks.append(AttnBlock(block_in_ch))
285
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
286
+
287
+ for i in reversed(range(self.num_resolutions)):
288
+ block_out_ch = self.nf * self.ch_mult[i]
289
+
290
+ for _ in range(self.num_res_blocks):
291
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
292
+ block_in_ch = block_out_ch
293
+
294
+ if curr_res in self.attn_resolutions:
295
+ blocks.append(AttnBlock(block_in_ch))
296
+
297
+ if i != 0:
298
+ blocks.append(Upsample(block_in_ch))
299
+ curr_res = curr_res * 2
300
+
301
+ blocks.append(normalize(block_in_ch))
302
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
303
+
304
+ self.blocks = nn.ModuleList(blocks)
305
+
306
+
307
+ def forward(self, x):
308
+ for block in self.blocks:
309
+ x = block(x)
310
+
311
+ return x
312
+
313
+ class VQAutoEncoder(nn.Module):
314
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
315
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
316
+ super().__init__()
317
+ logger = get_root_logger()
318
+ self.in_channels = 3
319
+ self.nf = nf
320
+ self.n_blocks = res_blocks
321
+ self.codebook_size = codebook_size
322
+ self.embed_dim = emb_dim
323
+ self.ch_mult = ch_mult
324
+ self.resolution = img_size
325
+ self.attn_resolutions = attn_resolutions
326
+ self.quantizer_type = quantizer
327
+ self.encoder = Encoder(
328
+ self.in_channels,
329
+ self.nf,
330
+ self.embed_dim,
331
+ self.ch_mult,
332
+ self.n_blocks,
333
+ self.resolution,
334
+ self.attn_resolutions
335
+ )
336
+ if self.quantizer_type == "nearest":
337
+ self.beta = beta #0.25
338
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
339
+ elif self.quantizer_type == "gumbel":
340
+ self.gumbel_num_hiddens = emb_dim
341
+ self.straight_through = gumbel_straight_through
342
+ self.kl_weight = gumbel_kl_weight
343
+ self.quantize = GumbelQuantizer(
344
+ self.codebook_size,
345
+ self.embed_dim,
346
+ self.gumbel_num_hiddens,
347
+ self.straight_through,
348
+ self.kl_weight
349
+ )
350
+ self.generator = Generator(
351
+ self.nf,
352
+ self.embed_dim,
353
+ self.ch_mult,
354
+ self.n_blocks,
355
+ self.resolution,
356
+ self.attn_resolutions
357
+ )
358
+
359
+ if model_path is not None:
360
+ chkpt = torch.load(model_path, map_location='cpu')
361
+ if 'params_ema' in chkpt:
362
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
363
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
364
+ elif 'params' in chkpt:
365
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
366
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
367
+ else:
368
+ raise ValueError(f'Wrong params!')
369
+
370
+
371
+ def forward(self, x):
372
+ x = self.encoder(x)
373
+ quant, codebook_loss, quant_stats = self.quantize(x)
374
+ x = self.generator(quant)
375
+ return x, codebook_loss, quant_stats
376
+
377
+ # patch based discriminator
378
+ class VQGANDiscriminator(nn.Module):
379
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
380
+ super().__init__()
381
+
382
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
383
+ ndf_mult = 1
384
+ ndf_mult_prev = 1
385
+ for n in range(1, n_layers): # gradually increase the number of filters
386
+ ndf_mult_prev = ndf_mult
387
+ ndf_mult = min(2 ** n, 8)
388
+ layers += [
389
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
390
+ nn.BatchNorm2d(ndf * ndf_mult),
391
+ nn.LeakyReLU(0.2, True)
392
+ ]
393
+
394
+ ndf_mult_prev = ndf_mult
395
+ ndf_mult = min(2 ** n_layers, 8)
396
+
397
+ layers += [
398
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
399
+ nn.BatchNorm2d(ndf * ndf_mult),
400
+ nn.LeakyReLU(0.2, True)
401
+ ]
402
+
403
+ layers += [
404
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
405
+ self.main = nn.Sequential(*layers)
406
+
407
+ if model_path is not None:
408
+ chkpt = torch.load(model_path, map_location='cpu')
409
+ if 'params_d' in chkpt:
410
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
411
+ elif 'params' in chkpt:
412
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
413
+ else:
414
+ raise ValueError(f'Wrong params!')
415
+
416
+ def forward(self, x):
417
+ return self.main(x)
418
+
SR_Inference/codeformer/weights/codeformer_v0.1.0.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1009e537e0c2a07d4cabce6355f53cb66767cd4b4297ec7a4a64ca4b8a5684b7
3
+ size 376637898
SR_Inference/gfpgan/weights/GFPGANv1.3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70
3
+ size 348632874
SR_Inference/gfpgan/weights/detection_Resnet50_Final.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
3
+ size 109497761
SR_Inference/gfpgan/weights/parsing_parsenet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
3
+ size 85331193
SR_Inference/hat/hat_arch.py ADDED
@@ -0,0 +1,979 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+ from basicsr.archs.arch_util import to_2tuple, trunc_normal_
6
+
7
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
8
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
9
+
10
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
11
+ """
12
+ if drop_prob == 0. or not training:
13
+ return x
14
+ keep_prob = 1 - drop_prob
15
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
16
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
17
+ random_tensor.floor_() # binarize
18
+ output = x.div(keep_prob) * random_tensor
19
+ return output
20
+
21
+
22
+ class DropPath(nn.Module):
23
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
24
+
25
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
26
+ """
27
+
28
+ def __init__(self, drop_prob=None):
29
+ super(DropPath, self).__init__()
30
+ self.drop_prob = drop_prob
31
+
32
+ def forward(self, x):
33
+ return drop_path(x, self.drop_prob, self.training)
34
+
35
+
36
+ class ChannelAttention(nn.Module):
37
+ """Channel attention used in RCAN.
38
+ Args:
39
+ num_feat (int): Channel number of intermediate features.
40
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
41
+ """
42
+
43
+ def __init__(self, num_feat, squeeze_factor=16):
44
+ super(ChannelAttention, self).__init__()
45
+ self.attention = nn.Sequential(
46
+ nn.AdaptiveAvgPool2d(1),
47
+ nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
48
+ nn.ReLU(inplace=True),
49
+ nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
50
+ nn.Sigmoid())
51
+
52
+ def forward(self, x):
53
+ y = self.attention(x)
54
+ return x * y
55
+
56
+
57
+ class CAB(nn.Module):
58
+
59
+ def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
60
+ super(CAB, self).__init__()
61
+
62
+ self.cab = nn.Sequential(
63
+ nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
64
+ nn.GELU(),
65
+ nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
66
+ ChannelAttention(num_feat, squeeze_factor)
67
+ )
68
+
69
+ def forward(self, x):
70
+ return self.cab(x)
71
+
72
+
73
+ class Mlp(nn.Module):
74
+
75
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
76
+ super().__init__()
77
+ out_features = out_features or in_features
78
+ hidden_features = hidden_features or in_features
79
+ self.fc1 = nn.Linear(in_features, hidden_features)
80
+ self.act = act_layer()
81
+ self.fc2 = nn.Linear(hidden_features, out_features)
82
+ self.drop = nn.Dropout(drop)
83
+
84
+ def forward(self, x):
85
+ x = self.fc1(x)
86
+ x = self.act(x)
87
+ x = self.drop(x)
88
+ x = self.fc2(x)
89
+ x = self.drop(x)
90
+ return x
91
+
92
+
93
+ def window_partition(x, window_size):
94
+ """
95
+ Args:
96
+ x: (b, h, w, c)
97
+ window_size (int): window size
98
+
99
+ Returns:
100
+ windows: (num_windows*b, window_size, window_size, c)
101
+ """
102
+ b, h, w, c = x.shape
103
+ x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
104
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
105
+ return windows
106
+
107
+
108
+ def window_reverse(windows, window_size, h, w):
109
+ """
110
+ Args:
111
+ windows: (num_windows*b, window_size, window_size, c)
112
+ window_size (int): Window size
113
+ h (int): Height of image
114
+ w (int): Width of image
115
+
116
+ Returns:
117
+ x: (b, h, w, c)
118
+ """
119
+ b = int(windows.shape[0] / (h * w / window_size / window_size))
120
+ x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
121
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
122
+ return x
123
+
124
+
125
+ class WindowAttention(nn.Module):
126
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
127
+ It supports both of shifted and non-shifted window.
128
+
129
+ Args:
130
+ dim (int): Number of input channels.
131
+ window_size (tuple[int]): The height and width of the window.
132
+ num_heads (int): Number of attention heads.
133
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
134
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
135
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
136
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
137
+ """
138
+
139
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
140
+
141
+ super().__init__()
142
+ self.dim = dim
143
+ self.window_size = window_size # Wh, Ww
144
+ self.num_heads = num_heads
145
+ head_dim = dim // num_heads
146
+ self.scale = qk_scale or head_dim**-0.5
147
+
148
+ # define a parameter table of relative position bias
149
+ self.relative_position_bias_table = nn.Parameter(
150
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
151
+
152
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
153
+ self.attn_drop = nn.Dropout(attn_drop)
154
+ self.proj = nn.Linear(dim, dim)
155
+
156
+ self.proj_drop = nn.Dropout(proj_drop)
157
+
158
+ trunc_normal_(self.relative_position_bias_table, std=.02)
159
+ self.softmax = nn.Softmax(dim=-1)
160
+
161
+ def forward(self, x, rpi, mask=None):
162
+ """
163
+ Args:
164
+ x: input features with shape of (num_windows*b, n, c)
165
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
166
+ """
167
+ b_, n, c = x.shape
168
+ qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
169
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
170
+
171
+ q = q * self.scale
172
+ attn = (q @ k.transpose(-2, -1))
173
+
174
+ relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
175
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
176
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
177
+ attn = attn + relative_position_bias.unsqueeze(0)
178
+
179
+ if mask is not None:
180
+ nw = mask.shape[0]
181
+ attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
182
+ attn = attn.view(-1, self.num_heads, n, n)
183
+ attn = self.softmax(attn)
184
+ else:
185
+ attn = self.softmax(attn)
186
+
187
+ attn = self.attn_drop(attn)
188
+
189
+ x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
190
+ x = self.proj(x)
191
+ x = self.proj_drop(x)
192
+ return x
193
+
194
+
195
+ class HAB(nn.Module):
196
+ r""" Hybrid Attention Block.
197
+
198
+ Args:
199
+ dim (int): Number of input channels.
200
+ input_resolution (tuple[int]): Input resolution.
201
+ num_heads (int): Number of attention heads.
202
+ window_size (int): Window size.
203
+ shift_size (int): Shift size for SW-MSA.
204
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
205
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
206
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
207
+ drop (float, optional): Dropout rate. Default: 0.0
208
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
209
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
210
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
211
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
212
+ """
213
+
214
+ def __init__(self,
215
+ dim,
216
+ input_resolution,
217
+ num_heads,
218
+ window_size=7,
219
+ shift_size=0,
220
+ compress_ratio=3,
221
+ squeeze_factor=30,
222
+ conv_scale=0.01,
223
+ mlp_ratio=4.,
224
+ qkv_bias=True,
225
+ qk_scale=None,
226
+ drop=0.,
227
+ attn_drop=0.,
228
+ drop_path=0.,
229
+ act_layer=nn.GELU,
230
+ norm_layer=nn.LayerNorm):
231
+ super().__init__()
232
+ self.dim = dim
233
+ self.input_resolution = input_resolution
234
+ self.num_heads = num_heads
235
+ self.window_size = window_size
236
+ self.shift_size = shift_size
237
+ self.mlp_ratio = mlp_ratio
238
+ if min(self.input_resolution) <= self.window_size:
239
+ # if window size is larger than input resolution, we don't partition windows
240
+ self.shift_size = 0
241
+ self.window_size = min(self.input_resolution)
242
+ assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
243
+
244
+ self.norm1 = norm_layer(dim)
245
+ self.attn = WindowAttention(
246
+ dim,
247
+ window_size=to_2tuple(self.window_size),
248
+ num_heads=num_heads,
249
+ qkv_bias=qkv_bias,
250
+ qk_scale=qk_scale,
251
+ attn_drop=attn_drop,
252
+ proj_drop=drop)
253
+
254
+ self.conv_scale = conv_scale
255
+ self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor)
256
+
257
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
258
+ self.norm2 = norm_layer(dim)
259
+ mlp_hidden_dim = int(dim * mlp_ratio)
260
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
261
+
262
+ def forward(self, x, x_size, rpi_sa, attn_mask):
263
+ h, w = x_size
264
+ b, _, c = x.shape
265
+ # assert seq_len == h * w, "input feature has wrong size"
266
+
267
+ shortcut = x
268
+ x = self.norm1(x)
269
+ x = x.view(b, h, w, c)
270
+
271
+ # Conv_X
272
+ conv_x = self.conv_block(x.permute(0, 3, 1, 2))
273
+ conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c)
274
+
275
+ # cyclic shift
276
+ if self.shift_size > 0:
277
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
278
+ attn_mask = attn_mask
279
+ else:
280
+ shifted_x = x
281
+ attn_mask = None
282
+
283
+ # partition windows
284
+ x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
285
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
286
+
287
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
288
+ attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask)
289
+
290
+ # merge windows
291
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
292
+ shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
293
+
294
+ # reverse cyclic shift
295
+ if self.shift_size > 0:
296
+ attn_x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
297
+ else:
298
+ attn_x = shifted_x
299
+ attn_x = attn_x.view(b, h * w, c)
300
+
301
+ # FFN
302
+ x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale
303
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
304
+
305
+ return x
306
+
307
+
308
+ class PatchMerging(nn.Module):
309
+ r""" Patch Merging Layer.
310
+
311
+ Args:
312
+ input_resolution (tuple[int]): Resolution of input feature.
313
+ dim (int): Number of input channels.
314
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
315
+ """
316
+
317
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
318
+ super().__init__()
319
+ self.input_resolution = input_resolution
320
+ self.dim = dim
321
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
322
+ self.norm = norm_layer(4 * dim)
323
+
324
+ def forward(self, x):
325
+ """
326
+ x: b, h*w, c
327
+ """
328
+ h, w = self.input_resolution
329
+ b, seq_len, c = x.shape
330
+ assert seq_len == h * w, 'input feature has wrong size'
331
+ assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
332
+
333
+ x = x.view(b, h, w, c)
334
+
335
+ x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
336
+ x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
337
+ x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
338
+ x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
339
+ x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
340
+ x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
341
+
342
+ x = self.norm(x)
343
+ x = self.reduction(x)
344
+
345
+ return x
346
+
347
+
348
+ class OCAB(nn.Module):
349
+ # overlapping cross-attention block
350
+
351
+ def __init__(self, dim,
352
+ input_resolution,
353
+ window_size,
354
+ overlap_ratio,
355
+ num_heads,
356
+ qkv_bias=True,
357
+ qk_scale=None,
358
+ mlp_ratio=2,
359
+ norm_layer=nn.LayerNorm
360
+ ):
361
+
362
+ super().__init__()
363
+ self.dim = dim
364
+ self.input_resolution = input_resolution
365
+ self.window_size = window_size
366
+ self.num_heads = num_heads
367
+ head_dim = dim // num_heads
368
+ self.scale = qk_scale or head_dim**-0.5
369
+ self.overlap_win_size = int(window_size * overlap_ratio) + window_size
370
+
371
+ self.norm1 = norm_layer(dim)
372
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
373
+ self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size, padding=(self.overlap_win_size-window_size)//2)
374
+
375
+ # define a parameter table of relative position bias
376
+ self.relative_position_bias_table = nn.Parameter(
377
+ torch.zeros((window_size + self.overlap_win_size - 1) * (window_size + self.overlap_win_size - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
378
+
379
+ trunc_normal_(self.relative_position_bias_table, std=.02)
380
+ self.softmax = nn.Softmax(dim=-1)
381
+
382
+ self.proj = nn.Linear(dim,dim)
383
+
384
+ self.norm2 = norm_layer(dim)
385
+ mlp_hidden_dim = int(dim * mlp_ratio)
386
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU)
387
+
388
+ def forward(self, x, x_size, rpi):
389
+ h, w = x_size
390
+ b, _, c = x.shape
391
+
392
+ shortcut = x
393
+ x = self.norm1(x)
394
+ x = x.view(b, h, w, c)
395
+
396
+ qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2) # 3, b, c, h, w
397
+ q = qkv[0].permute(0, 2, 3, 1) # b, h, w, c
398
+ kv = torch.cat((qkv[1], qkv[2]), dim=1) # b, 2*c, h, w
399
+
400
+ # partition windows
401
+ q_windows = window_partition(q, self.window_size) # nw*b, window_size, window_size, c
402
+ q_windows = q_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
403
+
404
+ kv_windows = self.unfold(kv) # b, c*w*w, nw
405
+ kv_windows = rearrange(kv_windows, 'b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch', nc=2, ch=c, owh=self.overlap_win_size, oww=self.overlap_win_size).contiguous() # 2, nw*b, ow*ow, c
406
+ k_windows, v_windows = kv_windows[0], kv_windows[1] # nw*b, ow*ow, c
407
+
408
+ b_, nq, _ = q_windows.shape
409
+ _, n, _ = k_windows.shape
410
+ d = self.dim // self.num_heads
411
+ q = q_windows.reshape(b_, nq, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, nq, d
412
+ k = k_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, n, d
413
+ v = v_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, n, d
414
+
415
+ q = q * self.scale
416
+ attn = (q @ k.transpose(-2, -1))
417
+
418
+ relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
419
+ self.window_size * self.window_size, self.overlap_win_size * self.overlap_win_size, -1) # ws*ws, wse*wse, nH
420
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, ws*ws, wse*wse
421
+ attn = attn + relative_position_bias.unsqueeze(0)
422
+
423
+ attn = self.softmax(attn)
424
+ attn_windows = (attn @ v).transpose(1, 2).reshape(b_, nq, self.dim)
425
+
426
+ # merge windows
427
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim)
428
+ x = window_reverse(attn_windows, self.window_size, h, w) # b h w c
429
+ x = x.view(b, h * w, self.dim)
430
+
431
+ x = self.proj(x) + shortcut
432
+
433
+ x = x + self.mlp(self.norm2(x))
434
+ return x
435
+
436
+
437
+ class AttenBlocks(nn.Module):
438
+ """ A series of attention blocks for one RHAG.
439
+
440
+ Args:
441
+ dim (int): Number of input channels.
442
+ input_resolution (tuple[int]): Input resolution.
443
+ depth (int): Number of blocks.
444
+ num_heads (int): Number of attention heads.
445
+ window_size (int): Local window size.
446
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
447
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
448
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
449
+ drop (float, optional): Dropout rate. Default: 0.0
450
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
451
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
452
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
453
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
454
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
455
+ """
456
+
457
+ def __init__(self,
458
+ dim,
459
+ input_resolution,
460
+ depth,
461
+ num_heads,
462
+ window_size,
463
+ compress_ratio,
464
+ squeeze_factor,
465
+ conv_scale,
466
+ overlap_ratio,
467
+ mlp_ratio=4.,
468
+ qkv_bias=True,
469
+ qk_scale=None,
470
+ drop=0.,
471
+ attn_drop=0.,
472
+ drop_path=0.,
473
+ norm_layer=nn.LayerNorm,
474
+ downsample=None,
475
+ use_checkpoint=False):
476
+
477
+ super().__init__()
478
+ self.dim = dim
479
+ self.input_resolution = input_resolution
480
+ self.depth = depth
481
+ self.use_checkpoint = use_checkpoint
482
+
483
+ # build blocks
484
+ self.blocks = nn.ModuleList([
485
+ HAB(
486
+ dim=dim,
487
+ input_resolution=input_resolution,
488
+ num_heads=num_heads,
489
+ window_size=window_size,
490
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
491
+ compress_ratio=compress_ratio,
492
+ squeeze_factor=squeeze_factor,
493
+ conv_scale=conv_scale,
494
+ mlp_ratio=mlp_ratio,
495
+ qkv_bias=qkv_bias,
496
+ qk_scale=qk_scale,
497
+ drop=drop,
498
+ attn_drop=attn_drop,
499
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
500
+ norm_layer=norm_layer) for i in range(depth)
501
+ ])
502
+
503
+ # OCAB
504
+ self.overlap_attn = OCAB(
505
+ dim=dim,
506
+ input_resolution=input_resolution,
507
+ window_size=window_size,
508
+ overlap_ratio=overlap_ratio,
509
+ num_heads=num_heads,
510
+ qkv_bias=qkv_bias,
511
+ qk_scale=qk_scale,
512
+ mlp_ratio=mlp_ratio,
513
+ norm_layer=norm_layer
514
+ )
515
+
516
+ # patch merging layer
517
+ if downsample is not None:
518
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
519
+ else:
520
+ self.downsample = None
521
+
522
+ def forward(self, x, x_size, params):
523
+ for blk in self.blocks:
524
+ x = blk(x, x_size, params['rpi_sa'], params['attn_mask'])
525
+
526
+ x = self.overlap_attn(x, x_size, params['rpi_oca'])
527
+
528
+ if self.downsample is not None:
529
+ x = self.downsample(x)
530
+ return x
531
+
532
+
533
+ class RHAG(nn.Module):
534
+ """Residual Hybrid Attention Group (RHAG).
535
+
536
+ Args:
537
+ dim (int): Number of input channels.
538
+ input_resolution (tuple[int]): Input resolution.
539
+ depth (int): Number of blocks.
540
+ num_heads (int): Number of attention heads.
541
+ window_size (int): Local window size.
542
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
543
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
544
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
545
+ drop (float, optional): Dropout rate. Default: 0.0
546
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
547
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
548
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
549
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
550
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
551
+ img_size: Input image size.
552
+ patch_size: Patch size.
553
+ resi_connection: The convolutional block before residual connection.
554
+ """
555
+
556
+ def __init__(self,
557
+ dim,
558
+ input_resolution,
559
+ depth,
560
+ num_heads,
561
+ window_size,
562
+ compress_ratio,
563
+ squeeze_factor,
564
+ conv_scale,
565
+ overlap_ratio,
566
+ mlp_ratio=4.,
567
+ qkv_bias=True,
568
+ qk_scale=None,
569
+ drop=0.,
570
+ attn_drop=0.,
571
+ drop_path=0.,
572
+ norm_layer=nn.LayerNorm,
573
+ downsample=None,
574
+ use_checkpoint=False,
575
+ img_size=224,
576
+ patch_size=4,
577
+ resi_connection='1conv'):
578
+ super(RHAG, self).__init__()
579
+
580
+ self.dim = dim
581
+ self.input_resolution = input_resolution
582
+
583
+ self.residual_group = AttenBlocks(
584
+ dim=dim,
585
+ input_resolution=input_resolution,
586
+ depth=depth,
587
+ num_heads=num_heads,
588
+ window_size=window_size,
589
+ compress_ratio=compress_ratio,
590
+ squeeze_factor=squeeze_factor,
591
+ conv_scale=conv_scale,
592
+ overlap_ratio=overlap_ratio,
593
+ mlp_ratio=mlp_ratio,
594
+ qkv_bias=qkv_bias,
595
+ qk_scale=qk_scale,
596
+ drop=drop,
597
+ attn_drop=attn_drop,
598
+ drop_path=drop_path,
599
+ norm_layer=norm_layer,
600
+ downsample=downsample,
601
+ use_checkpoint=use_checkpoint)
602
+
603
+ if resi_connection == '1conv':
604
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
605
+ elif resi_connection == 'identity':
606
+ self.conv = nn.Identity()
607
+
608
+ self.patch_embed = PatchEmbed(
609
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
610
+
611
+ self.patch_unembed = PatchUnEmbed(
612
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
613
+
614
+ def forward(self, x, x_size, params):
615
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + x
616
+
617
+
618
+ class PatchEmbed(nn.Module):
619
+ r""" Image to Patch Embedding
620
+
621
+ Args:
622
+ img_size (int): Image size. Default: 224.
623
+ patch_size (int): Patch token size. Default: 4.
624
+ in_chans (int): Number of input image channels. Default: 3.
625
+ embed_dim (int): Number of linear projection output channels. Default: 96.
626
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
627
+ """
628
+
629
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
630
+ super().__init__()
631
+ img_size = to_2tuple(img_size)
632
+ patch_size = to_2tuple(patch_size)
633
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
634
+ self.img_size = img_size
635
+ self.patch_size = patch_size
636
+ self.patches_resolution = patches_resolution
637
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
638
+
639
+ self.in_chans = in_chans
640
+ self.embed_dim = embed_dim
641
+
642
+ if norm_layer is not None:
643
+ self.norm = norm_layer(embed_dim)
644
+ else:
645
+ self.norm = None
646
+
647
+ def forward(self, x):
648
+ x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
649
+ if self.norm is not None:
650
+ x = self.norm(x)
651
+ return x
652
+
653
+
654
+ class PatchUnEmbed(nn.Module):
655
+ r""" Image to Patch Unembedding
656
+
657
+ Args:
658
+ img_size (int): Image size. Default: 224.
659
+ patch_size (int): Patch token size. Default: 4.
660
+ in_chans (int): Number of input image channels. Default: 3.
661
+ embed_dim (int): Number of linear projection output channels. Default: 96.
662
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
663
+ """
664
+
665
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
666
+ super().__init__()
667
+ img_size = to_2tuple(img_size)
668
+ patch_size = to_2tuple(patch_size)
669
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
670
+ self.img_size = img_size
671
+ self.patch_size = patch_size
672
+ self.patches_resolution = patches_resolution
673
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
674
+
675
+ self.in_chans = in_chans
676
+ self.embed_dim = embed_dim
677
+
678
+ def forward(self, x, x_size):
679
+ x = x.transpose(1, 2).contiguous().view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
680
+ return x
681
+
682
+
683
+ class Upsample(nn.Sequential):
684
+ """Upsample module.
685
+
686
+ Args:
687
+ scale (int): Scale factor. Supported scales: 2^n and 3.
688
+ num_feat (int): Channel number of intermediate features.
689
+ """
690
+
691
+ def __init__(self, scale, num_feat):
692
+ m = []
693
+ if (scale & (scale - 1)) == 0: # scale = 2^n
694
+ for _ in range(int(math.log(scale, 2))):
695
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
696
+ m.append(nn.PixelShuffle(2))
697
+ elif scale == 3:
698
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
699
+ m.append(nn.PixelShuffle(3))
700
+ else:
701
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
702
+ super(Upsample, self).__init__(*m)
703
+
704
+
705
+ class HATArch(nn.Module):
706
+ r""" Hybrid Attention Transformer
707
+ A PyTorch implementation of : `Activating More Pixels in Image Super-Resolution Transformer`.
708
+ Some codes are based on SwinIR.
709
+ Args:
710
+ img_size (int | tuple(int)): Input image size. Default 64
711
+ patch_size (int | tuple(int)): Patch size. Default: 1
712
+ in_chans (int): Number of input image channels. Default: 3
713
+ embed_dim (int): Patch embedding dimension. Default: 96
714
+ depths (tuple(int)): Depth of each Swin Transformer layer.
715
+ num_heads (tuple(int)): Number of attention heads in different layers.
716
+ window_size (int): Window size. Default: 7
717
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
718
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
719
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
720
+ drop_rate (float): Dropout rate. Default: 0
721
+ attn_drop_rate (float): Attention dropout rate. Default: 0
722
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
723
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
724
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
725
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
726
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
727
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
728
+ img_range: Image range. 1. or 255.
729
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
730
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
731
+ """
732
+
733
+ def __init__(self,
734
+ img_size=64,
735
+ patch_size=1,
736
+ in_chans=3,
737
+ embed_dim=96,
738
+ depths=(6, 6, 6, 6),
739
+ num_heads=(6, 6, 6, 6),
740
+ window_size=7,
741
+ compress_ratio=3,
742
+ squeeze_factor=30,
743
+ conv_scale=0.01,
744
+ overlap_ratio=0.5,
745
+ mlp_ratio=4.,
746
+ qkv_bias=True,
747
+ qk_scale=None,
748
+ drop_rate=0.,
749
+ attn_drop_rate=0.,
750
+ drop_path_rate=0.1,
751
+ norm_layer=nn.LayerNorm,
752
+ ape=False,
753
+ patch_norm=True,
754
+ use_checkpoint=False,
755
+ upscale=2,
756
+ img_range=1.,
757
+ upsampler='',
758
+ resi_connection='1conv',
759
+ **kwargs):
760
+ super(HATArch, self).__init__()
761
+
762
+ self.window_size = window_size
763
+ self.shift_size = window_size // 2
764
+ self.overlap_ratio = overlap_ratio
765
+
766
+ num_in_ch = in_chans
767
+ num_out_ch = in_chans
768
+ num_feat = 64
769
+ self.img_range = img_range
770
+ if in_chans == 3:
771
+ rgb_mean = (0.4488, 0.4371, 0.4040)
772
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
773
+ else:
774
+ self.mean = torch.zeros(1, 1, 1, 1)
775
+ self.upscale = upscale
776
+ self.upsampler = upsampler
777
+
778
+ # relative position index
779
+ relative_position_index_SA = self.calculate_rpi_sa()
780
+ relative_position_index_OCA = self.calculate_rpi_oca()
781
+ self.register_buffer('relative_position_index_SA', relative_position_index_SA)
782
+ self.register_buffer('relative_position_index_OCA', relative_position_index_OCA)
783
+
784
+ # ------------------------- 1, shallow feature extraction ------------------------- #
785
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
786
+
787
+ # ------------------------- 2, deep feature extraction ------------------------- #
788
+ self.num_layers = len(depths)
789
+ self.embed_dim = embed_dim
790
+ self.ape = ape
791
+ self.patch_norm = patch_norm
792
+ self.num_features = embed_dim
793
+ self.mlp_ratio = mlp_ratio
794
+
795
+ # split image into non-overlapping patches
796
+ self.patch_embed = PatchEmbed(
797
+ img_size=img_size,
798
+ patch_size=patch_size,
799
+ in_chans=embed_dim,
800
+ embed_dim=embed_dim,
801
+ norm_layer=norm_layer if self.patch_norm else None)
802
+ num_patches = self.patch_embed.num_patches
803
+ patches_resolution = self.patch_embed.patches_resolution
804
+ self.patches_resolution = patches_resolution
805
+
806
+ # merge non-overlapping patches into image
807
+ self.patch_unembed = PatchUnEmbed(
808
+ img_size=img_size,
809
+ patch_size=patch_size,
810
+ in_chans=embed_dim,
811
+ embed_dim=embed_dim,
812
+ norm_layer=norm_layer if self.patch_norm else None)
813
+
814
+ # absolute position embedding
815
+ if self.ape:
816
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
817
+ trunc_normal_(self.absolute_pos_embed, std=.02)
818
+
819
+ self.pos_drop = nn.Dropout(p=drop_rate)
820
+
821
+ # stochastic depth
822
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
823
+
824
+ # build Residual Hybrid Attention Groups (RHAG)
825
+ self.layers = nn.ModuleList()
826
+ for i_layer in range(self.num_layers):
827
+ layer = RHAG(
828
+ dim=embed_dim,
829
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
830
+ depth=depths[i_layer],
831
+ num_heads=num_heads[i_layer],
832
+ window_size=window_size,
833
+ compress_ratio=compress_ratio,
834
+ squeeze_factor=squeeze_factor,
835
+ conv_scale=conv_scale,
836
+ overlap_ratio=overlap_ratio,
837
+ mlp_ratio=self.mlp_ratio,
838
+ qkv_bias=qkv_bias,
839
+ qk_scale=qk_scale,
840
+ drop=drop_rate,
841
+ attn_drop=attn_drop_rate,
842
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
843
+ norm_layer=norm_layer,
844
+ downsample=None,
845
+ use_checkpoint=use_checkpoint,
846
+ img_size=img_size,
847
+ patch_size=patch_size,
848
+ resi_connection=resi_connection)
849
+ self.layers.append(layer)
850
+ self.norm = norm_layer(self.num_features)
851
+
852
+ # build the last conv layer in deep feature extraction
853
+ if resi_connection == '1conv':
854
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
855
+ elif resi_connection == 'identity':
856
+ self.conv_after_body = nn.Identity()
857
+
858
+ # ------------------------- 3, high quality image reconstruction ------------------------- #
859
+ if self.upsampler == 'pixelshuffle':
860
+ # for classical SR
861
+ self.conv_before_upsample = nn.Sequential(
862
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
863
+ self.upsample = Upsample(upscale, num_feat)
864
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
865
+
866
+ self.apply(self._init_weights)
867
+
868
+ def _init_weights(self, m):
869
+ if isinstance(m, nn.Linear):
870
+ trunc_normal_(m.weight, std=.02)
871
+ if isinstance(m, nn.Linear) and m.bias is not None:
872
+ nn.init.constant_(m.bias, 0)
873
+ elif isinstance(m, nn.LayerNorm):
874
+ nn.init.constant_(m.bias, 0)
875
+ nn.init.constant_(m.weight, 1.0)
876
+
877
+ def calculate_rpi_sa(self):
878
+ # calculate relative position index for SA
879
+ coords_h = torch.arange(self.window_size)
880
+ coords_w = torch.arange(self.window_size)
881
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
882
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
883
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
884
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
885
+ relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
886
+ relative_coords[:, :, 1] += self.window_size - 1
887
+ relative_coords[:, :, 0] *= 2 * self.window_size - 1
888
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
889
+ return relative_position_index
890
+
891
+ def calculate_rpi_oca(self):
892
+ # calculate relative position index for OCA
893
+ window_size_ori = self.window_size
894
+ window_size_ext = self.window_size + int(self.overlap_ratio * self.window_size)
895
+
896
+ coords_h = torch.arange(window_size_ori)
897
+ coords_w = torch.arange(window_size_ori)
898
+ coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, ws, ws
899
+ coords_ori_flatten = torch.flatten(coords_ori, 1) # 2, ws*ws
900
+
901
+ coords_h = torch.arange(window_size_ext)
902
+ coords_w = torch.arange(window_size_ext)
903
+ coords_ext = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, wse, wse
904
+ coords_ext_flatten = torch.flatten(coords_ext, 1) # 2, wse*wse
905
+
906
+ relative_coords = coords_ext_flatten[:, None, :] - coords_ori_flatten[:, :, None] # 2, ws*ws, wse*wse
907
+
908
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # ws*ws, wse*wse, 2
909
+ relative_coords[:, :, 0] += window_size_ori - window_size_ext + 1 # shift to start from 0
910
+ relative_coords[:, :, 1] += window_size_ori - window_size_ext + 1
911
+
912
+ relative_coords[:, :, 0] *= window_size_ori + window_size_ext - 1
913
+ relative_position_index = relative_coords.sum(-1)
914
+ return relative_position_index
915
+
916
+ def calculate_mask(self, x_size):
917
+ # calculate attention mask for SW-MSA
918
+ h, w = x_size
919
+ img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
920
+ h_slices = (slice(0, -self.window_size), slice(-self.window_size,
921
+ -self.shift_size), slice(-self.shift_size, None))
922
+ w_slices = (slice(0, -self.window_size), slice(-self.window_size,
923
+ -self.shift_size), slice(-self.shift_size, None))
924
+ cnt = 0
925
+ for h in h_slices:
926
+ for w in w_slices:
927
+ img_mask[:, h, w, :] = cnt
928
+ cnt += 1
929
+
930
+ mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
931
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
932
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
933
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
934
+
935
+ return attn_mask
936
+
937
+ @torch.jit.ignore
938
+ def no_weight_decay(self):
939
+ return {'absolute_pos_embed'}
940
+
941
+ @torch.jit.ignore
942
+ def no_weight_decay_keywords(self):
943
+ return {'relative_position_bias_table'}
944
+
945
+ def forward_features(self, x):
946
+ x_size = (x.shape[2], x.shape[3])
947
+
948
+ # Calculate attention mask and relative position index in advance to speed up inference.
949
+ # The original code is very time-consuming for large window size.
950
+ attn_mask = self.calculate_mask(x_size).to(x.device)
951
+ params = {'attn_mask': attn_mask, 'rpi_sa': self.relative_position_index_SA, 'rpi_oca': self.relative_position_index_OCA}
952
+
953
+ x = self.patch_embed(x)
954
+ if self.ape:
955
+ x = x + self.absolute_pos_embed
956
+ x = self.pos_drop(x)
957
+
958
+ for layer in self.layers:
959
+ x = layer(x, x_size, params)
960
+
961
+ x = self.norm(x) # b seq_len c
962
+ x = self.patch_unembed(x, x_size)
963
+
964
+ return x
965
+
966
+ def forward(self, x):
967
+ self.mean = self.mean.type_as(x)
968
+ x = (x - self.mean) * self.img_range
969
+
970
+ if self.upsampler == 'pixelshuffle':
971
+ # for classical SR
972
+ x = self.conv_first(x)
973
+ x = self.conv_after_body(self.forward_features(x)) + x
974
+ x = self.conv_before_upsample(x)
975
+ x = self.conv_last(self.upsample(x))
976
+
977
+ x = x / self.img_range + self.mean
978
+
979
+ return x
SR_Inference/hat/weights/HAT-L_SRx2_ImageNet-pretrain.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2818c7ca8d72ec4cc5f31c93203d55252a662dd35cda34ce1a69661f97dcd38f
3
+ size 165182573
SR_Inference/hat/weights/HAT_SRx2_ImageNet-pretrain.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82ebd911263bcc886fbef46b30cf97b92a932a27a3cba30163d4577afb09b9d7
3
+ size 84546053
SR_Inference/hat/weights/HAT_SRx4_ImageNet-pretrain.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ee053c42461187846dc0e93aa5abd34591c0725a8e044a59000e92ee215e833
3
+ size 85137601
SR_Inference/inference_codeformer.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import sys
4
+ import torch
5
+ import os.path as osp
6
+ from basicsr.utils import img2tensor, tensor2img
7
+ from torchvision.transforms.functional import normalize
8
+ from facexlib.utils.face_restoration_helper import FaceRestoreHelper
9
+
10
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11
+
12
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
13
+ sys.path.append(root_path)
14
+ from SR_Inference.codeformer.codeformer_arch import CodeFormerArch
15
+ from SR_Inference.inference_sr_utils import RealEsrUpsamplerZoo
16
+
17
+
18
+ class CodeFormer:
19
+
20
+ def __init__(
21
+ self,
22
+ upscale=2,
23
+ bg_upsampler_name="realesrgan",
24
+ prefered_net_in_upsampler="RRDBNet",
25
+ fidelity_weight=0.8,
26
+ ):
27
+
28
+ self.upscale = int(upscale)
29
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ self.fidelity_weight = fidelity_weight
31
+
32
+ # ------------------------ set up background upsampler ------------------------
33
+ upsampler_zoo = RealEsrUpsamplerZoo(
34
+ upscale=self.upscale,
35
+ bg_upsampler_name=bg_upsampler_name,
36
+ prefered_net_in_upsampler=prefered_net_in_upsampler,
37
+ )
38
+ self.bg_upsampler = upsampler_zoo.bg_upsampler
39
+
40
+ # ------------------ set up FaceRestoreHelper -------------------
41
+ gfpgan_weights_path = os.path.join(
42
+ ROOT_DIR, "SR_Inference", "gfpgan", "weights"
43
+ )
44
+ self.face_restorer_helper = FaceRestoreHelper(
45
+ upscale_factor=self.upscale,
46
+ face_size=512,
47
+ crop_ratio=(1, 1),
48
+ det_model="retinaface_resnet50",
49
+ save_ext="png",
50
+ use_parse=True,
51
+ device=self.device,
52
+ # model_rootpath="gfpgan/weights",
53
+ model_rootpath=gfpgan_weights_path,
54
+ )
55
+
56
+ # ------------------ load model -------------------
57
+ self.sr_model = CodeFormerArch().to(self.device)
58
+ ckpt_path = os.path.join(
59
+ ROOT_DIR, "SR_Inference", "codeformer", "weights", "codeformer_v0.1.0.pth"
60
+ )
61
+ loadnet = torch.load(ckpt_path, map_location=self.device)
62
+ if "params_ema" in loadnet:
63
+ keyname = "params_ema"
64
+ else:
65
+ keyname = "params"
66
+
67
+ self.sr_model.load_state_dict(loadnet[keyname])
68
+ self.sr_model.eval()
69
+
70
+ @torch.no_grad()
71
+ def __call__(self, img):
72
+
73
+ bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
74
+
75
+ self.face_restorer_helper.clean_all()
76
+ self.face_restorer_helper.read_image(img)
77
+ self.face_restorer_helper.get_face_landmarks_5(
78
+ only_keep_largest=True, only_center_face=False, eye_dist_threshold=5
79
+ )
80
+ self.face_restorer_helper.align_warp_face()
81
+
82
+ if len(self.face_restorer_helper.cropped_faces) > 0:
83
+
84
+ cropped_face = self.face_restorer_helper.cropped_faces[0]
85
+
86
+ cropped_face_t = img2tensor(
87
+ imgs=cropped_face / 255.0, bgr2rgb=True, float32=True
88
+ )
89
+ normalize(
90
+ tensor=cropped_face_t,
91
+ mean=(0.5, 0.5, 0.5),
92
+ std=(0.5, 0.5, 0.5),
93
+ inplace=True,
94
+ )
95
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
96
+
97
+ # ------------------- restore/enhance image using CodeFormerArch model -------------------
98
+ output = self.sr_model(cropped_face_t, w=self.fidelity_weight, adain=True)[
99
+ 0
100
+ ]
101
+
102
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
103
+ restored_face = restored_face.astype("uint8")
104
+
105
+ self.face_restorer_helper.add_restored_face(restored_face)
106
+ self.face_restorer_helper.get_inverse_affine(None)
107
+
108
+ sr_img = self.face_restorer_helper.paste_faces_to_input_image(
109
+ upsample_img=bg_img
110
+ )
111
+ else:
112
+ sr_img = bg_img
113
+
114
+ return sr_img
115
+
116
+
117
+ if __name__ == "__main__":
118
+
119
+ codeformer = CodeFormer(upscale=2, fidelity_weight=1.0)
120
+
121
+ img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
122
+ sr_img = codeformer(img=img)
123
+
124
+ saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
125
+ os.makedirs(saving_dir, exist_ok=True)
126
+ cv2.imwrite(f"{saving_dir}/sr_img.png", sr_img)
SR_Inference/inference_gfpgan.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import sys
4
+ import torch
5
+ import os.path as osp
6
+ from gfpgan import GFPGANer
7
+ from basicsr.utils.download_util import load_file_from_url
8
+
9
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
10
+
11
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
12
+ sys.path.append(root_path)
13
+ from SR_Inference.inference_sr_utils import RealEsrUpsamplerZoo
14
+
15
+
16
+ class GFPGAN:
17
+
18
+ def __init__(
19
+ self,
20
+ upscale=2,
21
+ bg_upsampler_name="realesrgan",
22
+ prefered_net_in_upsampler="RRDBNet",
23
+ ):
24
+
25
+ upscale = int(upscale)
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+
28
+ # ------------------------ set up background upsampler ------------------------
29
+ upsampler_zoo = RealEsrUpsamplerZoo(
30
+ upscale=upscale,
31
+ bg_upsampler_name=bg_upsampler_name,
32
+ prefered_net_in_upsampler=prefered_net_in_upsampler,
33
+ )
34
+ bg_upsampler = upsampler_zoo.bg_upsampler
35
+
36
+ # ------------------------ load model ------------------------
37
+ gfpgan_weights_path = os.path.join(
38
+ ROOT_DIR, "SR_Inference", "gfpgan", "weights"
39
+ )
40
+ gfpgan_model_path = os.path.join(gfpgan_weights_path, "GFPGANv1.3.pth")
41
+
42
+ if not os.path.isfile(gfpgan_model_path):
43
+ url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"
44
+ gfpgan_model_path = load_file_from_url(
45
+ url=url,
46
+ model_dir=gfpgan_weights_path,
47
+ progress=True,
48
+ file_name="GFPGANv1.3.pth",
49
+ )
50
+
51
+ self.sr_model = GFPGANer(
52
+ upscale=upscale,
53
+ bg_upsampler=bg_upsampler,
54
+ model_path=gfpgan_model_path,
55
+ device=device,
56
+ )
57
+
58
+ def __call__(self, img):
59
+ # ------------------------ restore/enhance image using GFPGAN model ------------------------
60
+ cropped_faces, sr_faces, sr_img = self.sr_model.enhance(img)
61
+
62
+ return sr_img
63
+
64
+
65
+ if __name__ == "__main__":
66
+
67
+ gfpgan = GFPGAN(
68
+ upscale=2, bg_upsampler_name="realesrgan", prefered_net_in_upsampler="RRDBNet"
69
+ )
70
+
71
+ img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
72
+ sr_img = gfpgan(img=img)
73
+
74
+ saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
75
+ os.makedirs(saving_dir, exist_ok=True)
76
+ cv2.imwrite(f"{saving_dir}/sr_img_gfpgan.png", sr_img)
SR_Inference/inference_hat.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import sys
4
+ import torch
5
+ import numpy as np
6
+ import os.path as osp
7
+ from PIL import Image
8
+ from basicsr.utils import img2tensor
9
+
10
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11
+
12
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
13
+ sys.path.append(root_path)
14
+ from SR_Inference.hat.hat_arch import HATArch
15
+
16
+
17
+ class HAT:
18
+
19
+ def __init__(
20
+ self,
21
+ upscale=2,
22
+ in_chans=3,
23
+ img_size=(480, 640),
24
+ window_size=16,
25
+ compress_ratio=3,
26
+ squeeze_factor=30,
27
+ conv_scale=0.01,
28
+ overlap_ratio=0.5,
29
+ img_range=1.0,
30
+ depths=[6, 6, 6, 6, 6, 6],
31
+ embed_dim=180,
32
+ num_heads=[6, 6, 6, 6, 6, 6],
33
+ mlp_ratio=2,
34
+ upsampler="pixelshuffle",
35
+ resi_connection="1conv",
36
+ ):
37
+ upscale = int(upscale)
38
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+ # ------------------ load model for img enhancement -------------------
41
+ self.sr_model = HATArch(
42
+ img_size=img_size,
43
+ upscale=upscale,
44
+ in_chans=in_chans,
45
+ window_size=window_size,
46
+ compress_ratio=compress_ratio,
47
+ squeeze_factor=squeeze_factor,
48
+ conv_scale=conv_scale,
49
+ overlap_ratio=overlap_ratio,
50
+ img_range=img_range,
51
+ depths=depths,
52
+ embed_dim=embed_dim,
53
+ num_heads=num_heads,
54
+ mlp_ratio=mlp_ratio,
55
+ upsampler=upsampler,
56
+ resi_connection=resi_connection,
57
+ ).to(self.device)
58
+
59
+ ckpt_path = os.path.join(
60
+ ROOT_DIR,
61
+ "SR_Inference",
62
+ "hat",
63
+ "weights",
64
+ f"HAT_SRx{str(upscale)}_ImageNet-pretrain.pth",
65
+ )
66
+ loadnet = torch.load(ckpt_path, map_location=self.device)
67
+ if "params_ema" in loadnet:
68
+ keyname = "params_ema"
69
+ else:
70
+ keyname = "params"
71
+
72
+ self.sr_model.load_state_dict(loadnet[keyname])
73
+ self.sr_model.eval()
74
+
75
+ @torch.no_grad()
76
+ def __call__(self, img):
77
+ img_tensor = (
78
+ img2tensor(imgs=img / 255.0, bgr2rgb=True, float32=True)
79
+ .unsqueeze(0)
80
+ .to(self.device)
81
+ )
82
+ restored_img = self.sr_model(img_tensor)[0]
83
+ restored_img = restored_img.permute(1, 2, 0).cpu().numpy()
84
+ restored_img = (restored_img - restored_img.min()) / (
85
+ restored_img.max() - restored_img.min()
86
+ )
87
+ restored_img = (restored_img * 255).astype(np.uint8)
88
+ restored_img = Image.fromarray(restored_img)
89
+ restored_img = np.array(restored_img)
90
+ sr_img = cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR)
91
+
92
+ return sr_img
93
+
94
+
95
+ if __name__ == "__main__":
96
+
97
+ hat = HAT(upscale=2)
98
+
99
+ img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
100
+ sr_img = hat(img=img)
101
+
102
+ saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
103
+ os.makedirs(saving_dir, exist_ok=True)
104
+ cv2.imwrite(f"{saving_dir}/sr_img_hat.png", sr_img)
SR_Inference/inference_realesr.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import sys
4
+ import torch
5
+ import os.path as osp
6
+
7
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
8
+
9
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
10
+ sys.path.append(root_path)
11
+ from SR_Inference.inference_sr_utils import RealEsrUpsamplerZoo
12
+
13
+
14
+ class RealEsr:
15
+
16
+ def __init__(
17
+ self,
18
+ upscale=2,
19
+ bg_upsampler_name="realesrgan",
20
+ prefered_net_in_upsampler="RRDBNet",
21
+ ):
22
+
23
+ self.upscale = int(upscale)
24
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+ # ------------------------ set up background upsampler ------------------------
27
+ self.upsampler_zoo = RealEsrUpsamplerZoo(
28
+ upscale=self.upscale,
29
+ bg_upsampler_name=bg_upsampler_name,
30
+ prefered_net_in_upsampler=prefered_net_in_upsampler,
31
+ )
32
+ self.bg_upsampler = self.upsampler_zoo.bg_upsampler
33
+
34
+ def __call__(self, img):
35
+ # ---------------- restore/enhance image using the selected RealESR model ----------------
36
+ sr_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
37
+
38
+ return sr_img
39
+
40
+
41
+ if __name__ == "__main__":
42
+
43
+ realesr = RealEsr(
44
+ upscale=2, bg_upsampler_name="realesrgan", prefered_net_in_upsampler="RRDBNet"
45
+ )
46
+
47
+ img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
48
+ sr_img = realesr(img=img)
49
+
50
+ saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
51
+ os.makedirs(saving_dir, exist_ok=True)
52
+ cv2.imwrite(f"{saving_dir}/sr_img.png", sr_img)
SR_Inference/inference_sr_utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ from realesrgan import RealESRGANer
5
+ from basicsr.archs.rrdbnet_arch import RRDBNet
6
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
7
+ from basicsr.utils.download_util import load_file_from_url
8
+
9
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
10
+
11
+
12
+ class RealEsrUpsamplerZoo:
13
+
14
+ def __init__(
15
+ self,
16
+ upscale=2,
17
+ bg_upsampler_name="realesrgan",
18
+ prefered_net_in_upsampler="RRDBNet",
19
+ ):
20
+
21
+ self.upscale = int(upscale)
22
+
23
+ # ------------------------ set up background upsampler ------------------------
24
+ weights_path = os.path.join(
25
+ ROOT_DIR, "SR_Inference", f"{bg_upsampler_name}", "weights"
26
+ )
27
+
28
+ if bg_upsampler_name == "realesrgan":
29
+ model = self.get_prefered_net(prefered_net_in_upsampler, upscale)
30
+ if self.upscale == 2:
31
+ model_path = os.path.join(weights_path, "RealESRGAN_x2plus.pth")
32
+ url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
33
+ elif self.upscale == 4:
34
+ model_path = os.path.join(weights_path, "RealESRGAN_x4plus.pth")
35
+ url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
36
+ else:
37
+ raise Exception(
38
+ f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}"
39
+ )
40
+ elif bg_upsampler_name == "realesrnet":
41
+ model = self.get_prefered_net(prefered_net_in_upsampler, upscale)
42
+ if self.upscale == 4:
43
+ model_path = os.path.join(weights_path, "RealESRNet_x4plus.pth")
44
+ url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
45
+ else:
46
+ raise Exception(
47
+ f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}"
48
+ )
49
+ elif bg_upsampler_name == "anime":
50
+ model = self.get_prefered_net(prefered_net_in_upsampler, upscale)
51
+ if self.upscale == 4:
52
+ model_path = os.path.join(
53
+ weights_path, "RealESRGAN_x4plus_anime_6B.pth"
54
+ )
55
+ url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
56
+ else:
57
+ raise Exception(
58
+ f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}"
59
+ )
60
+ else:
61
+ raise Exception(f"No model implemented for: {bg_upsampler_name}")
62
+
63
+ # ------------------------ load background upsampler model ------------------------
64
+ if not os.path.isfile(model_path):
65
+ model_path = load_file_from_url(
66
+ url=url, model_dir=weights_path, progress=True, file_name=None
67
+ )
68
+
69
+ self.bg_upsampler = RealESRGANer(
70
+ scale=int(upscale),
71
+ model_path=model_path,
72
+ model=model,
73
+ tile=0,
74
+ tile_pad=0,
75
+ pre_pad=0,
76
+ half=False,
77
+ )
78
+
79
+ @staticmethod
80
+ def get_prefered_net(prefered_net_in_upsampler, upscale=2):
81
+ if prefered_net_in_upsampler == "RRDBNet":
82
+ model = RRDBNet(
83
+ num_in_ch=3,
84
+ num_out_ch=3,
85
+ num_feat=64,
86
+ num_block=23,
87
+ num_grow_ch=32,
88
+ scale=int(upscale),
89
+ )
90
+ elif prefered_net_in_upsampler == "SRVGGNetCompact":
91
+ model = SRVGGNetCompact(
92
+ num_in_ch=3,
93
+ num_out_ch=3,
94
+ num_feat=64,
95
+ num_conv=16,
96
+ upscale=int(upscale),
97
+ act_type="prelu",
98
+ )
99
+ else:
100
+ raise Exception(f"No net named: {prefered_net_in_upsampler} implemented!")
101
+ return model
SR_Inference/inference_srresnet.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import sys
4
+ import torch
5
+ import numpy as np
6
+ import os.path as osp
7
+ from PIL import Image
8
+ from basicsr.utils import img2tensor
9
+ from basicsr.archs.srresnet_arch import MSRResNet
10
+
11
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12
+
13
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
14
+ sys.path.append(root_path)
15
+
16
+
17
+ class SRResNet:
18
+
19
+ def __init__(self, upscale=2, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16):
20
+
21
+ self.upscale = int(upscale)
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+ # ------------------ load model for img enhancement -------------------
25
+ self.sr_model = MSRResNet(
26
+ upscale=self.upscale,
27
+ num_in_ch=num_in_ch,
28
+ num_out_ch=num_out_ch,
29
+ num_feat=num_feat,
30
+ num_block=num_block,
31
+ ).to(self.device)
32
+
33
+ ckpt_path = os.path.join(
34
+ ROOT_DIR,
35
+ "SR_Inference",
36
+ "srresnet",
37
+ "weights",
38
+ f"SRResNet_{str(self.upscale)}x.pth",
39
+ )
40
+ loadnet = torch.load(ckpt_path, map_location=self.device)
41
+ if "params_ema" in loadnet:
42
+ keyname = "params_ema"
43
+ else:
44
+ keyname = "params"
45
+
46
+ self.sr_model.load_state_dict(loadnet[keyname])
47
+ self.sr_model.eval()
48
+
49
+ @torch.no_grad()
50
+ def __call__(self, img):
51
+ img_tensor = (
52
+ img2tensor(imgs=img / 255.0, bgr2rgb=True, float32=True)
53
+ .unsqueeze(0)
54
+ .to(self.device)
55
+ )
56
+ restored_img = self.sr_model(img_tensor)[0]
57
+ restored_img = restored_img.permute(1, 2, 0).cpu().numpy()
58
+ restored_img = (restored_img - restored_img.min()) / (
59
+ restored_img.max() - restored_img.min()
60
+ )
61
+ restored_img = (restored_img * 255).astype(np.uint8)
62
+ restored_img = Image.fromarray(restored_img)
63
+ restored_img = np.array(restored_img)
64
+ sr_img = cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR)
65
+
66
+ return sr_img
67
+
68
+
69
+ if __name__ == "__main__":
70
+
71
+ srresnet = SRResNet(upscale=2)
72
+
73
+ img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
74
+ sr_img = srresnet(img=img)
75
+
76
+ saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
77
+ os.makedirs(saving_dir, exist_ok=True)
78
+ cv2.imwrite(f"{saving_dir}/sr_img_srresnet.png", sr_img)
SR_Inference/realesrgan/weights/RealESRGAN_x2plus.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49fafd45f8fd7aa8d31ab2a22d14d91b536c34494a5cfe31eb5d89c2fa266abb
3
+ size 67061725
SR_Inference/realesrgan/weights/RealESRGAN_x4plus.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1
3
+ size 67040989
SR_Inference/srresnet/weights/SRResNet_2x.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4d2a531ecd6e8f15cc5eccf4bca58cdfea69f76c09baa1b208694977b0f6f5e
3
+ size 5492202
SR_Inference/srresnet/weights/SRResNet_4x.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:112f2ec947bb497b0350b149a1e06c3f73de77497cec64ce0c7ba268b8398023
3
+ size 6083374
app.py CHANGED
@@ -4,8 +4,9 @@
4
  from io import BytesIO
5
  import os
6
  import sys
 
7
  import matplotlib.pyplot as plt
8
- import requests
9
  import streamlit as st
10
  import torch
11
  from PIL import Image
@@ -21,6 +22,7 @@ import os.path as osp
21
  root_path = osp.abspath(osp.join(__file__, osp.pardir))
22
  sys.path.append(root_path)
23
 
 
24
  from utils import get_model
25
  from registry_utils import import_registered_modules
26
 
@@ -39,15 +41,13 @@ CAM_METHODS = [
39
  # "LayerCAM",
40
  ]
41
  TV_MODELS = [
42
- "resnet18",
43
- # "resnet50",
44
- ]
45
- SR_METHODS = ["GFPGAN", "RealESRGAN", "SRResNet", "CodeFormer", "HAT"]
46
- UPSCALE = ["2", "3", "4"]
47
- LABEL_MAP = [
48
- "left_eye",
49
- "right_eye",
50
  ]
 
 
 
 
51
 
52
 
53
  @torch.no_grad()
@@ -79,150 +79,287 @@ def main():
79
 
80
  # Sidebar
81
  # File selection
82
- st.sidebar.title("Input selection")
83
  # Disabling warning
84
  st.set_option("deprecation.showfileUploaderEncoding", False)
85
  # Choose your own image
86
  uploaded_file = st.sidebar.file_uploader(
87
- "Upload files", type=["png", "jpeg", "jpg"]
88
  )
89
  if uploaded_file is not None:
90
- img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- cols[0].image(img, use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  # Model selection
95
- st.sidebar.title("Setup")
96
  tv_model = st.sidebar.selectbox(
97
  "Classification model",
98
  TV_MODELS,
99
- help="Supported models from Torchvision",
100
  )
101
 
102
- # class_choices = [
103
- # f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)
104
- # ]
105
- # class_selection = st.sidebar.selectbox(
106
- # "Class selection", ["Predicted class (argmax)", *class_choices]
 
 
 
 
 
107
  # )
108
 
109
- img_configs = {"img_size": [32, 64], "means": None, "stds": None}
110
- # For newline
111
  st.sidebar.write("\n")
112
 
113
- if st.sidebar.button("Compute CAM"):
114
  if uploaded_file is None:
115
  st.sidebar.error("Please upload an image first")
116
 
117
  else:
118
  with st.spinner("Analyzing..."):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- preprocess_steps = [transforms.ToTensor()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- image_size = img_configs["img_size"]
123
- if image_size is not None:
124
- preprocess_steps.append(
125
- transforms.Resize(
126
- [image_size[0], image_size[-1]],
127
- interpolation=transforms.InterpolationMode.BICUBIC,
128
- antialias=True,
 
 
 
 
 
 
129
  )
 
 
 
 
 
 
130
  )
 
131
 
132
- means = img_configs["means"]
133
- stds = img_configs["stds"]
134
- if means is not None and stds is not None:
135
- preprocess_steps.append(transforms.Normalize(means, stds))
136
 
137
- preprocess_function = transforms.Compose(preprocess_steps)
138
- input_img = preprocess_function(img)
139
- input_img = input_img.unsqueeze(0).to(device="cpu")
140
-
141
- model_configs = {
142
- "model_path": root_path
143
- + "/pre_trained_models/ResNet18/left_eye.pt",
144
- "registered_model_name": "ResNet18",
145
- "num_classes": 1,
146
- }
147
- registered_model_name = model_configs["registered_model_name"]
148
- # default_layer = ""
149
- if tv_model is not None:
150
- with st.spinner("Loading model..."):
151
- model = _load_model(model_configs)
152
-
153
- if torch.cuda.is_available():
154
- model = model.cuda()
155
-
156
- if registered_model_name == "ResNet18":
157
- target_layer = model.resnet.layer4[-1].conv2
158
- elif registered_model_name == "ResNet50":
159
- target_layer = model.resnet.layer4[-1].conv3
160
- else:
161
- raise Exception(
162
- f"No target layer available for selected model: {registered_model_name}"
163
  )
164
 
165
- # target_layer = st.sidebar.text_input(
166
- # "Target layer",
167
- # default_layer,
168
- # help='If you want to target several layers, add a "+" separator (e.g. "layer3+layer4")',
169
- # )
170
- cam_method = "CAM"
171
- # cam_method = st.sidebar.selectbox(
172
- # "CAM method",
173
- # CAM_METHODS,
174
- # help="The way your class activation map will be computed",
175
- # )
176
- if cam_method is not None:
177
- # cam_extractor = methods.__dict__[cam_method](
178
- # model,
179
- # target_layer=(
180
- # [s.strip() for s in target_layer.split("+")]
181
- # if len(target_layer) > 0
182
- # else None
183
- # ),
184
- # )
185
- cam_extractor = torchcam_methods.__dict__[cam_method](
186
- model,
187
- target_layer=target_layer,
188
- fc_layer=model.resnet.fc,
189
- input_shape=(3, 32, 64),
 
 
 
190
  )
191
- # with torch.no_grad():
192
- # if input_mask is not None:
193
- # out = self.model(input_img, input_mask)
194
- # else:
195
- # out = self.model(input_img)
196
- # activation_map = cam_extractor(class_idx=target_class)
197
-
198
- # Forward the image to the model
199
- out = model(input_img)
200
- print("out = ", out)
201
-
202
- # Select the target class
203
- # if class_selection == "Predicted class (argmax)":
204
- # class_idx = out.squeeze(0).argmax().item()
205
- # else:
206
- # class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
207
-
208
- # Retrieve the CAM
209
- # act_maps = cam_extractor(class_idx=target_class)
210
- act_maps = cam_extractor(0, out)
211
- # Fuse the CAMs if there are several
212
- activation_map = (
213
- act_maps[0]
214
- if len(act_maps) == 1
215
- else cam_extractor.fuse_cams(act_maps)
216
- )
217
-
218
- # Overlayed CAM
219
- fig, ax = plt.subplots()
220
- result = overlay_mask(
221
- img, to_pil_image(activation_map, mode="F"), alpha=0.5
222
- )
223
- ax.imshow(result)
224
- ax.axis("off")
225
- cols[-1].pyplot(fig)
226
 
227
 
228
  if __name__ == "__main__":
 
4
  from io import BytesIO
5
  import os
6
  import sys
7
+ import cv2
8
  import matplotlib.pyplot as plt
9
+ import numpy as np
10
  import streamlit as st
11
  import torch
12
  from PIL import Image
 
22
  root_path = osp.abspath(osp.join(__file__, osp.pardir))
23
  sys.path.append(root_path)
24
 
25
+ from preprocessing.dataset_creation import EyeDentityDatasetCreation
26
  from utils import get_model
27
  from registry_utils import import_registered_modules
28
 
 
41
  # "LayerCAM",
42
  ]
43
  TV_MODELS = [
44
+ "ResNet18",
45
+ "ResNet50",
 
 
 
 
 
 
46
  ]
47
+ SR_METHODS = ["GFPGAN", "CodeFormer", "RealESRGAN", "SRResNet", "HAT"]
48
+ UPSCALE = [2, 4]
49
+ UPSCALE_METHODS = ["BILINEAR", "BICUBIC"]
50
+ LABEL_MAP = ["left_pupil", "right_pupil"]
51
 
52
 
53
  @torch.no_grad()
 
79
 
80
  # Sidebar
81
  # File selection
82
+ st.sidebar.title("Upload Face or Eye")
83
  # Disabling warning
84
  st.set_option("deprecation.showfileUploaderEncoding", False)
85
  # Choose your own image
86
  uploaded_file = st.sidebar.file_uploader(
87
+ "Upload Image", type=["png", "jpeg", "jpg"]
88
  )
89
  if uploaded_file is not None:
90
+ input_img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB")
91
+ # print("input_img before = ", input_img.size)
92
+ max_size = [input_img.size[0], input_img.size[1]]
93
+ cols[0].text(f"Input Image: {max_size[0]} x {max_size[1]}")
94
+ if input_img.size[0] == input_img.size[1] and input_img.size[0] >= 256:
95
+ max_size[0] = 256
96
+ max_size[1] = 256
97
+ else:
98
+ if input_img.size[0] >= 640:
99
+ max_size[0] = 640
100
+ elif input_img.size[0] < 64:
101
+ max_size[0] = 64
102
+ if input_img.size[1] >= 480:
103
+ max_size[1] = 480
104
+ elif input_img.size[1] < 32:
105
+ max_size[1] = 32
106
+ input_img.thumbnail((max_size[0], max_size[1])) # Bicubic resampling
107
+ # print("input_img after = ", input_img.size)
108
+ # cols[0].image(input_img)
109
+ fig0, axs0 = plt.subplots(1, 1, figsize=(10, 10))
110
+ # Display the input image
111
+ axs0.imshow(input_img)
112
+ axs0.axis("off")
113
+ axs0.set_title("Input Image")
114
+
115
+ # Display the plot
116
+ cols[0].pyplot(fig0)
117
+ cols[0].text(f"Input Image Resized: {max_size[0]} x {max_size[1]}")
118
+
119
+ st.sidebar.title("Setup")
120
 
121
+ # Upscale selection
122
+ upscale = "-"
123
+ # upscale = st.sidebar.selectbox(
124
+ # "Upscale",
125
+ # ["-"] + UPSCALE,
126
+ # help="Upscale the uploaded image 2 or 4 times. Keep blank for no upscaling",
127
+ # )
128
+
129
+ # Upscale method selection
130
+ if upscale != "-":
131
+ upscale_method_or_model = st.sidebar.selectbox(
132
+ "Upscale Method / Model",
133
+ UPSCALE_METHODS + SR_METHODS,
134
+ help="Select a method or model to upscale the uploaded image",
135
+ )
136
+ else:
137
+ upscale_method_or_model = None
138
+
139
+ # Pupil selection
140
+ pupil_selection = st.sidebar.selectbox(
141
+ "Pupil Selection",
142
+ ["-"] + LABEL_MAP,
143
+ help="Select left or right pupil OR keep blank for both pupil diameter estimation",
144
+ )
145
 
146
  # Model selection
 
147
  tv_model = st.sidebar.selectbox(
148
  "Classification model",
149
  TV_MODELS,
150
+ help="Supported Models for Pupil Diameter Estimation",
151
  )
152
 
153
+ cam_method = "CAM"
154
+ # cam_method = st.sidebar.selectbox(
155
+ # "CAM method",
156
+ # CAM_METHODS,
157
+ # help="The way your class activation map will be computed",
158
+ # )
159
+ # target_layer = st.sidebar.text_input(
160
+ # "Target layer",
161
+ # default_layer,
162
+ # help='If you want to target several layers, add a "+" separator (e.g. "layer3+layer4")',
163
  # )
164
 
 
 
165
  st.sidebar.write("\n")
166
 
167
+ if st.sidebar.button("Predict Diameter & Compute CAM"):
168
  if uploaded_file is None:
169
  st.sidebar.error("Please upload an image first")
170
 
171
  else:
172
  with st.spinner("Analyzing..."):
173
+ if upscale == "-":
174
+ sr_configs = None
175
+ else:
176
+ sr_configs = {
177
+ "method": upscale_method_or_model,
178
+ "params": {"upscale": upscale},
179
+ }
180
+ config_file = {
181
+ "sr_configs": sr_configs,
182
+ "feature_extraction_configs": {
183
+ "blink_detection": False,
184
+ "upscale": upscale,
185
+ "extraction_library": "mediapipe",
186
+ },
187
+ }
188
+
189
+ img = np.array(input_img)
190
+ # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
191
+ # if img.shape[0] > max_size or img.shape[1] > max_size:
192
+ # img = cv2.resize(img, (max_size, max_size))
193
+
194
+ ds_results = EyeDentityDatasetCreation(
195
+ feature_extraction_configs=config_file[
196
+ "feature_extraction_configs"
197
+ ],
198
+ sr_configs=config_file["sr_configs"],
199
+ )(img)
200
+ # if ds_results is not None:
201
+ # print("ds_results = ", ds_results.keys())
202
+
203
+ preprocess_steps = [
204
+ transforms.ToTensor(),
205
+ transforms.Resize(
206
+ [32, 64],
207
+ # interpolation=transforms.InterpolationMode.BILINEAR,
208
+ interpolation=transforms.InterpolationMode.BICUBIC,
209
+ antialias=True,
210
+ ),
211
+ ]
212
+ preprocess_function = transforms.Compose(preprocess_steps)
213
 
214
+ left_eye = None
215
+ right_eye = None
216
+
217
+ if ds_results is None:
218
+ # print("type of input_img = ", type(input_img))
219
+ input_img = preprocess_function(input_img)
220
+ input_img = input_img.unsqueeze(0)
221
+ if pupil_selection == "left_pupil":
222
+ left_eye = input_img
223
+ elif pupil_selection == "right_pupil":
224
+ right_eye = input_img
225
+ else:
226
+ left_eye = input_img
227
+ right_eye = input_img
228
+ # print("type of left_eye = ", type(left_eye))
229
+ # print("type of right_eye = ", type(right_eye))
230
+ elif "eyes" in ds_results.keys():
231
+ if (
232
+ "left_eye" in ds_results["eyes"].keys()
233
+ and ds_results["eyes"]["left_eye"] is not None
234
+ ):
235
+ left_eye = ds_results["eyes"]["left_eye"]
236
+ # print("type of left_eye = ", type(left_eye))
237
+ left_eye = to_pil_image(left_eye).convert("RGB")
238
+ # print("type of left_eye = ", type(left_eye))
239
+
240
+ left_eye = preprocess_function(left_eye)
241
+ # print("type of left_eye = ", type(left_eye))
242
+
243
+ left_eye = left_eye.unsqueeze(0)
244
+ if (
245
+ "right_eye" in ds_results["eyes"].keys()
246
+ and ds_results["eyes"]["right_eye"] is not None
247
+ ):
248
+ right_eye = ds_results["eyes"]["right_eye"]
249
+ # print("type of right_eye = ", type(right_eye))
250
+ right_eye = to_pil_image(right_eye).convert("RGB")
251
+ # print("type of right_eye = ", type(right_eye))
252
+
253
+ right_eye = preprocess_function(right_eye)
254
+ # print("type of right_eye = ", type(right_eye))
255
+
256
+ right_eye = right_eye.unsqueeze(0)
257
+ else:
258
+ # print("type of input_img = ", type(input_img))
259
+ input_img = preprocess_function(input_img)
260
+ input_img = input_img.unsqueeze(0)
261
+ if pupil_selection == "left_pupil":
262
+ left_eye = input_img
263
+ elif pupil_selection == "right_pupil":
264
+ right_eye = input_img
265
+ else:
266
+ left_eye = input_img
267
+ right_eye = input_img
268
+ # print("type of left_eye = ", type(left_eye))
269
+ # print("type of right_eye = ", type(right_eye))
270
+
271
+ # print("left_eye = ", left_eye.shape)
272
+ # print("right_eye = ", right_eye.shape)
273
+
274
+ if pupil_selection == "-":
275
+ selected_eyes = ["left_eye", "right_eye"]
276
+ elif pupil_selection == "left_pupil":
277
+ selected_eyes = ["left_eye"]
278
+ elif pupil_selection == "right_pupil":
279
+ selected_eyes = ["right_eye"]
280
+
281
+ for eye_type in selected_eyes:
282
+
283
+ model_configs = {
284
+ "model_path": root_path
285
+ + f"/pre_trained_models/{tv_model}/{eye_type}.pt",
286
+ "registered_model_name": tv_model,
287
+ "num_classes": 1,
288
+ }
289
+ registered_model_name = model_configs["registered_model_name"]
290
+ model = _load_model(model_configs)
291
+
292
+ if registered_model_name == "ResNet18":
293
+ target_layer = model.resnet.layer4[-1].conv2
294
+ elif registered_model_name == "ResNet50":
295
+ target_layer = model.resnet.layer4[-1].conv3
296
+ else:
297
+ raise Exception(
298
+ f"No target layer available for selected model: {registered_model_name}"
299
+ )
300
 
301
+ if left_eye is not None and eye_type == "left_eye":
302
+ input_img = left_eye
303
+ elif right_eye is not None and eye_type == "right_eye":
304
+ input_img = right_eye
305
+ else:
306
+ raise Exception("Wrong Data")
307
+
308
+ if cam_method is not None:
309
+ cam_extractor = torchcam_methods.__dict__[cam_method](
310
+ model,
311
+ target_layer=target_layer,
312
+ fc_layer=model.resnet.fc,
313
+ input_shape=input_img.shape,
314
  )
315
+
316
+ # with torch.no_grad():
317
+ out = model(input_img)
318
+ cols[-1].markdown(
319
+ f"<h3>Predicted Pupil Diameter: {out[0].item():.2f} mm</h3>",
320
+ unsafe_allow_html=True,
321
  )
322
+ # cols[-1].text(f"Predicted Pupil Diameter: {out[0].item():.2f}")
323
 
324
+ # Retrieve the CAM
325
+ act_maps = cam_extractor(0, out)
 
 
326
 
327
+ # Fuse the CAMs if there are several
328
+ activation_map = (
329
+ act_maps[0]
330
+ if len(act_maps) == 1
331
+ else cam_extractor.fuse_cams(act_maps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  )
333
 
334
+ # Convert input image and activation map to PIL images
335
+ input_image_pil = to_pil_image(input_img.squeeze(0))
336
+ activation_map_pil = to_pil_image(activation_map, mode="F")
337
+
338
+ # Create the overlayed CAM result
339
+ result = overlay_mask(
340
+ input_image_pil,
341
+ activation_map_pil,
342
+ alpha=0.5,
343
+ )
344
+
345
+ # Create a subplot with 1 row and 2 columns
346
+ fig, axs = plt.subplots(1, 2, figsize=(10, 5))
347
+
348
+ # Display the input image
349
+ axs[0].imshow(input_image_pil)
350
+ axs[0].axis("off")
351
+ axs[0].set_title("Input Image")
352
+
353
+ # Display the overlayed CAM result
354
+ axs[1].imshow(result)
355
+ axs[1].axis("off")
356
+ axs[1].set_title("Overlayed CAM")
357
+
358
+ # Display the plot
359
+ cols[-1].pyplot(fig)
360
+ cols[-1].text(
361
+ f"eye image size: {input_img.shape[-1]} x {input_img.shape[-2]}"
362
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
 
365
  if __name__ == "__main__":
config.yml CHANGED
@@ -19,9 +19,9 @@ xai_configs:
19
  "InputXGradient",
20
  "GuidedBackprop",
21
  "Deconvolution",
22
- "GuidedGradCam",
23
- "LayerGradCam",
24
- "LayerGradientXActivation",
25
  ]
26
  cam_methods: [
27
  "CAM",
 
19
  "InputXGradient",
20
  "GuidedBackprop",
21
  "Deconvolution",
22
+ # "GuidedGradCam",
23
+ # "LayerGradCam",
24
+ # "LayerGradientXActivation",
25
  ]
26
  cam_methods: [
27
  "CAM",
feature_extraction/extractor_mediapipe.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import warnings
4
+ import numpy as np
5
+ from PIL import Image
6
+ from math import sqrt
7
+ import mediapipe as mp
8
+ from transformers import pipeline
9
+
10
+ warnings.filterwarnings("ignore")
11
+
12
+
13
+ class ExtractorMediaPipe:
14
+
15
+ def __init__(self, upscale=1):
16
+
17
+ self.upscale = int(upscale)
18
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # ========== Face Extraction ==========
21
+ self.face_detector = mp.solutions.face_detection.FaceDetection(
22
+ model_selection=0, min_detection_confidence=0.5
23
+ )
24
+ self.face_mesh = mp.solutions.face_mesh.FaceMesh(
25
+ max_num_faces=1,
26
+ static_image_mode=True,
27
+ refine_landmarks=True,
28
+ min_detection_confidence=0.5,
29
+ min_tracking_confidence=0.5,
30
+ )
31
+
32
+ # ========== Eyes Extraction ==========
33
+ self.RIGHT_EYE = [
34
+ 362,
35
+ 382,
36
+ 381,
37
+ 380,
38
+ 374,
39
+ 373,
40
+ 390,
41
+ 249,
42
+ 263,
43
+ 466,
44
+ 388,
45
+ 387,
46
+ 386,
47
+ 385,
48
+ 384,
49
+ 398,
50
+ ]
51
+ self.LEFT_EYE = [
52
+ 33,
53
+ 7,
54
+ 163,
55
+ 144,
56
+ 145,
57
+ 153,
58
+ 154,
59
+ 155,
60
+ 133,
61
+ 173,
62
+ 157,
63
+ 158,
64
+ 159,
65
+ 160,
66
+ 161,
67
+ 246,
68
+ ]
69
+ # https://huggingface.co/dima806/closed_eyes_image_detection
70
+ # https://www.kaggle.com/code/dima806/closed-eye-image-detection-vit
71
+ self.pipe = pipeline(
72
+ "image-classification",
73
+ model="dima806/closed_eyes_image_detection",
74
+ device=self.device,
75
+ )
76
+ self.blink_lower_thresh = 0.22
77
+ self.blink_upper_thresh = 0.25
78
+ self.blink_confidence = 0.50
79
+
80
+ # ========== Iris Extraction ==========
81
+ self.RIGHT_IRIS = [474, 475, 476, 477]
82
+ self.LEFT_IRIS = [469, 470, 471, 472]
83
+
84
+ def extract_face(self, image):
85
+
86
+ tmp_image = image.copy()
87
+ results = self.face_detector.process(tmp_image)
88
+
89
+ if not results.detections:
90
+ # print("No face detected")
91
+ return None
92
+ else:
93
+ bboxC = results.detections[0].location_data.relative_bounding_box
94
+ ih, iw, _ = image.shape
95
+
96
+ # Get bounding box coordinates
97
+ x, y, w, h = (
98
+ int(bboxC.xmin * iw),
99
+ int(bboxC.ymin * ih),
100
+ int(bboxC.width * iw),
101
+ int(bboxC.height * ih),
102
+ )
103
+
104
+ # Calculate the center of the bounding box
105
+ center_x = x + w // 2
106
+ center_y = y + h // 2
107
+
108
+ # Calculate new bounds ensuring they fit within the image dimensions
109
+ half_size = 128 * self.upscale
110
+ x1 = max(center_x - half_size, 0)
111
+ y1 = max(center_y - half_size, 0)
112
+ x2 = min(center_x + half_size, iw)
113
+ y2 = min(center_y + half_size, ih)
114
+
115
+ # Adjust x1, x2, y1, and y2 to ensure the cropped region is exactly (256 * self.upscale) x (256 * self.upscale)
116
+ if x2 - x1 < (256 * self.upscale):
117
+ if x1 == 0:
118
+ x2 = min((256 * self.upscale), iw)
119
+ elif x2 == iw:
120
+ x1 = max(iw - (256 * self.upscale), 0)
121
+
122
+ if y2 - y1 < (256 * self.upscale):
123
+ if y1 == 0:
124
+ y2 = min((256 * self.upscale), ih)
125
+ elif y2 == ih:
126
+ y1 = max(ih - (256 * self.upscale), 0)
127
+
128
+ cropped_face = image[y1:y2, x1:x2]
129
+
130
+ # bicubic upsampling
131
+ # if self.upscale != 1:
132
+ # cropped_face = cv2.resize(
133
+ # cropped_face,
134
+ # (256 * self.upscale, 256 * self.upscale),
135
+ # interpolation=cv2.INTER_CUBIC,
136
+ # )
137
+
138
+ return cropped_face
139
+
140
+ @staticmethod
141
+ def landmarksDetection(image, results, draw=False):
142
+ image_height, image_width = image.shape[:2]
143
+ mesh_coordinates = [
144
+ (int(point.x * image_width), int(point.y * image_height))
145
+ for point in results.multi_face_landmarks[0].landmark
146
+ ]
147
+ if draw:
148
+ [cv2.circle(image, i, 2, (0, 255, 0), -1) for i in mesh_coordinates]
149
+ return mesh_coordinates
150
+
151
+ @staticmethod
152
+ def euclideanDistance(point, point1):
153
+ x, y = point
154
+ x1, y1 = point1
155
+ distance = sqrt((x1 - x) ** 2 + (y1 - y) ** 2)
156
+ return distance
157
+
158
+ def blinkRatio(self, landmarks, right_indices, left_indices):
159
+
160
+ right_eye_landmark1 = landmarks[right_indices[0]]
161
+ right_eye_landmark2 = landmarks[right_indices[8]]
162
+
163
+ right_eye_landmark3 = landmarks[right_indices[12]]
164
+ right_eye_landmark4 = landmarks[right_indices[4]]
165
+
166
+ left_eye_landmark1 = landmarks[left_indices[0]]
167
+ left_eye_landmark2 = landmarks[left_indices[8]]
168
+
169
+ left_eye_landmark3 = landmarks[left_indices[12]]
170
+ left_eye_landmark4 = landmarks[left_indices[4]]
171
+
172
+ right_eye_horizontal_distance = self.euclideanDistance(
173
+ right_eye_landmark1, right_eye_landmark2
174
+ )
175
+ right_eye_vertical_distance = self.euclideanDistance(
176
+ right_eye_landmark3, right_eye_landmark4
177
+ )
178
+
179
+ left_eye_vertical_distance = self.euclideanDistance(
180
+ left_eye_landmark3, left_eye_landmark4
181
+ )
182
+ left_eye_horizontal_distance = self.euclideanDistance(
183
+ left_eye_landmark1, left_eye_landmark2
184
+ )
185
+
186
+ right_eye_ratio = right_eye_vertical_distance / right_eye_horizontal_distance
187
+ left_eye_ratio = left_eye_vertical_distance / left_eye_horizontal_distance
188
+
189
+ eyes_ratio = (right_eye_ratio + left_eye_ratio) / 2
190
+
191
+ return eyes_ratio
192
+
193
+ def extract_eyes_regions(self, image, landmarks, eye_indices):
194
+ h, w, _ = image.shape
195
+ points = [
196
+ (int(landmarks[idx].x * w), int(landmarks[idx].y * h))
197
+ for idx in eye_indices
198
+ ]
199
+
200
+ x_min = min([p[0] for p in points])
201
+ x_max = max([p[0] for p in points])
202
+ y_min = min([p[1] for p in points])
203
+ y_max = max([p[1] for p in points])
204
+
205
+ center_x = (x_min + x_max) // 2
206
+ center_y = (y_min + y_max) // 2
207
+
208
+ target_width = 32 * self.upscale
209
+ target_height = 16 * self.upscale
210
+
211
+ x1 = max(center_x - target_width // 2, 0)
212
+ y1 = max(center_y - target_height // 2, 0)
213
+ x2 = x1 + target_width
214
+ y2 = y1 + target_height
215
+
216
+ if x2 > w:
217
+ x1 = w - target_width
218
+ x2 = w
219
+ if y2 > h:
220
+ y1 = h - target_height
221
+ y2 = h
222
+
223
+ return image[y1:y2, x1:x2]
224
+
225
+ def blink_detection_model(self, left_eye, right_eye):
226
+
227
+ left_eye = cv2.cvtColor(left_eye, cv2.COLOR_RGB2GRAY)
228
+ left_eye = Image.fromarray(left_eye)
229
+ preds_left = self.pipe(left_eye)
230
+ if preds_left[0]["label"] == "closeEye":
231
+ closed_left = preds_left[0]["score"] >= self.blink_confidence
232
+ else:
233
+ closed_left = preds_left[1]["score"] >= self.blink_confidence
234
+
235
+ right_eye = cv2.cvtColor(right_eye, cv2.COLOR_RGB2GRAY)
236
+ right_eye = Image.fromarray(right_eye)
237
+ preds_right = self.pipe(right_eye)
238
+ if preds_right[0]["label"] == "closeEye":
239
+ closed_right = preds_right[0]["score"] >= self.blink_confidence
240
+ else:
241
+ closed_right = preds_right[1]["score"] >= self.blink_confidence
242
+
243
+ # print("preds_left = ", preds_left)
244
+ # print("preds_right = ", preds_right)
245
+
246
+ return closed_left or closed_right
247
+
248
+ def extract_eyes(self, image, blink_detection=False):
249
+
250
+ tmp_face = image.copy()
251
+ results = self.face_mesh.process(tmp_face)
252
+
253
+ if results.multi_face_landmarks is None:
254
+ return None
255
+
256
+ face_landmarks = results.multi_face_landmarks[0].landmark
257
+
258
+ left_eye = self.extract_eyes_regions(image, face_landmarks, self.LEFT_EYE)
259
+ right_eye = self.extract_eyes_regions(image, face_landmarks, self.RIGHT_EYE)
260
+ blinked = False
261
+
262
+ if blink_detection:
263
+ mesh_coordinates = self.landmarksDetection(image, results, False)
264
+ eyes_ratio = self.blinkRatio(
265
+ mesh_coordinates, self.RIGHT_EYE, self.LEFT_EYE
266
+ )
267
+ if (
268
+ eyes_ratio > self.blink_lower_thresh
269
+ and eyes_ratio <= self.blink_upper_thresh
270
+ ):
271
+ # print(
272
+ # "I think person blinked. eyes_ratio = ",
273
+ # eyes_ratio,
274
+ # "Confirming with ViT model...",
275
+ # )
276
+ blinked = self.blink_detection_model(
277
+ left_eye=left_eye, right_eye=right_eye
278
+ )
279
+ # if blinked:
280
+ # print("Yes, person blinked. Confirmed by model")
281
+ # else:
282
+ # print("No, person didn't blinked. False Alarm")
283
+ elif eyes_ratio <= self.blink_lower_thresh:
284
+ blinked = True
285
+ # print("Surely person blinked. eyes_ratio = ", eyes_ratio)
286
+ else:
287
+ blinked = False
288
+
289
+ return {"left_eye": left_eye, "right_eye": right_eye, "blinked": blinked}
290
+
291
+ @staticmethod
292
+ def segment_iris(iris_img):
293
+
294
+ # Convert RGB image to grayscale
295
+ iris_img_gray = cv2.cvtColor(iris_img, cv2.COLOR_RGB2GRAY)
296
+
297
+ # Apply Gaussian blur for denoising
298
+ iris_img_blur = cv2.GaussianBlur(iris_img_gray, (5, 5), 0)
299
+
300
+ # Perform adaptive thresholding
301
+ _, iris_img_mask = cv2.threshold(
302
+ iris_img_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
303
+ )
304
+
305
+ # Invert the mask
306
+ segmented_mask = cv2.bitwise_not(iris_img_mask)
307
+ segmented_mask = cv2.cvtColor(segmented_mask, cv2.COLOR_GRAY2RGB)
308
+ segmented_iris = cv2.bitwise_and(iris_img, segmented_mask)
309
+
310
+ return {
311
+ "segmented_iris": segmented_iris,
312
+ "segmented_mask": segmented_mask,
313
+ }
314
+
315
+ def extract_iris(self, image):
316
+
317
+ ih, iw, _ = image.shape
318
+ tmp_face = image.copy()
319
+ results = self.face_mesh.process(tmp_face)
320
+
321
+ if results.multi_face_landmarks is None:
322
+ return None
323
+
324
+ mesh_coordinates = self.landmarksDetection(image, results, False)
325
+ mesh_points = np.array(mesh_coordinates)
326
+
327
+ (l_cx, l_cy), l_radius = cv2.minEnclosingCircle(mesh_points[self.LEFT_IRIS])
328
+ (r_cx, r_cy), r_radius = cv2.minEnclosingCircle(mesh_points[self.RIGHT_IRIS])
329
+
330
+ # Crop the left iris to be exactly 16*upscaled x 16*upscaled
331
+ l_x1 = max(int(l_cx) - (8 * self.upscale), 0)
332
+ l_y1 = max(int(l_cy) - (8 * self.upscale), 0)
333
+ l_x2 = min(int(l_cx) + (8 * self.upscale), iw)
334
+ l_y2 = min(int(l_cy) + (8 * self.upscale), ih)
335
+
336
+ cropped_left_iris = image[l_y1:l_y2, l_x1:l_x2]
337
+
338
+ left_iris_segmented_data = self.segment_iris(
339
+ cv2.cvtColor(cropped_left_iris, cv2.COLOR_BGR2RGB)
340
+ )
341
+
342
+ # Crop the right iris to be exactly 16*upscaled x 16*upscaled
343
+ r_x1 = max(int(r_cx) - (8 * self.upscale), 0)
344
+ r_y1 = max(int(r_cy) - (8 * self.upscale), 0)
345
+ r_x2 = min(int(r_cx) + (8 * self.upscale), iw)
346
+ r_y2 = min(int(r_cy) + (8 * self.upscale), ih)
347
+
348
+ cropped_right_iris = image[r_y1:r_y2, r_x1:r_x2]
349
+
350
+ right_iris_segmented_data = self.segment_iris(
351
+ cv2.cvtColor(cropped_right_iris, cv2.COLOR_BGR2RGB)
352
+ )
353
+
354
+ return {
355
+ "left_iris": {
356
+ "img": cropped_left_iris,
357
+ "segmented_iris": left_iris_segmented_data["segmented_iris"],
358
+ "segmented_mask": left_iris_segmented_data["segmented_mask"],
359
+ },
360
+ "right_iris": {
361
+ "img": cropped_right_iris,
362
+ "segmented_iris": right_iris_segmented_data["segmented_iris"],
363
+ "segmented_mask": right_iris_segmented_data["segmented_mask"],
364
+ },
365
+ }
feature_extraction/features_extractor.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import warnings
4
+ import os.path as osp
5
+
6
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
7
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
8
+ sys.path.append(root_path)
9
+
10
+ from feature_extraction.extractor_mediapipe import ExtractorMediaPipe
11
+
12
+ warnings.filterwarnings("ignore")
13
+
14
+
15
+ class FeaturesExtractor:
16
+
17
+ def __init__(
18
+ self, extraction_library="mediapipe", blink_detection=False, upscale=1
19
+ ):
20
+ self.upscale = upscale
21
+ self.blink_detection = blink_detection
22
+ self.extraction_library = extraction_library
23
+ self.feature_extractor = ExtractorMediaPipe(self.upscale)
24
+
25
+ def __call__(self, image):
26
+ results = {}
27
+ face = self.feature_extractor.extract_face(image)
28
+ if face is None:
29
+ # print("No face found. Skipped feature extraction!")
30
+ return None
31
+ else:
32
+ results["img"] = image
33
+ results["face"] = face
34
+ eyes_data = self.feature_extractor.extract_eyes(image, self.blink_detection)
35
+ if eyes_data is None:
36
+ # print("No eyes found. Skipped feature extraction!")
37
+ return results
38
+ else:
39
+ results["eyes"] = eyes_data
40
+ if eyes_data["blinked"]:
41
+ # print("Found blinked eyes!")
42
+ return results
43
+ else:
44
+ iris_data = self.feature_extractor.extract_iris(image)
45
+ if iris_data is None:
46
+ # print("No iris found. Skipped feature extraction!")
47
+ return results
48
+ else:
49
+ results["iris"] = iris_data
50
+ return results
packages.txt DELETED
File without changes
preprocessing/dataset_creation.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import cv2
3
+ import os.path as osp
4
+
5
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
6
+ sys.path.append(root_path)
7
+
8
+ from preprocessing.dataset_creation_utils import get_sr_method
9
+ from feature_extraction.features_extractor import FeaturesExtractor
10
+
11
+
12
+ class EyeDentityDatasetCreation:
13
+
14
+ def __init__(self, feature_extraction_configs, sr_configs=None):
15
+ self.extraction_library = feature_extraction_configs["extraction_library"]
16
+ self.sr_configs = sr_configs
17
+ if self.sr_configs:
18
+ self.sr_method_name = sr_configs["method"]
19
+ self.upscale = sr_configs["params"]["upscale"]
20
+ if self.sr_method_name != "-":
21
+ self.sr_method = get_sr_method(self, sr_configs)
22
+ else:
23
+ self.upscale = 1
24
+
25
+ self.blink_detection = feature_extraction_configs["blink_detection"]
26
+ self.features_extractor = FeaturesExtractor(
27
+ extraction_library=self.extraction_library,
28
+ blink_detection=self.blink_detection,
29
+ upscale=self.upscale,
30
+ )
31
+
32
+ def __call__(self, img):
33
+ # img = cv2.imread(img)
34
+ if self.sr_configs is None or self.sr_configs != "-":
35
+ img = img
36
+ else:
37
+ img = self.sr_method(img)
38
+
39
+ result_dict = self.features_extractor(img)
40
+ return result_dict
preprocessing/dataset_creation_utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ from SR_Inference.inference_hat import HAT
6
+ from SR_Inference.inference_gfpgan import GFPGAN
7
+ from SR_Inference.inference_realesr import RealEsr
8
+ from SR_Inference.inference_srresnet import SRResNet
9
+ from SR_Inference.inference_codeformer import CodeFormer
10
+
11
+
12
+ def seed_everything(seed=42):
13
+ random.seed(seed)
14
+ os.environ["PYTHONHASHSEED"] = str(seed)
15
+ np.random.seed(seed)
16
+ torch.manual_seed(seed)
17
+ torch.cuda.manual_seed(seed)
18
+ torch.backends.cudnn.benchmark = True
19
+ torch.backends.cudnn.deterministic = True
20
+
21
+
22
+ def get_sr_method(self, sr_configs):
23
+ sr_method_class = globals().get(self.sr_method_name)
24
+ if sr_method_class is not None:
25
+ return sr_method_class(**sr_configs["params"])
26
+ else:
27
+ raise Exception(
28
+ f"No such SR method called '{self.sr_method_name}' implemented!"
29
+ )
registrations/models.py CHANGED
@@ -121,4 +121,4 @@ class ResNet50(nn.Module):
121
  return x
122
 
123
 
124
- print("Registered models in MODEL_REGISTRY:", MODEL_REGISTRY.keys())
 
121
  return x
122
 
123
 
124
+ # print("Registered models in MODEL_REGISTRY:", MODEL_REGISTRY.keys())
registry_utils.py CHANGED
@@ -58,22 +58,22 @@ def import_registered_modules(registration_folder="registrations"):
58
  list: List of imported modules.
59
  """
60
 
61
- print("\n")
62
 
63
  registration_modules_folder = (
64
  osp.dirname(osp.abspath(__file__)) + f"/{registration_folder}"
65
  )
66
- print("registration_modules_folder = ", registration_modules_folder)
67
 
68
  registration_modules_file_names = [
69
  osp.splitext(osp.basename(v))[0]
70
  for v in scandir(dir_path=registration_modules_folder)
71
  ]
72
- print("registration_modules_file_names = ", registration_modules_file_names)
73
 
74
  imported_modules = [
75
  importlib.import_module(f"{registration_folder}.{file_name}")
76
  for file_name in registration_modules_file_names
77
  ]
78
- print("imported_modules = ", imported_modules)
79
- print("\n")
 
58
  list: List of imported modules.
59
  """
60
 
61
+ # print("\n")
62
 
63
  registration_modules_folder = (
64
  osp.dirname(osp.abspath(__file__)) + f"/{registration_folder}"
65
  )
66
+ # print("registration_modules_folder = ", registration_modules_folder)
67
 
68
  registration_modules_file_names = [
69
  osp.splitext(osp.basename(v))[0]
70
  for v in scandir(dir_path=registration_modules_folder)
71
  ]
72
+ # print("registration_modules_file_names = ", registration_modules_file_names)
73
 
74
  imported_modules = [
75
  importlib.import_module(f"{registration_folder}.{file_name}")
76
  for file_name in registration_modules_file_names
77
  ]
78
+ # print("imported_modules = ", imported_modules)
79
+ # print("\n")
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  tqdm
2
  PyYAML
3
  numpy
@@ -10,7 +11,7 @@ scikit_learn
10
  torch
11
  captum
12
  evaluate
13
- # basicsr
14
  facexlib
15
  realesrgan
16
  opencv_python
@@ -18,7 +19,7 @@ cmake
18
  dlib
19
  einops
20
  transformers
21
- # gfpgan
22
  # streamlit
23
  mediapipe
24
  imutils
 
1
+ # https://huggingface.co/docs/hub/en/spaces-dependencies
2
  tqdm
3
  PyYAML
4
  numpy
 
11
  torch
12
  captum
13
  evaluate
14
+ basicsr
15
  facexlib
16
  realesrgan
17
  opencv_python
 
19
  dlib
20
  einops
21
  transformers
22
+ gfpgan
23
  # streamlit
24
  mediapipe
25
  imutils
xx-packages.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/docs/hub/en/spaces-dependencies
2
+ # tqdm
3
+ # PyYAML
4
+ # numpy
5
+ # pandas
6
+ # matplotlib
7
+ # seaborn
8
+ # mlflow
9
+ # pillow
10
+ # scikit_learn
11
+ # torch
12
+ # captum
13
+ # evaluate
14
+ # basicsr
15
+ # facexlib
16
+ # realesrgan
17
+ # opencv_python
18
+ # cmake
19
+ # dlib
20
+ # einops
21
+ # transformers
22
+ # gfpgan
23
+ # streamlit
24
+ # mediapipe
25
+ # imutils
26
+ # scipy
27
+ # torchvision==0.16.0
28
+ # torchcam