Spaces:
Running
Running
vijul.shah
commited on
Commit
·
0f2d9f6
1
Parent(s):
51ba5d6
End-to-End Pipeline Configured
Browse files- .gitignore +1 -1
- SR_Inference/codeformer/codeformer_arch.py +271 -0
- SR_Inference/codeformer/vqgan_arch.py +418 -0
- SR_Inference/codeformer/weights/codeformer_v0.1.0.pth +3 -0
- SR_Inference/gfpgan/weights/GFPGANv1.3.pth +3 -0
- SR_Inference/gfpgan/weights/detection_Resnet50_Final.pth +3 -0
- SR_Inference/gfpgan/weights/parsing_parsenet.pth +3 -0
- SR_Inference/hat/hat_arch.py +979 -0
- SR_Inference/hat/weights/HAT-L_SRx2_ImageNet-pretrain.pth +3 -0
- SR_Inference/hat/weights/HAT_SRx2_ImageNet-pretrain.pth +3 -0
- SR_Inference/hat/weights/HAT_SRx4_ImageNet-pretrain.pth +3 -0
- SR_Inference/inference_codeformer.py +126 -0
- SR_Inference/inference_gfpgan.py +76 -0
- SR_Inference/inference_hat.py +104 -0
- SR_Inference/inference_realesr.py +52 -0
- SR_Inference/inference_sr_utils.py +101 -0
- SR_Inference/inference_srresnet.py +78 -0
- SR_Inference/realesrgan/weights/RealESRGAN_x2plus.pth +3 -0
- SR_Inference/realesrgan/weights/RealESRGAN_x4plus.pth +3 -0
- SR_Inference/srresnet/weights/SRResNet_2x.pth +3 -0
- SR_Inference/srresnet/weights/SRResNet_4x.pth +3 -0
- app.py +258 -121
- config.yml +3 -3
- feature_extraction/extractor_mediapipe.py +365 -0
- feature_extraction/features_extractor.py +50 -0
- packages.txt +0 -0
- preprocessing/dataset_creation.py +40 -0
- preprocessing/dataset_creation_utils.py +29 -0
- registrations/models.py +1 -1
- registry_utils.py +5 -5
- requirements.txt +3 -2
- xx-packages.txt +28 -0
.gitignore
CHANGED
@@ -1 +1 @@
|
|
1 |
-
__pycache__
|
|
|
1 |
+
__pycache__
|
SR_Inference/codeformer/codeformer_arch.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from .vqgan_arch import *
|
4 |
+
from typing import Optional
|
5 |
+
from torch import nn, Tensor
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
def calc_mean_std(feat, eps=1e-5):
|
9 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
feat (Tensor): 4D tensor.
|
13 |
+
eps (float): A small value added to the variance to avoid
|
14 |
+
divide-by-zero. Default: 1e-5.
|
15 |
+
"""
|
16 |
+
size = feat.size()
|
17 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
18 |
+
b, c = size[:2]
|
19 |
+
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
20 |
+
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
21 |
+
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
22 |
+
return feat_mean, feat_std
|
23 |
+
|
24 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
25 |
+
"""Adaptive instance normalization.
|
26 |
+
|
27 |
+
Adjust the reference features to have the similar color and illuminations
|
28 |
+
as those in the degradate features.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
content_feat (Tensor): The reference feature.
|
32 |
+
style_feat (Tensor): The degradate features.
|
33 |
+
"""
|
34 |
+
size = content_feat.size()
|
35 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
36 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
37 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
38 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
39 |
+
|
40 |
+
class PositionEmbeddingSine(nn.Module):
|
41 |
+
"""
|
42 |
+
This is a more standard version of the position embedding, very similar to the one
|
43 |
+
used by the Attention is all you need paper, generalized to work on images.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
47 |
+
super().__init__()
|
48 |
+
self.num_pos_feats = num_pos_feats
|
49 |
+
self.temperature = temperature
|
50 |
+
self.normalize = normalize
|
51 |
+
if scale is not None and normalize is False:
|
52 |
+
raise ValueError("normalize should be True if scale is passed")
|
53 |
+
if scale is None:
|
54 |
+
scale = 2 * math.pi
|
55 |
+
self.scale = scale
|
56 |
+
|
57 |
+
def forward(self, x, mask=None):
|
58 |
+
if mask is None:
|
59 |
+
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
60 |
+
not_mask = ~mask
|
61 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
62 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
63 |
+
if self.normalize:
|
64 |
+
eps = 1e-6
|
65 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
66 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
67 |
+
|
68 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
69 |
+
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
70 |
+
|
71 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
72 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
73 |
+
pos_x = torch.stack(
|
74 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
75 |
+
).flatten(3)
|
76 |
+
pos_y = torch.stack(
|
77 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
78 |
+
).flatten(3)
|
79 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
80 |
+
return pos
|
81 |
+
|
82 |
+
def _get_activation_fn(activation):
|
83 |
+
"""Return an activation function given a string"""
|
84 |
+
if activation == "relu":
|
85 |
+
return F.relu
|
86 |
+
if activation == "gelu":
|
87 |
+
return F.gelu
|
88 |
+
if activation == "glu":
|
89 |
+
return F.glu
|
90 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
91 |
+
|
92 |
+
class TransformerSALayer(nn.Module):
|
93 |
+
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
94 |
+
super().__init__()
|
95 |
+
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
96 |
+
# Implementation of Feedforward model - MLP
|
97 |
+
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
98 |
+
self.dropout = nn.Dropout(dropout)
|
99 |
+
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
100 |
+
|
101 |
+
self.norm1 = nn.LayerNorm(embed_dim)
|
102 |
+
self.norm2 = nn.LayerNorm(embed_dim)
|
103 |
+
self.dropout1 = nn.Dropout(dropout)
|
104 |
+
self.dropout2 = nn.Dropout(dropout)
|
105 |
+
|
106 |
+
self.activation = _get_activation_fn(activation)
|
107 |
+
|
108 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
109 |
+
return tensor if pos is None else tensor + pos
|
110 |
+
|
111 |
+
def forward(self, tgt,
|
112 |
+
tgt_mask: Optional[Tensor] = None,
|
113 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
114 |
+
query_pos: Optional[Tensor] = None):
|
115 |
+
|
116 |
+
# self attention
|
117 |
+
tgt2 = self.norm1(tgt)
|
118 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
119 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
120 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
121 |
+
tgt = tgt + self.dropout1(tgt2)
|
122 |
+
|
123 |
+
# ffn
|
124 |
+
tgt2 = self.norm2(tgt)
|
125 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
126 |
+
tgt = tgt + self.dropout2(tgt2)
|
127 |
+
return tgt
|
128 |
+
|
129 |
+
class Fuse_sft_block(nn.Module):
|
130 |
+
def __init__(self, in_ch, out_ch):
|
131 |
+
super().__init__()
|
132 |
+
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
133 |
+
|
134 |
+
self.scale = nn.Sequential(
|
135 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
136 |
+
nn.LeakyReLU(0.2, True),
|
137 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
138 |
+
|
139 |
+
self.shift = nn.Sequential(
|
140 |
+
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
141 |
+
nn.LeakyReLU(0.2, True),
|
142 |
+
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
143 |
+
|
144 |
+
def forward(self, enc_feat, dec_feat, w=1):
|
145 |
+
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
146 |
+
scale = self.scale(enc_feat)
|
147 |
+
shift = self.shift(enc_feat)
|
148 |
+
residual = w * (dec_feat * scale + shift)
|
149 |
+
out = dec_feat + residual
|
150 |
+
return out
|
151 |
+
|
152 |
+
class CodeFormerArch(VQAutoEncoder):
|
153 |
+
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
154 |
+
codebook_size=1024, latent_size=256,
|
155 |
+
connect_list=['32', '64', '128', '256'],
|
156 |
+
fix_modules=['quantize','generator'], vqgan_path=None):
|
157 |
+
super(CodeFormerArch, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
158 |
+
|
159 |
+
if vqgan_path is not None:
|
160 |
+
self.load_state_dict(
|
161 |
+
torch.load(vqgan_path, map_location='cpu')['params_ema'])
|
162 |
+
|
163 |
+
if fix_modules is not None:
|
164 |
+
for module in fix_modules:
|
165 |
+
for param in getattr(self, module).parameters():
|
166 |
+
param.requires_grad = False
|
167 |
+
|
168 |
+
self.connect_list = connect_list
|
169 |
+
self.n_layers = n_layers
|
170 |
+
self.dim_embd = dim_embd
|
171 |
+
self.dim_mlp = dim_embd*2
|
172 |
+
|
173 |
+
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
174 |
+
self.feat_emb = nn.Linear(256, self.dim_embd)
|
175 |
+
|
176 |
+
# transformer
|
177 |
+
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
178 |
+
for _ in range(self.n_layers)])
|
179 |
+
|
180 |
+
# logits_predict head
|
181 |
+
self.idx_pred_layer = nn.Sequential(
|
182 |
+
nn.LayerNorm(dim_embd),
|
183 |
+
nn.Linear(dim_embd, codebook_size, bias=False))
|
184 |
+
|
185 |
+
self.channels = {
|
186 |
+
'16': 512,
|
187 |
+
'32': 256,
|
188 |
+
'64': 256,
|
189 |
+
'128': 128,
|
190 |
+
'256': 128,
|
191 |
+
'512': 64,
|
192 |
+
}
|
193 |
+
|
194 |
+
# after second residual block for > 16, before attn layer for ==16
|
195 |
+
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
196 |
+
# after first residual block for > 16, before attn layer for ==16
|
197 |
+
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
198 |
+
|
199 |
+
# fuse_convs_dict
|
200 |
+
self.fuse_convs_dict = nn.ModuleDict()
|
201 |
+
for f_size in self.connect_list:
|
202 |
+
in_ch = self.channels[f_size]
|
203 |
+
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
204 |
+
|
205 |
+
def _init_weights(self, module):
|
206 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
207 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
208 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
209 |
+
module.bias.data.zero_()
|
210 |
+
elif isinstance(module, nn.LayerNorm):
|
211 |
+
module.bias.data.zero_()
|
212 |
+
module.weight.data.fill_(1.0)
|
213 |
+
|
214 |
+
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
215 |
+
# ################### Encoder #####################
|
216 |
+
enc_feat_dict = {}
|
217 |
+
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
218 |
+
for i, block in enumerate(self.encoder.blocks):
|
219 |
+
x = block(x)
|
220 |
+
if i in out_list:
|
221 |
+
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
222 |
+
|
223 |
+
lq_feat = x
|
224 |
+
# ################# Transformer ###################
|
225 |
+
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
226 |
+
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
227 |
+
# BCHW -> BC(HW) -> (HW)BC
|
228 |
+
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
229 |
+
query_emb = feat_emb
|
230 |
+
# Transformer encoder
|
231 |
+
for layer in self.ft_layers:
|
232 |
+
query_emb = layer(query_emb, query_pos=pos_emb)
|
233 |
+
|
234 |
+
# output logits
|
235 |
+
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
236 |
+
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
237 |
+
|
238 |
+
if code_only: # for training stage II
|
239 |
+
# logits doesn't need softmax before cross_entropy loss
|
240 |
+
return logits, lq_feat
|
241 |
+
|
242 |
+
# ################# Quantization ###################
|
243 |
+
# if self.training:
|
244 |
+
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
245 |
+
# # b(hw)c -> bc(hw) -> bchw
|
246 |
+
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
247 |
+
# ------------
|
248 |
+
soft_one_hot = F.softmax(logits, dim=2)
|
249 |
+
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
250 |
+
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
251 |
+
# preserve gradients
|
252 |
+
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
253 |
+
|
254 |
+
if detach_16:
|
255 |
+
quant_feat = quant_feat.detach() # for training stage III
|
256 |
+
if adain:
|
257 |
+
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
258 |
+
|
259 |
+
# ################## Generator ####################
|
260 |
+
x = quant_feat
|
261 |
+
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
262 |
+
|
263 |
+
for i, block in enumerate(self.generator.blocks):
|
264 |
+
x = block(x)
|
265 |
+
if i in fuse_list: # fuse after i-th block
|
266 |
+
f_size = str(x.shape[-1])
|
267 |
+
if w>0:
|
268 |
+
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
269 |
+
out = x
|
270 |
+
# logits doesn't need softmax before cross_entropy loss
|
271 |
+
return out, logits, lq_feat
|
SR_Inference/codeformer/vqgan_arch.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
3 |
+
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
4 |
+
|
5 |
+
'''
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from basicsr.utils import get_root_logger
|
10 |
+
|
11 |
+
def normalize(in_channels):
|
12 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
13 |
+
|
14 |
+
@torch.jit.script
|
15 |
+
def swish(x):
|
16 |
+
return x*torch.sigmoid(x)
|
17 |
+
|
18 |
+
# Define VQVAE classes
|
19 |
+
class VectorQuantizer(nn.Module):
|
20 |
+
def __init__(self, codebook_size, emb_dim, beta):
|
21 |
+
super(VectorQuantizer, self).__init__()
|
22 |
+
self.codebook_size = codebook_size # number of embeddings
|
23 |
+
self.emb_dim = emb_dim # dimension of embedding
|
24 |
+
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
25 |
+
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
26 |
+
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
27 |
+
|
28 |
+
def forward(self, z):
|
29 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
30 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
31 |
+
z_flattened = z.view(-1, self.emb_dim)
|
32 |
+
|
33 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
34 |
+
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
35 |
+
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
36 |
+
|
37 |
+
mean_distance = torch.mean(d)
|
38 |
+
# find closest encodings
|
39 |
+
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
40 |
+
# min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
41 |
+
# [0-1], higher score, higher confidence
|
42 |
+
# min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
43 |
+
|
44 |
+
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
45 |
+
min_encodings.scatter_(1, min_encoding_indices, 1)
|
46 |
+
|
47 |
+
# get quantized latent vectors
|
48 |
+
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
49 |
+
# compute loss for embedding
|
50 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
51 |
+
# preserve gradients
|
52 |
+
z_q = z + (z_q - z).detach()
|
53 |
+
|
54 |
+
# perplexity
|
55 |
+
e_mean = torch.mean(min_encodings, dim=0)
|
56 |
+
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
57 |
+
# reshape back to match original input shape
|
58 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
59 |
+
|
60 |
+
return z_q, loss, {
|
61 |
+
"perplexity": perplexity,
|
62 |
+
"min_encodings": min_encodings,
|
63 |
+
"min_encoding_indices": min_encoding_indices,
|
64 |
+
"mean_distance": mean_distance
|
65 |
+
}
|
66 |
+
|
67 |
+
def get_codebook_feat(self, indices, shape):
|
68 |
+
# input indices: batch*token_num -> (batch*token_num)*1
|
69 |
+
# shape: batch, height, width, channel
|
70 |
+
indices = indices.view(-1,1)
|
71 |
+
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
72 |
+
min_encodings.scatter_(1, indices, 1)
|
73 |
+
# get quantized latent vectors
|
74 |
+
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
75 |
+
|
76 |
+
if shape is not None: # reshape back to match original input shape
|
77 |
+
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
78 |
+
|
79 |
+
return z_q
|
80 |
+
|
81 |
+
class GumbelQuantizer(nn.Module):
|
82 |
+
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
83 |
+
super().__init__()
|
84 |
+
self.codebook_size = codebook_size # number of embeddings
|
85 |
+
self.emb_dim = emb_dim # dimension of embedding
|
86 |
+
self.straight_through = straight_through
|
87 |
+
self.temperature = temp_init
|
88 |
+
self.kl_weight = kl_weight
|
89 |
+
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
90 |
+
self.embed = nn.Embedding(codebook_size, emb_dim)
|
91 |
+
|
92 |
+
def forward(self, z):
|
93 |
+
hard = self.straight_through if self.training else True
|
94 |
+
|
95 |
+
logits = self.proj(z)
|
96 |
+
|
97 |
+
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
98 |
+
|
99 |
+
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
100 |
+
|
101 |
+
# + kl divergence to the prior loss
|
102 |
+
qy = F.softmax(logits, dim=1)
|
103 |
+
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
104 |
+
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
105 |
+
|
106 |
+
return z_q, diff, {
|
107 |
+
"min_encoding_indices": min_encoding_indices
|
108 |
+
}
|
109 |
+
|
110 |
+
class Downsample(nn.Module):
|
111 |
+
def __init__(self, in_channels):
|
112 |
+
super().__init__()
|
113 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
pad = (0, 1, 0, 1)
|
117 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
118 |
+
x = self.conv(x)
|
119 |
+
return x
|
120 |
+
|
121 |
+
class Upsample(nn.Module):
|
122 |
+
def __init__(self, in_channels):
|
123 |
+
super().__init__()
|
124 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
128 |
+
x = self.conv(x)
|
129 |
+
|
130 |
+
return x
|
131 |
+
|
132 |
+
class ResBlock(nn.Module):
|
133 |
+
def __init__(self, in_channels, out_channels=None):
|
134 |
+
super(ResBlock, self).__init__()
|
135 |
+
self.in_channels = in_channels
|
136 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
137 |
+
self.norm1 = normalize(in_channels)
|
138 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
139 |
+
self.norm2 = normalize(out_channels)
|
140 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
141 |
+
if self.in_channels != self.out_channels:
|
142 |
+
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
143 |
+
|
144 |
+
def forward(self, x_in):
|
145 |
+
x = x_in
|
146 |
+
x = self.norm1(x)
|
147 |
+
x = swish(x)
|
148 |
+
x = self.conv1(x)
|
149 |
+
x = self.norm2(x)
|
150 |
+
x = swish(x)
|
151 |
+
x = self.conv2(x)
|
152 |
+
if self.in_channels != self.out_channels:
|
153 |
+
x_in = self.conv_out(x_in)
|
154 |
+
|
155 |
+
return x + x_in
|
156 |
+
|
157 |
+
class AttnBlock(nn.Module):
|
158 |
+
def __init__(self, in_channels):
|
159 |
+
super().__init__()
|
160 |
+
self.in_channels = in_channels
|
161 |
+
|
162 |
+
self.norm = normalize(in_channels)
|
163 |
+
self.q = torch.nn.Conv2d(
|
164 |
+
in_channels,
|
165 |
+
in_channels,
|
166 |
+
kernel_size=1,
|
167 |
+
stride=1,
|
168 |
+
padding=0
|
169 |
+
)
|
170 |
+
self.k = torch.nn.Conv2d(
|
171 |
+
in_channels,
|
172 |
+
in_channels,
|
173 |
+
kernel_size=1,
|
174 |
+
stride=1,
|
175 |
+
padding=0
|
176 |
+
)
|
177 |
+
self.v = torch.nn.Conv2d(
|
178 |
+
in_channels,
|
179 |
+
in_channels,
|
180 |
+
kernel_size=1,
|
181 |
+
stride=1,
|
182 |
+
padding=0
|
183 |
+
)
|
184 |
+
self.proj_out = torch.nn.Conv2d(
|
185 |
+
in_channels,
|
186 |
+
in_channels,
|
187 |
+
kernel_size=1,
|
188 |
+
stride=1,
|
189 |
+
padding=0
|
190 |
+
)
|
191 |
+
|
192 |
+
def forward(self, x):
|
193 |
+
h_ = x
|
194 |
+
h_ = self.norm(h_)
|
195 |
+
q = self.q(h_)
|
196 |
+
k = self.k(h_)
|
197 |
+
v = self.v(h_)
|
198 |
+
|
199 |
+
# compute attention
|
200 |
+
b, c, h, w = q.shape
|
201 |
+
q = q.reshape(b, c, h*w)
|
202 |
+
q = q.permute(0, 2, 1)
|
203 |
+
k = k.reshape(b, c, h*w)
|
204 |
+
w_ = torch.bmm(q, k)
|
205 |
+
w_ = w_ * (int(c)**(-0.5))
|
206 |
+
w_ = F.softmax(w_, dim=2)
|
207 |
+
|
208 |
+
# attend to values
|
209 |
+
v = v.reshape(b, c, h*w)
|
210 |
+
w_ = w_.permute(0, 2, 1)
|
211 |
+
h_ = torch.bmm(v, w_)
|
212 |
+
h_ = h_.reshape(b, c, h, w)
|
213 |
+
|
214 |
+
h_ = self.proj_out(h_)
|
215 |
+
|
216 |
+
return x+h_
|
217 |
+
|
218 |
+
class Encoder(nn.Module):
|
219 |
+
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
220 |
+
super().__init__()
|
221 |
+
self.nf = nf
|
222 |
+
self.num_resolutions = len(ch_mult)
|
223 |
+
self.num_res_blocks = num_res_blocks
|
224 |
+
self.resolution = resolution
|
225 |
+
self.attn_resolutions = attn_resolutions
|
226 |
+
|
227 |
+
curr_res = self.resolution
|
228 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
229 |
+
|
230 |
+
blocks = []
|
231 |
+
# initial convultion
|
232 |
+
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
233 |
+
|
234 |
+
# residual and downsampling blocks, with attention on smaller res (16x16)
|
235 |
+
for i in range(self.num_resolutions):
|
236 |
+
block_in_ch = nf * in_ch_mult[i]
|
237 |
+
block_out_ch = nf * ch_mult[i]
|
238 |
+
for _ in range(self.num_res_blocks):
|
239 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
240 |
+
block_in_ch = block_out_ch
|
241 |
+
if curr_res in attn_resolutions:
|
242 |
+
blocks.append(AttnBlock(block_in_ch))
|
243 |
+
|
244 |
+
if i != self.num_resolutions - 1:
|
245 |
+
blocks.append(Downsample(block_in_ch))
|
246 |
+
curr_res = curr_res // 2
|
247 |
+
|
248 |
+
# non-local attention block
|
249 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
250 |
+
blocks.append(AttnBlock(block_in_ch))
|
251 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
252 |
+
|
253 |
+
# normalise and convert to latent size
|
254 |
+
blocks.append(normalize(block_in_ch))
|
255 |
+
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
256 |
+
self.blocks = nn.ModuleList(blocks)
|
257 |
+
|
258 |
+
def forward(self, x):
|
259 |
+
for block in self.blocks:
|
260 |
+
x = block(x)
|
261 |
+
|
262 |
+
return x
|
263 |
+
|
264 |
+
class Generator(nn.Module):
|
265 |
+
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
266 |
+
super().__init__()
|
267 |
+
self.nf = nf
|
268 |
+
self.ch_mult = ch_mult
|
269 |
+
self.num_resolutions = len(self.ch_mult)
|
270 |
+
self.num_res_blocks = res_blocks
|
271 |
+
self.resolution = img_size
|
272 |
+
self.attn_resolutions = attn_resolutions
|
273 |
+
self.in_channels = emb_dim
|
274 |
+
self.out_channels = 3
|
275 |
+
block_in_ch = self.nf * self.ch_mult[-1]
|
276 |
+
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
277 |
+
|
278 |
+
blocks = []
|
279 |
+
# initial conv
|
280 |
+
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
281 |
+
|
282 |
+
# non-local attention block
|
283 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
284 |
+
blocks.append(AttnBlock(block_in_ch))
|
285 |
+
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
286 |
+
|
287 |
+
for i in reversed(range(self.num_resolutions)):
|
288 |
+
block_out_ch = self.nf * self.ch_mult[i]
|
289 |
+
|
290 |
+
for _ in range(self.num_res_blocks):
|
291 |
+
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
292 |
+
block_in_ch = block_out_ch
|
293 |
+
|
294 |
+
if curr_res in self.attn_resolutions:
|
295 |
+
blocks.append(AttnBlock(block_in_ch))
|
296 |
+
|
297 |
+
if i != 0:
|
298 |
+
blocks.append(Upsample(block_in_ch))
|
299 |
+
curr_res = curr_res * 2
|
300 |
+
|
301 |
+
blocks.append(normalize(block_in_ch))
|
302 |
+
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
303 |
+
|
304 |
+
self.blocks = nn.ModuleList(blocks)
|
305 |
+
|
306 |
+
|
307 |
+
def forward(self, x):
|
308 |
+
for block in self.blocks:
|
309 |
+
x = block(x)
|
310 |
+
|
311 |
+
return x
|
312 |
+
|
313 |
+
class VQAutoEncoder(nn.Module):
|
314 |
+
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
315 |
+
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
316 |
+
super().__init__()
|
317 |
+
logger = get_root_logger()
|
318 |
+
self.in_channels = 3
|
319 |
+
self.nf = nf
|
320 |
+
self.n_blocks = res_blocks
|
321 |
+
self.codebook_size = codebook_size
|
322 |
+
self.embed_dim = emb_dim
|
323 |
+
self.ch_mult = ch_mult
|
324 |
+
self.resolution = img_size
|
325 |
+
self.attn_resolutions = attn_resolutions
|
326 |
+
self.quantizer_type = quantizer
|
327 |
+
self.encoder = Encoder(
|
328 |
+
self.in_channels,
|
329 |
+
self.nf,
|
330 |
+
self.embed_dim,
|
331 |
+
self.ch_mult,
|
332 |
+
self.n_blocks,
|
333 |
+
self.resolution,
|
334 |
+
self.attn_resolutions
|
335 |
+
)
|
336 |
+
if self.quantizer_type == "nearest":
|
337 |
+
self.beta = beta #0.25
|
338 |
+
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
339 |
+
elif self.quantizer_type == "gumbel":
|
340 |
+
self.gumbel_num_hiddens = emb_dim
|
341 |
+
self.straight_through = gumbel_straight_through
|
342 |
+
self.kl_weight = gumbel_kl_weight
|
343 |
+
self.quantize = GumbelQuantizer(
|
344 |
+
self.codebook_size,
|
345 |
+
self.embed_dim,
|
346 |
+
self.gumbel_num_hiddens,
|
347 |
+
self.straight_through,
|
348 |
+
self.kl_weight
|
349 |
+
)
|
350 |
+
self.generator = Generator(
|
351 |
+
self.nf,
|
352 |
+
self.embed_dim,
|
353 |
+
self.ch_mult,
|
354 |
+
self.n_blocks,
|
355 |
+
self.resolution,
|
356 |
+
self.attn_resolutions
|
357 |
+
)
|
358 |
+
|
359 |
+
if model_path is not None:
|
360 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
361 |
+
if 'params_ema' in chkpt:
|
362 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
363 |
+
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
364 |
+
elif 'params' in chkpt:
|
365 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
366 |
+
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
367 |
+
else:
|
368 |
+
raise ValueError(f'Wrong params!')
|
369 |
+
|
370 |
+
|
371 |
+
def forward(self, x):
|
372 |
+
x = self.encoder(x)
|
373 |
+
quant, codebook_loss, quant_stats = self.quantize(x)
|
374 |
+
x = self.generator(quant)
|
375 |
+
return x, codebook_loss, quant_stats
|
376 |
+
|
377 |
+
# patch based discriminator
|
378 |
+
class VQGANDiscriminator(nn.Module):
|
379 |
+
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
380 |
+
super().__init__()
|
381 |
+
|
382 |
+
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
383 |
+
ndf_mult = 1
|
384 |
+
ndf_mult_prev = 1
|
385 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
386 |
+
ndf_mult_prev = ndf_mult
|
387 |
+
ndf_mult = min(2 ** n, 8)
|
388 |
+
layers += [
|
389 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
390 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
391 |
+
nn.LeakyReLU(0.2, True)
|
392 |
+
]
|
393 |
+
|
394 |
+
ndf_mult_prev = ndf_mult
|
395 |
+
ndf_mult = min(2 ** n_layers, 8)
|
396 |
+
|
397 |
+
layers += [
|
398 |
+
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
399 |
+
nn.BatchNorm2d(ndf * ndf_mult),
|
400 |
+
nn.LeakyReLU(0.2, True)
|
401 |
+
]
|
402 |
+
|
403 |
+
layers += [
|
404 |
+
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
405 |
+
self.main = nn.Sequential(*layers)
|
406 |
+
|
407 |
+
if model_path is not None:
|
408 |
+
chkpt = torch.load(model_path, map_location='cpu')
|
409 |
+
if 'params_d' in chkpt:
|
410 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
411 |
+
elif 'params' in chkpt:
|
412 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
413 |
+
else:
|
414 |
+
raise ValueError(f'Wrong params!')
|
415 |
+
|
416 |
+
def forward(self, x):
|
417 |
+
return self.main(x)
|
418 |
+
|
SR_Inference/codeformer/weights/codeformer_v0.1.0.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1009e537e0c2a07d4cabce6355f53cb66767cd4b4297ec7a4a64ca4b8a5684b7
|
3 |
+
size 376637898
|
SR_Inference/gfpgan/weights/GFPGANv1.3.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70
|
3 |
+
size 348632874
|
SR_Inference/gfpgan/weights/detection_Resnet50_Final.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d
|
3 |
+
size 109497761
|
SR_Inference/gfpgan/weights/parsing_parsenet.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2
|
3 |
+
size 85331193
|
SR_Inference/hat/hat_arch.py
ADDED
@@ -0,0 +1,979 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from einops import rearrange
|
5 |
+
from basicsr.archs.arch_util import to_2tuple, trunc_normal_
|
6 |
+
|
7 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
8 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
9 |
+
|
10 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
11 |
+
"""
|
12 |
+
if drop_prob == 0. or not training:
|
13 |
+
return x
|
14 |
+
keep_prob = 1 - drop_prob
|
15 |
+
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
16 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
17 |
+
random_tensor.floor_() # binarize
|
18 |
+
output = x.div(keep_prob) * random_tensor
|
19 |
+
return output
|
20 |
+
|
21 |
+
|
22 |
+
class DropPath(nn.Module):
|
23 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
24 |
+
|
25 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, drop_prob=None):
|
29 |
+
super(DropPath, self).__init__()
|
30 |
+
self.drop_prob = drop_prob
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
return drop_path(x, self.drop_prob, self.training)
|
34 |
+
|
35 |
+
|
36 |
+
class ChannelAttention(nn.Module):
|
37 |
+
"""Channel attention used in RCAN.
|
38 |
+
Args:
|
39 |
+
num_feat (int): Channel number of intermediate features.
|
40 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(self, num_feat, squeeze_factor=16):
|
44 |
+
super(ChannelAttention, self).__init__()
|
45 |
+
self.attention = nn.Sequential(
|
46 |
+
nn.AdaptiveAvgPool2d(1),
|
47 |
+
nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
|
48 |
+
nn.ReLU(inplace=True),
|
49 |
+
nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
|
50 |
+
nn.Sigmoid())
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
y = self.attention(x)
|
54 |
+
return x * y
|
55 |
+
|
56 |
+
|
57 |
+
class CAB(nn.Module):
|
58 |
+
|
59 |
+
def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
|
60 |
+
super(CAB, self).__init__()
|
61 |
+
|
62 |
+
self.cab = nn.Sequential(
|
63 |
+
nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
|
64 |
+
nn.GELU(),
|
65 |
+
nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
|
66 |
+
ChannelAttention(num_feat, squeeze_factor)
|
67 |
+
)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
return self.cab(x)
|
71 |
+
|
72 |
+
|
73 |
+
class Mlp(nn.Module):
|
74 |
+
|
75 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
76 |
+
super().__init__()
|
77 |
+
out_features = out_features or in_features
|
78 |
+
hidden_features = hidden_features or in_features
|
79 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
80 |
+
self.act = act_layer()
|
81 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
82 |
+
self.drop = nn.Dropout(drop)
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
x = self.fc1(x)
|
86 |
+
x = self.act(x)
|
87 |
+
x = self.drop(x)
|
88 |
+
x = self.fc2(x)
|
89 |
+
x = self.drop(x)
|
90 |
+
return x
|
91 |
+
|
92 |
+
|
93 |
+
def window_partition(x, window_size):
|
94 |
+
"""
|
95 |
+
Args:
|
96 |
+
x: (b, h, w, c)
|
97 |
+
window_size (int): window size
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
windows: (num_windows*b, window_size, window_size, c)
|
101 |
+
"""
|
102 |
+
b, h, w, c = x.shape
|
103 |
+
x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
|
104 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
|
105 |
+
return windows
|
106 |
+
|
107 |
+
|
108 |
+
def window_reverse(windows, window_size, h, w):
|
109 |
+
"""
|
110 |
+
Args:
|
111 |
+
windows: (num_windows*b, window_size, window_size, c)
|
112 |
+
window_size (int): Window size
|
113 |
+
h (int): Height of image
|
114 |
+
w (int): Width of image
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
x: (b, h, w, c)
|
118 |
+
"""
|
119 |
+
b = int(windows.shape[0] / (h * w / window_size / window_size))
|
120 |
+
x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
|
121 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
class WindowAttention(nn.Module):
|
126 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
127 |
+
It supports both of shifted and non-shifted window.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
dim (int): Number of input channels.
|
131 |
+
window_size (tuple[int]): The height and width of the window.
|
132 |
+
num_heads (int): Number of attention heads.
|
133 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
134 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
135 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
136 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
137 |
+
"""
|
138 |
+
|
139 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
140 |
+
|
141 |
+
super().__init__()
|
142 |
+
self.dim = dim
|
143 |
+
self.window_size = window_size # Wh, Ww
|
144 |
+
self.num_heads = num_heads
|
145 |
+
head_dim = dim // num_heads
|
146 |
+
self.scale = qk_scale or head_dim**-0.5
|
147 |
+
|
148 |
+
# define a parameter table of relative position bias
|
149 |
+
self.relative_position_bias_table = nn.Parameter(
|
150 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
151 |
+
|
152 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
153 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
154 |
+
self.proj = nn.Linear(dim, dim)
|
155 |
+
|
156 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
157 |
+
|
158 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
159 |
+
self.softmax = nn.Softmax(dim=-1)
|
160 |
+
|
161 |
+
def forward(self, x, rpi, mask=None):
|
162 |
+
"""
|
163 |
+
Args:
|
164 |
+
x: input features with shape of (num_windows*b, n, c)
|
165 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
166 |
+
"""
|
167 |
+
b_, n, c = x.shape
|
168 |
+
qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
|
169 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
170 |
+
|
171 |
+
q = q * self.scale
|
172 |
+
attn = (q @ k.transpose(-2, -1))
|
173 |
+
|
174 |
+
relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
|
175 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
176 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
177 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
178 |
+
|
179 |
+
if mask is not None:
|
180 |
+
nw = mask.shape[0]
|
181 |
+
attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
|
182 |
+
attn = attn.view(-1, self.num_heads, n, n)
|
183 |
+
attn = self.softmax(attn)
|
184 |
+
else:
|
185 |
+
attn = self.softmax(attn)
|
186 |
+
|
187 |
+
attn = self.attn_drop(attn)
|
188 |
+
|
189 |
+
x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
|
190 |
+
x = self.proj(x)
|
191 |
+
x = self.proj_drop(x)
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
class HAB(nn.Module):
|
196 |
+
r""" Hybrid Attention Block.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
dim (int): Number of input channels.
|
200 |
+
input_resolution (tuple[int]): Input resolution.
|
201 |
+
num_heads (int): Number of attention heads.
|
202 |
+
window_size (int): Window size.
|
203 |
+
shift_size (int): Shift size for SW-MSA.
|
204 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
205 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
206 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
207 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
208 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
209 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
210 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
211 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
212 |
+
"""
|
213 |
+
|
214 |
+
def __init__(self,
|
215 |
+
dim,
|
216 |
+
input_resolution,
|
217 |
+
num_heads,
|
218 |
+
window_size=7,
|
219 |
+
shift_size=0,
|
220 |
+
compress_ratio=3,
|
221 |
+
squeeze_factor=30,
|
222 |
+
conv_scale=0.01,
|
223 |
+
mlp_ratio=4.,
|
224 |
+
qkv_bias=True,
|
225 |
+
qk_scale=None,
|
226 |
+
drop=0.,
|
227 |
+
attn_drop=0.,
|
228 |
+
drop_path=0.,
|
229 |
+
act_layer=nn.GELU,
|
230 |
+
norm_layer=nn.LayerNorm):
|
231 |
+
super().__init__()
|
232 |
+
self.dim = dim
|
233 |
+
self.input_resolution = input_resolution
|
234 |
+
self.num_heads = num_heads
|
235 |
+
self.window_size = window_size
|
236 |
+
self.shift_size = shift_size
|
237 |
+
self.mlp_ratio = mlp_ratio
|
238 |
+
if min(self.input_resolution) <= self.window_size:
|
239 |
+
# if window size is larger than input resolution, we don't partition windows
|
240 |
+
self.shift_size = 0
|
241 |
+
self.window_size = min(self.input_resolution)
|
242 |
+
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
|
243 |
+
|
244 |
+
self.norm1 = norm_layer(dim)
|
245 |
+
self.attn = WindowAttention(
|
246 |
+
dim,
|
247 |
+
window_size=to_2tuple(self.window_size),
|
248 |
+
num_heads=num_heads,
|
249 |
+
qkv_bias=qkv_bias,
|
250 |
+
qk_scale=qk_scale,
|
251 |
+
attn_drop=attn_drop,
|
252 |
+
proj_drop=drop)
|
253 |
+
|
254 |
+
self.conv_scale = conv_scale
|
255 |
+
self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor)
|
256 |
+
|
257 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
258 |
+
self.norm2 = norm_layer(dim)
|
259 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
260 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
261 |
+
|
262 |
+
def forward(self, x, x_size, rpi_sa, attn_mask):
|
263 |
+
h, w = x_size
|
264 |
+
b, _, c = x.shape
|
265 |
+
# assert seq_len == h * w, "input feature has wrong size"
|
266 |
+
|
267 |
+
shortcut = x
|
268 |
+
x = self.norm1(x)
|
269 |
+
x = x.view(b, h, w, c)
|
270 |
+
|
271 |
+
# Conv_X
|
272 |
+
conv_x = self.conv_block(x.permute(0, 3, 1, 2))
|
273 |
+
conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c)
|
274 |
+
|
275 |
+
# cyclic shift
|
276 |
+
if self.shift_size > 0:
|
277 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
278 |
+
attn_mask = attn_mask
|
279 |
+
else:
|
280 |
+
shifted_x = x
|
281 |
+
attn_mask = None
|
282 |
+
|
283 |
+
# partition windows
|
284 |
+
x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
|
285 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
|
286 |
+
|
287 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
288 |
+
attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask)
|
289 |
+
|
290 |
+
# merge windows
|
291 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
|
292 |
+
shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
|
293 |
+
|
294 |
+
# reverse cyclic shift
|
295 |
+
if self.shift_size > 0:
|
296 |
+
attn_x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
297 |
+
else:
|
298 |
+
attn_x = shifted_x
|
299 |
+
attn_x = attn_x.view(b, h * w, c)
|
300 |
+
|
301 |
+
# FFN
|
302 |
+
x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale
|
303 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
304 |
+
|
305 |
+
return x
|
306 |
+
|
307 |
+
|
308 |
+
class PatchMerging(nn.Module):
|
309 |
+
r""" Patch Merging Layer.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
313 |
+
dim (int): Number of input channels.
|
314 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
315 |
+
"""
|
316 |
+
|
317 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
318 |
+
super().__init__()
|
319 |
+
self.input_resolution = input_resolution
|
320 |
+
self.dim = dim
|
321 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
322 |
+
self.norm = norm_layer(4 * dim)
|
323 |
+
|
324 |
+
def forward(self, x):
|
325 |
+
"""
|
326 |
+
x: b, h*w, c
|
327 |
+
"""
|
328 |
+
h, w = self.input_resolution
|
329 |
+
b, seq_len, c = x.shape
|
330 |
+
assert seq_len == h * w, 'input feature has wrong size'
|
331 |
+
assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
|
332 |
+
|
333 |
+
x = x.view(b, h, w, c)
|
334 |
+
|
335 |
+
x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
|
336 |
+
x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
|
337 |
+
x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
|
338 |
+
x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
|
339 |
+
x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
|
340 |
+
x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
|
341 |
+
|
342 |
+
x = self.norm(x)
|
343 |
+
x = self.reduction(x)
|
344 |
+
|
345 |
+
return x
|
346 |
+
|
347 |
+
|
348 |
+
class OCAB(nn.Module):
|
349 |
+
# overlapping cross-attention block
|
350 |
+
|
351 |
+
def __init__(self, dim,
|
352 |
+
input_resolution,
|
353 |
+
window_size,
|
354 |
+
overlap_ratio,
|
355 |
+
num_heads,
|
356 |
+
qkv_bias=True,
|
357 |
+
qk_scale=None,
|
358 |
+
mlp_ratio=2,
|
359 |
+
norm_layer=nn.LayerNorm
|
360 |
+
):
|
361 |
+
|
362 |
+
super().__init__()
|
363 |
+
self.dim = dim
|
364 |
+
self.input_resolution = input_resolution
|
365 |
+
self.window_size = window_size
|
366 |
+
self.num_heads = num_heads
|
367 |
+
head_dim = dim // num_heads
|
368 |
+
self.scale = qk_scale or head_dim**-0.5
|
369 |
+
self.overlap_win_size = int(window_size * overlap_ratio) + window_size
|
370 |
+
|
371 |
+
self.norm1 = norm_layer(dim)
|
372 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
373 |
+
self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size, padding=(self.overlap_win_size-window_size)//2)
|
374 |
+
|
375 |
+
# define a parameter table of relative position bias
|
376 |
+
self.relative_position_bias_table = nn.Parameter(
|
377 |
+
torch.zeros((window_size + self.overlap_win_size - 1) * (window_size + self.overlap_win_size - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
378 |
+
|
379 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
380 |
+
self.softmax = nn.Softmax(dim=-1)
|
381 |
+
|
382 |
+
self.proj = nn.Linear(dim,dim)
|
383 |
+
|
384 |
+
self.norm2 = norm_layer(dim)
|
385 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
386 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU)
|
387 |
+
|
388 |
+
def forward(self, x, x_size, rpi):
|
389 |
+
h, w = x_size
|
390 |
+
b, _, c = x.shape
|
391 |
+
|
392 |
+
shortcut = x
|
393 |
+
x = self.norm1(x)
|
394 |
+
x = x.view(b, h, w, c)
|
395 |
+
|
396 |
+
qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2) # 3, b, c, h, w
|
397 |
+
q = qkv[0].permute(0, 2, 3, 1) # b, h, w, c
|
398 |
+
kv = torch.cat((qkv[1], qkv[2]), dim=1) # b, 2*c, h, w
|
399 |
+
|
400 |
+
# partition windows
|
401 |
+
q_windows = window_partition(q, self.window_size) # nw*b, window_size, window_size, c
|
402 |
+
q_windows = q_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
|
403 |
+
|
404 |
+
kv_windows = self.unfold(kv) # b, c*w*w, nw
|
405 |
+
kv_windows = rearrange(kv_windows, 'b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch', nc=2, ch=c, owh=self.overlap_win_size, oww=self.overlap_win_size).contiguous() # 2, nw*b, ow*ow, c
|
406 |
+
k_windows, v_windows = kv_windows[0], kv_windows[1] # nw*b, ow*ow, c
|
407 |
+
|
408 |
+
b_, nq, _ = q_windows.shape
|
409 |
+
_, n, _ = k_windows.shape
|
410 |
+
d = self.dim // self.num_heads
|
411 |
+
q = q_windows.reshape(b_, nq, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, nq, d
|
412 |
+
k = k_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, n, d
|
413 |
+
v = v_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, n, d
|
414 |
+
|
415 |
+
q = q * self.scale
|
416 |
+
attn = (q @ k.transpose(-2, -1))
|
417 |
+
|
418 |
+
relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
|
419 |
+
self.window_size * self.window_size, self.overlap_win_size * self.overlap_win_size, -1) # ws*ws, wse*wse, nH
|
420 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, ws*ws, wse*wse
|
421 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
422 |
+
|
423 |
+
attn = self.softmax(attn)
|
424 |
+
attn_windows = (attn @ v).transpose(1, 2).reshape(b_, nq, self.dim)
|
425 |
+
|
426 |
+
# merge windows
|
427 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim)
|
428 |
+
x = window_reverse(attn_windows, self.window_size, h, w) # b h w c
|
429 |
+
x = x.view(b, h * w, self.dim)
|
430 |
+
|
431 |
+
x = self.proj(x) + shortcut
|
432 |
+
|
433 |
+
x = x + self.mlp(self.norm2(x))
|
434 |
+
return x
|
435 |
+
|
436 |
+
|
437 |
+
class AttenBlocks(nn.Module):
|
438 |
+
""" A series of attention blocks for one RHAG.
|
439 |
+
|
440 |
+
Args:
|
441 |
+
dim (int): Number of input channels.
|
442 |
+
input_resolution (tuple[int]): Input resolution.
|
443 |
+
depth (int): Number of blocks.
|
444 |
+
num_heads (int): Number of attention heads.
|
445 |
+
window_size (int): Local window size.
|
446 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
447 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
448 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
449 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
450 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
451 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
452 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
453 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
454 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
455 |
+
"""
|
456 |
+
|
457 |
+
def __init__(self,
|
458 |
+
dim,
|
459 |
+
input_resolution,
|
460 |
+
depth,
|
461 |
+
num_heads,
|
462 |
+
window_size,
|
463 |
+
compress_ratio,
|
464 |
+
squeeze_factor,
|
465 |
+
conv_scale,
|
466 |
+
overlap_ratio,
|
467 |
+
mlp_ratio=4.,
|
468 |
+
qkv_bias=True,
|
469 |
+
qk_scale=None,
|
470 |
+
drop=0.,
|
471 |
+
attn_drop=0.,
|
472 |
+
drop_path=0.,
|
473 |
+
norm_layer=nn.LayerNorm,
|
474 |
+
downsample=None,
|
475 |
+
use_checkpoint=False):
|
476 |
+
|
477 |
+
super().__init__()
|
478 |
+
self.dim = dim
|
479 |
+
self.input_resolution = input_resolution
|
480 |
+
self.depth = depth
|
481 |
+
self.use_checkpoint = use_checkpoint
|
482 |
+
|
483 |
+
# build blocks
|
484 |
+
self.blocks = nn.ModuleList([
|
485 |
+
HAB(
|
486 |
+
dim=dim,
|
487 |
+
input_resolution=input_resolution,
|
488 |
+
num_heads=num_heads,
|
489 |
+
window_size=window_size,
|
490 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
491 |
+
compress_ratio=compress_ratio,
|
492 |
+
squeeze_factor=squeeze_factor,
|
493 |
+
conv_scale=conv_scale,
|
494 |
+
mlp_ratio=mlp_ratio,
|
495 |
+
qkv_bias=qkv_bias,
|
496 |
+
qk_scale=qk_scale,
|
497 |
+
drop=drop,
|
498 |
+
attn_drop=attn_drop,
|
499 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
500 |
+
norm_layer=norm_layer) for i in range(depth)
|
501 |
+
])
|
502 |
+
|
503 |
+
# OCAB
|
504 |
+
self.overlap_attn = OCAB(
|
505 |
+
dim=dim,
|
506 |
+
input_resolution=input_resolution,
|
507 |
+
window_size=window_size,
|
508 |
+
overlap_ratio=overlap_ratio,
|
509 |
+
num_heads=num_heads,
|
510 |
+
qkv_bias=qkv_bias,
|
511 |
+
qk_scale=qk_scale,
|
512 |
+
mlp_ratio=mlp_ratio,
|
513 |
+
norm_layer=norm_layer
|
514 |
+
)
|
515 |
+
|
516 |
+
# patch merging layer
|
517 |
+
if downsample is not None:
|
518 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
519 |
+
else:
|
520 |
+
self.downsample = None
|
521 |
+
|
522 |
+
def forward(self, x, x_size, params):
|
523 |
+
for blk in self.blocks:
|
524 |
+
x = blk(x, x_size, params['rpi_sa'], params['attn_mask'])
|
525 |
+
|
526 |
+
x = self.overlap_attn(x, x_size, params['rpi_oca'])
|
527 |
+
|
528 |
+
if self.downsample is not None:
|
529 |
+
x = self.downsample(x)
|
530 |
+
return x
|
531 |
+
|
532 |
+
|
533 |
+
class RHAG(nn.Module):
|
534 |
+
"""Residual Hybrid Attention Group (RHAG).
|
535 |
+
|
536 |
+
Args:
|
537 |
+
dim (int): Number of input channels.
|
538 |
+
input_resolution (tuple[int]): Input resolution.
|
539 |
+
depth (int): Number of blocks.
|
540 |
+
num_heads (int): Number of attention heads.
|
541 |
+
window_size (int): Local window size.
|
542 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
543 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
544 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
545 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
546 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
547 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
548 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
549 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
550 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
551 |
+
img_size: Input image size.
|
552 |
+
patch_size: Patch size.
|
553 |
+
resi_connection: The convolutional block before residual connection.
|
554 |
+
"""
|
555 |
+
|
556 |
+
def __init__(self,
|
557 |
+
dim,
|
558 |
+
input_resolution,
|
559 |
+
depth,
|
560 |
+
num_heads,
|
561 |
+
window_size,
|
562 |
+
compress_ratio,
|
563 |
+
squeeze_factor,
|
564 |
+
conv_scale,
|
565 |
+
overlap_ratio,
|
566 |
+
mlp_ratio=4.,
|
567 |
+
qkv_bias=True,
|
568 |
+
qk_scale=None,
|
569 |
+
drop=0.,
|
570 |
+
attn_drop=0.,
|
571 |
+
drop_path=0.,
|
572 |
+
norm_layer=nn.LayerNorm,
|
573 |
+
downsample=None,
|
574 |
+
use_checkpoint=False,
|
575 |
+
img_size=224,
|
576 |
+
patch_size=4,
|
577 |
+
resi_connection='1conv'):
|
578 |
+
super(RHAG, self).__init__()
|
579 |
+
|
580 |
+
self.dim = dim
|
581 |
+
self.input_resolution = input_resolution
|
582 |
+
|
583 |
+
self.residual_group = AttenBlocks(
|
584 |
+
dim=dim,
|
585 |
+
input_resolution=input_resolution,
|
586 |
+
depth=depth,
|
587 |
+
num_heads=num_heads,
|
588 |
+
window_size=window_size,
|
589 |
+
compress_ratio=compress_ratio,
|
590 |
+
squeeze_factor=squeeze_factor,
|
591 |
+
conv_scale=conv_scale,
|
592 |
+
overlap_ratio=overlap_ratio,
|
593 |
+
mlp_ratio=mlp_ratio,
|
594 |
+
qkv_bias=qkv_bias,
|
595 |
+
qk_scale=qk_scale,
|
596 |
+
drop=drop,
|
597 |
+
attn_drop=attn_drop,
|
598 |
+
drop_path=drop_path,
|
599 |
+
norm_layer=norm_layer,
|
600 |
+
downsample=downsample,
|
601 |
+
use_checkpoint=use_checkpoint)
|
602 |
+
|
603 |
+
if resi_connection == '1conv':
|
604 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
605 |
+
elif resi_connection == 'identity':
|
606 |
+
self.conv = nn.Identity()
|
607 |
+
|
608 |
+
self.patch_embed = PatchEmbed(
|
609 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
|
610 |
+
|
611 |
+
self.patch_unembed = PatchUnEmbed(
|
612 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
|
613 |
+
|
614 |
+
def forward(self, x, x_size, params):
|
615 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + x
|
616 |
+
|
617 |
+
|
618 |
+
class PatchEmbed(nn.Module):
|
619 |
+
r""" Image to Patch Embedding
|
620 |
+
|
621 |
+
Args:
|
622 |
+
img_size (int): Image size. Default: 224.
|
623 |
+
patch_size (int): Patch token size. Default: 4.
|
624 |
+
in_chans (int): Number of input image channels. Default: 3.
|
625 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
626 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
627 |
+
"""
|
628 |
+
|
629 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
630 |
+
super().__init__()
|
631 |
+
img_size = to_2tuple(img_size)
|
632 |
+
patch_size = to_2tuple(patch_size)
|
633 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
634 |
+
self.img_size = img_size
|
635 |
+
self.patch_size = patch_size
|
636 |
+
self.patches_resolution = patches_resolution
|
637 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
638 |
+
|
639 |
+
self.in_chans = in_chans
|
640 |
+
self.embed_dim = embed_dim
|
641 |
+
|
642 |
+
if norm_layer is not None:
|
643 |
+
self.norm = norm_layer(embed_dim)
|
644 |
+
else:
|
645 |
+
self.norm = None
|
646 |
+
|
647 |
+
def forward(self, x):
|
648 |
+
x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
|
649 |
+
if self.norm is not None:
|
650 |
+
x = self.norm(x)
|
651 |
+
return x
|
652 |
+
|
653 |
+
|
654 |
+
class PatchUnEmbed(nn.Module):
|
655 |
+
r""" Image to Patch Unembedding
|
656 |
+
|
657 |
+
Args:
|
658 |
+
img_size (int): Image size. Default: 224.
|
659 |
+
patch_size (int): Patch token size. Default: 4.
|
660 |
+
in_chans (int): Number of input image channels. Default: 3.
|
661 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
662 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
663 |
+
"""
|
664 |
+
|
665 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
666 |
+
super().__init__()
|
667 |
+
img_size = to_2tuple(img_size)
|
668 |
+
patch_size = to_2tuple(patch_size)
|
669 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
670 |
+
self.img_size = img_size
|
671 |
+
self.patch_size = patch_size
|
672 |
+
self.patches_resolution = patches_resolution
|
673 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
674 |
+
|
675 |
+
self.in_chans = in_chans
|
676 |
+
self.embed_dim = embed_dim
|
677 |
+
|
678 |
+
def forward(self, x, x_size):
|
679 |
+
x = x.transpose(1, 2).contiguous().view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
|
680 |
+
return x
|
681 |
+
|
682 |
+
|
683 |
+
class Upsample(nn.Sequential):
|
684 |
+
"""Upsample module.
|
685 |
+
|
686 |
+
Args:
|
687 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
688 |
+
num_feat (int): Channel number of intermediate features.
|
689 |
+
"""
|
690 |
+
|
691 |
+
def __init__(self, scale, num_feat):
|
692 |
+
m = []
|
693 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
694 |
+
for _ in range(int(math.log(scale, 2))):
|
695 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
696 |
+
m.append(nn.PixelShuffle(2))
|
697 |
+
elif scale == 3:
|
698 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
699 |
+
m.append(nn.PixelShuffle(3))
|
700 |
+
else:
|
701 |
+
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
702 |
+
super(Upsample, self).__init__(*m)
|
703 |
+
|
704 |
+
|
705 |
+
class HATArch(nn.Module):
|
706 |
+
r""" Hybrid Attention Transformer
|
707 |
+
A PyTorch implementation of : `Activating More Pixels in Image Super-Resolution Transformer`.
|
708 |
+
Some codes are based on SwinIR.
|
709 |
+
Args:
|
710 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
711 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
712 |
+
in_chans (int): Number of input image channels. Default: 3
|
713 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
714 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
715 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
716 |
+
window_size (int): Window size. Default: 7
|
717 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
718 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
719 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
720 |
+
drop_rate (float): Dropout rate. Default: 0
|
721 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
722 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
723 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
724 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
725 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
726 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
727 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
728 |
+
img_range: Image range. 1. or 255.
|
729 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
730 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
731 |
+
"""
|
732 |
+
|
733 |
+
def __init__(self,
|
734 |
+
img_size=64,
|
735 |
+
patch_size=1,
|
736 |
+
in_chans=3,
|
737 |
+
embed_dim=96,
|
738 |
+
depths=(6, 6, 6, 6),
|
739 |
+
num_heads=(6, 6, 6, 6),
|
740 |
+
window_size=7,
|
741 |
+
compress_ratio=3,
|
742 |
+
squeeze_factor=30,
|
743 |
+
conv_scale=0.01,
|
744 |
+
overlap_ratio=0.5,
|
745 |
+
mlp_ratio=4.,
|
746 |
+
qkv_bias=True,
|
747 |
+
qk_scale=None,
|
748 |
+
drop_rate=0.,
|
749 |
+
attn_drop_rate=0.,
|
750 |
+
drop_path_rate=0.1,
|
751 |
+
norm_layer=nn.LayerNorm,
|
752 |
+
ape=False,
|
753 |
+
patch_norm=True,
|
754 |
+
use_checkpoint=False,
|
755 |
+
upscale=2,
|
756 |
+
img_range=1.,
|
757 |
+
upsampler='',
|
758 |
+
resi_connection='1conv',
|
759 |
+
**kwargs):
|
760 |
+
super(HATArch, self).__init__()
|
761 |
+
|
762 |
+
self.window_size = window_size
|
763 |
+
self.shift_size = window_size // 2
|
764 |
+
self.overlap_ratio = overlap_ratio
|
765 |
+
|
766 |
+
num_in_ch = in_chans
|
767 |
+
num_out_ch = in_chans
|
768 |
+
num_feat = 64
|
769 |
+
self.img_range = img_range
|
770 |
+
if in_chans == 3:
|
771 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
772 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
773 |
+
else:
|
774 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
775 |
+
self.upscale = upscale
|
776 |
+
self.upsampler = upsampler
|
777 |
+
|
778 |
+
# relative position index
|
779 |
+
relative_position_index_SA = self.calculate_rpi_sa()
|
780 |
+
relative_position_index_OCA = self.calculate_rpi_oca()
|
781 |
+
self.register_buffer('relative_position_index_SA', relative_position_index_SA)
|
782 |
+
self.register_buffer('relative_position_index_OCA', relative_position_index_OCA)
|
783 |
+
|
784 |
+
# ------------------------- 1, shallow feature extraction ------------------------- #
|
785 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
786 |
+
|
787 |
+
# ------------------------- 2, deep feature extraction ------------------------- #
|
788 |
+
self.num_layers = len(depths)
|
789 |
+
self.embed_dim = embed_dim
|
790 |
+
self.ape = ape
|
791 |
+
self.patch_norm = patch_norm
|
792 |
+
self.num_features = embed_dim
|
793 |
+
self.mlp_ratio = mlp_ratio
|
794 |
+
|
795 |
+
# split image into non-overlapping patches
|
796 |
+
self.patch_embed = PatchEmbed(
|
797 |
+
img_size=img_size,
|
798 |
+
patch_size=patch_size,
|
799 |
+
in_chans=embed_dim,
|
800 |
+
embed_dim=embed_dim,
|
801 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
802 |
+
num_patches = self.patch_embed.num_patches
|
803 |
+
patches_resolution = self.patch_embed.patches_resolution
|
804 |
+
self.patches_resolution = patches_resolution
|
805 |
+
|
806 |
+
# merge non-overlapping patches into image
|
807 |
+
self.patch_unembed = PatchUnEmbed(
|
808 |
+
img_size=img_size,
|
809 |
+
patch_size=patch_size,
|
810 |
+
in_chans=embed_dim,
|
811 |
+
embed_dim=embed_dim,
|
812 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
813 |
+
|
814 |
+
# absolute position embedding
|
815 |
+
if self.ape:
|
816 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
817 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
818 |
+
|
819 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
820 |
+
|
821 |
+
# stochastic depth
|
822 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
823 |
+
|
824 |
+
# build Residual Hybrid Attention Groups (RHAG)
|
825 |
+
self.layers = nn.ModuleList()
|
826 |
+
for i_layer in range(self.num_layers):
|
827 |
+
layer = RHAG(
|
828 |
+
dim=embed_dim,
|
829 |
+
input_resolution=(patches_resolution[0], patches_resolution[1]),
|
830 |
+
depth=depths[i_layer],
|
831 |
+
num_heads=num_heads[i_layer],
|
832 |
+
window_size=window_size,
|
833 |
+
compress_ratio=compress_ratio,
|
834 |
+
squeeze_factor=squeeze_factor,
|
835 |
+
conv_scale=conv_scale,
|
836 |
+
overlap_ratio=overlap_ratio,
|
837 |
+
mlp_ratio=self.mlp_ratio,
|
838 |
+
qkv_bias=qkv_bias,
|
839 |
+
qk_scale=qk_scale,
|
840 |
+
drop=drop_rate,
|
841 |
+
attn_drop=attn_drop_rate,
|
842 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
843 |
+
norm_layer=norm_layer,
|
844 |
+
downsample=None,
|
845 |
+
use_checkpoint=use_checkpoint,
|
846 |
+
img_size=img_size,
|
847 |
+
patch_size=patch_size,
|
848 |
+
resi_connection=resi_connection)
|
849 |
+
self.layers.append(layer)
|
850 |
+
self.norm = norm_layer(self.num_features)
|
851 |
+
|
852 |
+
# build the last conv layer in deep feature extraction
|
853 |
+
if resi_connection == '1conv':
|
854 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
855 |
+
elif resi_connection == 'identity':
|
856 |
+
self.conv_after_body = nn.Identity()
|
857 |
+
|
858 |
+
# ------------------------- 3, high quality image reconstruction ------------------------- #
|
859 |
+
if self.upsampler == 'pixelshuffle':
|
860 |
+
# for classical SR
|
861 |
+
self.conv_before_upsample = nn.Sequential(
|
862 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
|
863 |
+
self.upsample = Upsample(upscale, num_feat)
|
864 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
865 |
+
|
866 |
+
self.apply(self._init_weights)
|
867 |
+
|
868 |
+
def _init_weights(self, m):
|
869 |
+
if isinstance(m, nn.Linear):
|
870 |
+
trunc_normal_(m.weight, std=.02)
|
871 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
872 |
+
nn.init.constant_(m.bias, 0)
|
873 |
+
elif isinstance(m, nn.LayerNorm):
|
874 |
+
nn.init.constant_(m.bias, 0)
|
875 |
+
nn.init.constant_(m.weight, 1.0)
|
876 |
+
|
877 |
+
def calculate_rpi_sa(self):
|
878 |
+
# calculate relative position index for SA
|
879 |
+
coords_h = torch.arange(self.window_size)
|
880 |
+
coords_w = torch.arange(self.window_size)
|
881 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
882 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
883 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
884 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
885 |
+
relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
|
886 |
+
relative_coords[:, :, 1] += self.window_size - 1
|
887 |
+
relative_coords[:, :, 0] *= 2 * self.window_size - 1
|
888 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
889 |
+
return relative_position_index
|
890 |
+
|
891 |
+
def calculate_rpi_oca(self):
|
892 |
+
# calculate relative position index for OCA
|
893 |
+
window_size_ori = self.window_size
|
894 |
+
window_size_ext = self.window_size + int(self.overlap_ratio * self.window_size)
|
895 |
+
|
896 |
+
coords_h = torch.arange(window_size_ori)
|
897 |
+
coords_w = torch.arange(window_size_ori)
|
898 |
+
coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, ws, ws
|
899 |
+
coords_ori_flatten = torch.flatten(coords_ori, 1) # 2, ws*ws
|
900 |
+
|
901 |
+
coords_h = torch.arange(window_size_ext)
|
902 |
+
coords_w = torch.arange(window_size_ext)
|
903 |
+
coords_ext = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, wse, wse
|
904 |
+
coords_ext_flatten = torch.flatten(coords_ext, 1) # 2, wse*wse
|
905 |
+
|
906 |
+
relative_coords = coords_ext_flatten[:, None, :] - coords_ori_flatten[:, :, None] # 2, ws*ws, wse*wse
|
907 |
+
|
908 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # ws*ws, wse*wse, 2
|
909 |
+
relative_coords[:, :, 0] += window_size_ori - window_size_ext + 1 # shift to start from 0
|
910 |
+
relative_coords[:, :, 1] += window_size_ori - window_size_ext + 1
|
911 |
+
|
912 |
+
relative_coords[:, :, 0] *= window_size_ori + window_size_ext - 1
|
913 |
+
relative_position_index = relative_coords.sum(-1)
|
914 |
+
return relative_position_index
|
915 |
+
|
916 |
+
def calculate_mask(self, x_size):
|
917 |
+
# calculate attention mask for SW-MSA
|
918 |
+
h, w = x_size
|
919 |
+
img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
|
920 |
+
h_slices = (slice(0, -self.window_size), slice(-self.window_size,
|
921 |
+
-self.shift_size), slice(-self.shift_size, None))
|
922 |
+
w_slices = (slice(0, -self.window_size), slice(-self.window_size,
|
923 |
+
-self.shift_size), slice(-self.shift_size, None))
|
924 |
+
cnt = 0
|
925 |
+
for h in h_slices:
|
926 |
+
for w in w_slices:
|
927 |
+
img_mask[:, h, w, :] = cnt
|
928 |
+
cnt += 1
|
929 |
+
|
930 |
+
mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
|
931 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
932 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
933 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
934 |
+
|
935 |
+
return attn_mask
|
936 |
+
|
937 |
+
@torch.jit.ignore
|
938 |
+
def no_weight_decay(self):
|
939 |
+
return {'absolute_pos_embed'}
|
940 |
+
|
941 |
+
@torch.jit.ignore
|
942 |
+
def no_weight_decay_keywords(self):
|
943 |
+
return {'relative_position_bias_table'}
|
944 |
+
|
945 |
+
def forward_features(self, x):
|
946 |
+
x_size = (x.shape[2], x.shape[3])
|
947 |
+
|
948 |
+
# Calculate attention mask and relative position index in advance to speed up inference.
|
949 |
+
# The original code is very time-consuming for large window size.
|
950 |
+
attn_mask = self.calculate_mask(x_size).to(x.device)
|
951 |
+
params = {'attn_mask': attn_mask, 'rpi_sa': self.relative_position_index_SA, 'rpi_oca': self.relative_position_index_OCA}
|
952 |
+
|
953 |
+
x = self.patch_embed(x)
|
954 |
+
if self.ape:
|
955 |
+
x = x + self.absolute_pos_embed
|
956 |
+
x = self.pos_drop(x)
|
957 |
+
|
958 |
+
for layer in self.layers:
|
959 |
+
x = layer(x, x_size, params)
|
960 |
+
|
961 |
+
x = self.norm(x) # b seq_len c
|
962 |
+
x = self.patch_unembed(x, x_size)
|
963 |
+
|
964 |
+
return x
|
965 |
+
|
966 |
+
def forward(self, x):
|
967 |
+
self.mean = self.mean.type_as(x)
|
968 |
+
x = (x - self.mean) * self.img_range
|
969 |
+
|
970 |
+
if self.upsampler == 'pixelshuffle':
|
971 |
+
# for classical SR
|
972 |
+
x = self.conv_first(x)
|
973 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
974 |
+
x = self.conv_before_upsample(x)
|
975 |
+
x = self.conv_last(self.upsample(x))
|
976 |
+
|
977 |
+
x = x / self.img_range + self.mean
|
978 |
+
|
979 |
+
return x
|
SR_Inference/hat/weights/HAT-L_SRx2_ImageNet-pretrain.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2818c7ca8d72ec4cc5f31c93203d55252a662dd35cda34ce1a69661f97dcd38f
|
3 |
+
size 165182573
|
SR_Inference/hat/weights/HAT_SRx2_ImageNet-pretrain.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:82ebd911263bcc886fbef46b30cf97b92a932a27a3cba30163d4577afb09b9d7
|
3 |
+
size 84546053
|
SR_Inference/hat/weights/HAT_SRx4_ImageNet-pretrain.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4ee053c42461187846dc0e93aa5abd34591c0725a8e044a59000e92ee215e833
|
3 |
+
size 85137601
|
SR_Inference/inference_codeformer.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import sys
|
4 |
+
import torch
|
5 |
+
import os.path as osp
|
6 |
+
from basicsr.utils import img2tensor, tensor2img
|
7 |
+
from torchvision.transforms.functional import normalize
|
8 |
+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
9 |
+
|
10 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
11 |
+
|
12 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
13 |
+
sys.path.append(root_path)
|
14 |
+
from SR_Inference.codeformer.codeformer_arch import CodeFormerArch
|
15 |
+
from SR_Inference.inference_sr_utils import RealEsrUpsamplerZoo
|
16 |
+
|
17 |
+
|
18 |
+
class CodeFormer:
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
upscale=2,
|
23 |
+
bg_upsampler_name="realesrgan",
|
24 |
+
prefered_net_in_upsampler="RRDBNet",
|
25 |
+
fidelity_weight=0.8,
|
26 |
+
):
|
27 |
+
|
28 |
+
self.upscale = int(upscale)
|
29 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
self.fidelity_weight = fidelity_weight
|
31 |
+
|
32 |
+
# ------------------------ set up background upsampler ------------------------
|
33 |
+
upsampler_zoo = RealEsrUpsamplerZoo(
|
34 |
+
upscale=self.upscale,
|
35 |
+
bg_upsampler_name=bg_upsampler_name,
|
36 |
+
prefered_net_in_upsampler=prefered_net_in_upsampler,
|
37 |
+
)
|
38 |
+
self.bg_upsampler = upsampler_zoo.bg_upsampler
|
39 |
+
|
40 |
+
# ------------------ set up FaceRestoreHelper -------------------
|
41 |
+
gfpgan_weights_path = os.path.join(
|
42 |
+
ROOT_DIR, "SR_Inference", "gfpgan", "weights"
|
43 |
+
)
|
44 |
+
self.face_restorer_helper = FaceRestoreHelper(
|
45 |
+
upscale_factor=self.upscale,
|
46 |
+
face_size=512,
|
47 |
+
crop_ratio=(1, 1),
|
48 |
+
det_model="retinaface_resnet50",
|
49 |
+
save_ext="png",
|
50 |
+
use_parse=True,
|
51 |
+
device=self.device,
|
52 |
+
# model_rootpath="gfpgan/weights",
|
53 |
+
model_rootpath=gfpgan_weights_path,
|
54 |
+
)
|
55 |
+
|
56 |
+
# ------------------ load model -------------------
|
57 |
+
self.sr_model = CodeFormerArch().to(self.device)
|
58 |
+
ckpt_path = os.path.join(
|
59 |
+
ROOT_DIR, "SR_Inference", "codeformer", "weights", "codeformer_v0.1.0.pth"
|
60 |
+
)
|
61 |
+
loadnet = torch.load(ckpt_path, map_location=self.device)
|
62 |
+
if "params_ema" in loadnet:
|
63 |
+
keyname = "params_ema"
|
64 |
+
else:
|
65 |
+
keyname = "params"
|
66 |
+
|
67 |
+
self.sr_model.load_state_dict(loadnet[keyname])
|
68 |
+
self.sr_model.eval()
|
69 |
+
|
70 |
+
@torch.no_grad()
|
71 |
+
def __call__(self, img):
|
72 |
+
|
73 |
+
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
74 |
+
|
75 |
+
self.face_restorer_helper.clean_all()
|
76 |
+
self.face_restorer_helper.read_image(img)
|
77 |
+
self.face_restorer_helper.get_face_landmarks_5(
|
78 |
+
only_keep_largest=True, only_center_face=False, eye_dist_threshold=5
|
79 |
+
)
|
80 |
+
self.face_restorer_helper.align_warp_face()
|
81 |
+
|
82 |
+
if len(self.face_restorer_helper.cropped_faces) > 0:
|
83 |
+
|
84 |
+
cropped_face = self.face_restorer_helper.cropped_faces[0]
|
85 |
+
|
86 |
+
cropped_face_t = img2tensor(
|
87 |
+
imgs=cropped_face / 255.0, bgr2rgb=True, float32=True
|
88 |
+
)
|
89 |
+
normalize(
|
90 |
+
tensor=cropped_face_t,
|
91 |
+
mean=(0.5, 0.5, 0.5),
|
92 |
+
std=(0.5, 0.5, 0.5),
|
93 |
+
inplace=True,
|
94 |
+
)
|
95 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
|
96 |
+
|
97 |
+
# ------------------- restore/enhance image using CodeFormerArch model -------------------
|
98 |
+
output = self.sr_model(cropped_face_t, w=self.fidelity_weight, adain=True)[
|
99 |
+
0
|
100 |
+
]
|
101 |
+
|
102 |
+
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
103 |
+
restored_face = restored_face.astype("uint8")
|
104 |
+
|
105 |
+
self.face_restorer_helper.add_restored_face(restored_face)
|
106 |
+
self.face_restorer_helper.get_inverse_affine(None)
|
107 |
+
|
108 |
+
sr_img = self.face_restorer_helper.paste_faces_to_input_image(
|
109 |
+
upsample_img=bg_img
|
110 |
+
)
|
111 |
+
else:
|
112 |
+
sr_img = bg_img
|
113 |
+
|
114 |
+
return sr_img
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
|
119 |
+
codeformer = CodeFormer(upscale=2, fidelity_weight=1.0)
|
120 |
+
|
121 |
+
img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
|
122 |
+
sr_img = codeformer(img=img)
|
123 |
+
|
124 |
+
saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
|
125 |
+
os.makedirs(saving_dir, exist_ok=True)
|
126 |
+
cv2.imwrite(f"{saving_dir}/sr_img.png", sr_img)
|
SR_Inference/inference_gfpgan.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import sys
|
4 |
+
import torch
|
5 |
+
import os.path as osp
|
6 |
+
from gfpgan import GFPGANer
|
7 |
+
from basicsr.utils.download_util import load_file_from_url
|
8 |
+
|
9 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
10 |
+
|
11 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
12 |
+
sys.path.append(root_path)
|
13 |
+
from SR_Inference.inference_sr_utils import RealEsrUpsamplerZoo
|
14 |
+
|
15 |
+
|
16 |
+
class GFPGAN:
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
upscale=2,
|
21 |
+
bg_upsampler_name="realesrgan",
|
22 |
+
prefered_net_in_upsampler="RRDBNet",
|
23 |
+
):
|
24 |
+
|
25 |
+
upscale = int(upscale)
|
26 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
27 |
+
|
28 |
+
# ------------------------ set up background upsampler ------------------------
|
29 |
+
upsampler_zoo = RealEsrUpsamplerZoo(
|
30 |
+
upscale=upscale,
|
31 |
+
bg_upsampler_name=bg_upsampler_name,
|
32 |
+
prefered_net_in_upsampler=prefered_net_in_upsampler,
|
33 |
+
)
|
34 |
+
bg_upsampler = upsampler_zoo.bg_upsampler
|
35 |
+
|
36 |
+
# ------------------------ load model ------------------------
|
37 |
+
gfpgan_weights_path = os.path.join(
|
38 |
+
ROOT_DIR, "SR_Inference", "gfpgan", "weights"
|
39 |
+
)
|
40 |
+
gfpgan_model_path = os.path.join(gfpgan_weights_path, "GFPGANv1.3.pth")
|
41 |
+
|
42 |
+
if not os.path.isfile(gfpgan_model_path):
|
43 |
+
url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth"
|
44 |
+
gfpgan_model_path = load_file_from_url(
|
45 |
+
url=url,
|
46 |
+
model_dir=gfpgan_weights_path,
|
47 |
+
progress=True,
|
48 |
+
file_name="GFPGANv1.3.pth",
|
49 |
+
)
|
50 |
+
|
51 |
+
self.sr_model = GFPGANer(
|
52 |
+
upscale=upscale,
|
53 |
+
bg_upsampler=bg_upsampler,
|
54 |
+
model_path=gfpgan_model_path,
|
55 |
+
device=device,
|
56 |
+
)
|
57 |
+
|
58 |
+
def __call__(self, img):
|
59 |
+
# ------------------------ restore/enhance image using GFPGAN model ------------------------
|
60 |
+
cropped_faces, sr_faces, sr_img = self.sr_model.enhance(img)
|
61 |
+
|
62 |
+
return sr_img
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
|
67 |
+
gfpgan = GFPGAN(
|
68 |
+
upscale=2, bg_upsampler_name="realesrgan", prefered_net_in_upsampler="RRDBNet"
|
69 |
+
)
|
70 |
+
|
71 |
+
img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
|
72 |
+
sr_img = gfpgan(img=img)
|
73 |
+
|
74 |
+
saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
|
75 |
+
os.makedirs(saving_dir, exist_ok=True)
|
76 |
+
cv2.imwrite(f"{saving_dir}/sr_img_gfpgan.png", sr_img)
|
SR_Inference/inference_hat.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import sys
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import os.path as osp
|
7 |
+
from PIL import Image
|
8 |
+
from basicsr.utils import img2tensor
|
9 |
+
|
10 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
11 |
+
|
12 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
13 |
+
sys.path.append(root_path)
|
14 |
+
from SR_Inference.hat.hat_arch import HATArch
|
15 |
+
|
16 |
+
|
17 |
+
class HAT:
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
upscale=2,
|
22 |
+
in_chans=3,
|
23 |
+
img_size=(480, 640),
|
24 |
+
window_size=16,
|
25 |
+
compress_ratio=3,
|
26 |
+
squeeze_factor=30,
|
27 |
+
conv_scale=0.01,
|
28 |
+
overlap_ratio=0.5,
|
29 |
+
img_range=1.0,
|
30 |
+
depths=[6, 6, 6, 6, 6, 6],
|
31 |
+
embed_dim=180,
|
32 |
+
num_heads=[6, 6, 6, 6, 6, 6],
|
33 |
+
mlp_ratio=2,
|
34 |
+
upsampler="pixelshuffle",
|
35 |
+
resi_connection="1conv",
|
36 |
+
):
|
37 |
+
upscale = int(upscale)
|
38 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
39 |
+
|
40 |
+
# ------------------ load model for img enhancement -------------------
|
41 |
+
self.sr_model = HATArch(
|
42 |
+
img_size=img_size,
|
43 |
+
upscale=upscale,
|
44 |
+
in_chans=in_chans,
|
45 |
+
window_size=window_size,
|
46 |
+
compress_ratio=compress_ratio,
|
47 |
+
squeeze_factor=squeeze_factor,
|
48 |
+
conv_scale=conv_scale,
|
49 |
+
overlap_ratio=overlap_ratio,
|
50 |
+
img_range=img_range,
|
51 |
+
depths=depths,
|
52 |
+
embed_dim=embed_dim,
|
53 |
+
num_heads=num_heads,
|
54 |
+
mlp_ratio=mlp_ratio,
|
55 |
+
upsampler=upsampler,
|
56 |
+
resi_connection=resi_connection,
|
57 |
+
).to(self.device)
|
58 |
+
|
59 |
+
ckpt_path = os.path.join(
|
60 |
+
ROOT_DIR,
|
61 |
+
"SR_Inference",
|
62 |
+
"hat",
|
63 |
+
"weights",
|
64 |
+
f"HAT_SRx{str(upscale)}_ImageNet-pretrain.pth",
|
65 |
+
)
|
66 |
+
loadnet = torch.load(ckpt_path, map_location=self.device)
|
67 |
+
if "params_ema" in loadnet:
|
68 |
+
keyname = "params_ema"
|
69 |
+
else:
|
70 |
+
keyname = "params"
|
71 |
+
|
72 |
+
self.sr_model.load_state_dict(loadnet[keyname])
|
73 |
+
self.sr_model.eval()
|
74 |
+
|
75 |
+
@torch.no_grad()
|
76 |
+
def __call__(self, img):
|
77 |
+
img_tensor = (
|
78 |
+
img2tensor(imgs=img / 255.0, bgr2rgb=True, float32=True)
|
79 |
+
.unsqueeze(0)
|
80 |
+
.to(self.device)
|
81 |
+
)
|
82 |
+
restored_img = self.sr_model(img_tensor)[0]
|
83 |
+
restored_img = restored_img.permute(1, 2, 0).cpu().numpy()
|
84 |
+
restored_img = (restored_img - restored_img.min()) / (
|
85 |
+
restored_img.max() - restored_img.min()
|
86 |
+
)
|
87 |
+
restored_img = (restored_img * 255).astype(np.uint8)
|
88 |
+
restored_img = Image.fromarray(restored_img)
|
89 |
+
restored_img = np.array(restored_img)
|
90 |
+
sr_img = cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR)
|
91 |
+
|
92 |
+
return sr_img
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
|
97 |
+
hat = HAT(upscale=2)
|
98 |
+
|
99 |
+
img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
|
100 |
+
sr_img = hat(img=img)
|
101 |
+
|
102 |
+
saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
|
103 |
+
os.makedirs(saving_dir, exist_ok=True)
|
104 |
+
cv2.imwrite(f"{saving_dir}/sr_img_hat.png", sr_img)
|
SR_Inference/inference_realesr.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import sys
|
4 |
+
import torch
|
5 |
+
import os.path as osp
|
6 |
+
|
7 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
8 |
+
|
9 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
10 |
+
sys.path.append(root_path)
|
11 |
+
from SR_Inference.inference_sr_utils import RealEsrUpsamplerZoo
|
12 |
+
|
13 |
+
|
14 |
+
class RealEsr:
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
upscale=2,
|
19 |
+
bg_upsampler_name="realesrgan",
|
20 |
+
prefered_net_in_upsampler="RRDBNet",
|
21 |
+
):
|
22 |
+
|
23 |
+
self.upscale = int(upscale)
|
24 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
|
26 |
+
# ------------------------ set up background upsampler ------------------------
|
27 |
+
self.upsampler_zoo = RealEsrUpsamplerZoo(
|
28 |
+
upscale=self.upscale,
|
29 |
+
bg_upsampler_name=bg_upsampler_name,
|
30 |
+
prefered_net_in_upsampler=prefered_net_in_upsampler,
|
31 |
+
)
|
32 |
+
self.bg_upsampler = self.upsampler_zoo.bg_upsampler
|
33 |
+
|
34 |
+
def __call__(self, img):
|
35 |
+
# ---------------- restore/enhance image using the selected RealESR model ----------------
|
36 |
+
sr_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
37 |
+
|
38 |
+
return sr_img
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == "__main__":
|
42 |
+
|
43 |
+
realesr = RealEsr(
|
44 |
+
upscale=2, bg_upsampler_name="realesrgan", prefered_net_in_upsampler="RRDBNet"
|
45 |
+
)
|
46 |
+
|
47 |
+
img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
|
48 |
+
sr_img = realesr(img=img)
|
49 |
+
|
50 |
+
saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
|
51 |
+
os.makedirs(saving_dir, exist_ok=True)
|
52 |
+
cv2.imwrite(f"{saving_dir}/sr_img.png", sr_img)
|
SR_Inference/inference_sr_utils.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
from realesrgan import RealESRGANer
|
5 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
6 |
+
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
7 |
+
from basicsr.utils.download_util import load_file_from_url
|
8 |
+
|
9 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
10 |
+
|
11 |
+
|
12 |
+
class RealEsrUpsamplerZoo:
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
upscale=2,
|
17 |
+
bg_upsampler_name="realesrgan",
|
18 |
+
prefered_net_in_upsampler="RRDBNet",
|
19 |
+
):
|
20 |
+
|
21 |
+
self.upscale = int(upscale)
|
22 |
+
|
23 |
+
# ------------------------ set up background upsampler ------------------------
|
24 |
+
weights_path = os.path.join(
|
25 |
+
ROOT_DIR, "SR_Inference", f"{bg_upsampler_name}", "weights"
|
26 |
+
)
|
27 |
+
|
28 |
+
if bg_upsampler_name == "realesrgan":
|
29 |
+
model = self.get_prefered_net(prefered_net_in_upsampler, upscale)
|
30 |
+
if self.upscale == 2:
|
31 |
+
model_path = os.path.join(weights_path, "RealESRGAN_x2plus.pth")
|
32 |
+
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
|
33 |
+
elif self.upscale == 4:
|
34 |
+
model_path = os.path.join(weights_path, "RealESRGAN_x4plus.pth")
|
35 |
+
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
|
36 |
+
else:
|
37 |
+
raise Exception(
|
38 |
+
f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}"
|
39 |
+
)
|
40 |
+
elif bg_upsampler_name == "realesrnet":
|
41 |
+
model = self.get_prefered_net(prefered_net_in_upsampler, upscale)
|
42 |
+
if self.upscale == 4:
|
43 |
+
model_path = os.path.join(weights_path, "RealESRNet_x4plus.pth")
|
44 |
+
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth"
|
45 |
+
else:
|
46 |
+
raise Exception(
|
47 |
+
f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}"
|
48 |
+
)
|
49 |
+
elif bg_upsampler_name == "anime":
|
50 |
+
model = self.get_prefered_net(prefered_net_in_upsampler, upscale)
|
51 |
+
if self.upscale == 4:
|
52 |
+
model_path = os.path.join(
|
53 |
+
weights_path, "RealESRGAN_x4plus_anime_6B.pth"
|
54 |
+
)
|
55 |
+
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
|
56 |
+
else:
|
57 |
+
raise Exception(
|
58 |
+
f"{bg_upsampler_name} model not available for upscaling x{str(self.upscale)}"
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
raise Exception(f"No model implemented for: {bg_upsampler_name}")
|
62 |
+
|
63 |
+
# ------------------------ load background upsampler model ------------------------
|
64 |
+
if not os.path.isfile(model_path):
|
65 |
+
model_path = load_file_from_url(
|
66 |
+
url=url, model_dir=weights_path, progress=True, file_name=None
|
67 |
+
)
|
68 |
+
|
69 |
+
self.bg_upsampler = RealESRGANer(
|
70 |
+
scale=int(upscale),
|
71 |
+
model_path=model_path,
|
72 |
+
model=model,
|
73 |
+
tile=0,
|
74 |
+
tile_pad=0,
|
75 |
+
pre_pad=0,
|
76 |
+
half=False,
|
77 |
+
)
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def get_prefered_net(prefered_net_in_upsampler, upscale=2):
|
81 |
+
if prefered_net_in_upsampler == "RRDBNet":
|
82 |
+
model = RRDBNet(
|
83 |
+
num_in_ch=3,
|
84 |
+
num_out_ch=3,
|
85 |
+
num_feat=64,
|
86 |
+
num_block=23,
|
87 |
+
num_grow_ch=32,
|
88 |
+
scale=int(upscale),
|
89 |
+
)
|
90 |
+
elif prefered_net_in_upsampler == "SRVGGNetCompact":
|
91 |
+
model = SRVGGNetCompact(
|
92 |
+
num_in_ch=3,
|
93 |
+
num_out_ch=3,
|
94 |
+
num_feat=64,
|
95 |
+
num_conv=16,
|
96 |
+
upscale=int(upscale),
|
97 |
+
act_type="prelu",
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
raise Exception(f"No net named: {prefered_net_in_upsampler} implemented!")
|
101 |
+
return model
|
SR_Inference/inference_srresnet.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import sys
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import os.path as osp
|
7 |
+
from PIL import Image
|
8 |
+
from basicsr.utils import img2tensor
|
9 |
+
from basicsr.archs.srresnet_arch import MSRResNet
|
10 |
+
|
11 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
12 |
+
|
13 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
14 |
+
sys.path.append(root_path)
|
15 |
+
|
16 |
+
|
17 |
+
class SRResNet:
|
18 |
+
|
19 |
+
def __init__(self, upscale=2, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16):
|
20 |
+
|
21 |
+
self.upscale = int(upscale)
|
22 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
|
24 |
+
# ------------------ load model for img enhancement -------------------
|
25 |
+
self.sr_model = MSRResNet(
|
26 |
+
upscale=self.upscale,
|
27 |
+
num_in_ch=num_in_ch,
|
28 |
+
num_out_ch=num_out_ch,
|
29 |
+
num_feat=num_feat,
|
30 |
+
num_block=num_block,
|
31 |
+
).to(self.device)
|
32 |
+
|
33 |
+
ckpt_path = os.path.join(
|
34 |
+
ROOT_DIR,
|
35 |
+
"SR_Inference",
|
36 |
+
"srresnet",
|
37 |
+
"weights",
|
38 |
+
f"SRResNet_{str(self.upscale)}x.pth",
|
39 |
+
)
|
40 |
+
loadnet = torch.load(ckpt_path, map_location=self.device)
|
41 |
+
if "params_ema" in loadnet:
|
42 |
+
keyname = "params_ema"
|
43 |
+
else:
|
44 |
+
keyname = "params"
|
45 |
+
|
46 |
+
self.sr_model.load_state_dict(loadnet[keyname])
|
47 |
+
self.sr_model.eval()
|
48 |
+
|
49 |
+
@torch.no_grad()
|
50 |
+
def __call__(self, img):
|
51 |
+
img_tensor = (
|
52 |
+
img2tensor(imgs=img / 255.0, bgr2rgb=True, float32=True)
|
53 |
+
.unsqueeze(0)
|
54 |
+
.to(self.device)
|
55 |
+
)
|
56 |
+
restored_img = self.sr_model(img_tensor)[0]
|
57 |
+
restored_img = restored_img.permute(1, 2, 0).cpu().numpy()
|
58 |
+
restored_img = (restored_img - restored_img.min()) / (
|
59 |
+
restored_img.max() - restored_img.min()
|
60 |
+
)
|
61 |
+
restored_img = (restored_img * 255).astype(np.uint8)
|
62 |
+
restored_img = Image.fromarray(restored_img)
|
63 |
+
restored_img = np.array(restored_img)
|
64 |
+
sr_img = cv2.cvtColor(restored_img, cv2.COLOR_RGB2BGR)
|
65 |
+
|
66 |
+
return sr_img
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
|
71 |
+
srresnet = SRResNet(upscale=2)
|
72 |
+
|
73 |
+
img = cv2.imread(f"{ROOT_DIR}/data/EyeDentify/Wo_SR/original/1/1/frame_01.png")
|
74 |
+
sr_img = srresnet(img=img)
|
75 |
+
|
76 |
+
saving_dir = f"{ROOT_DIR}/rough_works/SR_imgs"
|
77 |
+
os.makedirs(saving_dir, exist_ok=True)
|
78 |
+
cv2.imwrite(f"{saving_dir}/sr_img_srresnet.png", sr_img)
|
SR_Inference/realesrgan/weights/RealESRGAN_x2plus.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:49fafd45f8fd7aa8d31ab2a22d14d91b536c34494a5cfe31eb5d89c2fa266abb
|
3 |
+
size 67061725
|
SR_Inference/realesrgan/weights/RealESRGAN_x4plus.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4fa0d38905f75ac06eb49a7951b426670021be3018265fd191d2125df9d682f1
|
3 |
+
size 67040989
|
SR_Inference/srresnet/weights/SRResNet_2x.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4d2a531ecd6e8f15cc5eccf4bca58cdfea69f76c09baa1b208694977b0f6f5e
|
3 |
+
size 5492202
|
SR_Inference/srresnet/weights/SRResNet_4x.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:112f2ec947bb497b0350b149a1e06c3f73de77497cec64ce0c7ba268b8398023
|
3 |
+
size 6083374
|
app.py
CHANGED
@@ -4,8 +4,9 @@
|
|
4 |
from io import BytesIO
|
5 |
import os
|
6 |
import sys
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
-
import
|
9 |
import streamlit as st
|
10 |
import torch
|
11 |
from PIL import Image
|
@@ -21,6 +22,7 @@ import os.path as osp
|
|
21 |
root_path = osp.abspath(osp.join(__file__, osp.pardir))
|
22 |
sys.path.append(root_path)
|
23 |
|
|
|
24 |
from utils import get_model
|
25 |
from registry_utils import import_registered_modules
|
26 |
|
@@ -39,15 +41,13 @@ CAM_METHODS = [
|
|
39 |
# "LayerCAM",
|
40 |
]
|
41 |
TV_MODELS = [
|
42 |
-
"
|
43 |
-
|
44 |
-
]
|
45 |
-
SR_METHODS = ["GFPGAN", "RealESRGAN", "SRResNet", "CodeFormer", "HAT"]
|
46 |
-
UPSCALE = ["2", "3", "4"]
|
47 |
-
LABEL_MAP = [
|
48 |
-
"left_eye",
|
49 |
-
"right_eye",
|
50 |
]
|
|
|
|
|
|
|
|
|
51 |
|
52 |
|
53 |
@torch.no_grad()
|
@@ -79,150 +79,287 @@ def main():
|
|
79 |
|
80 |
# Sidebar
|
81 |
# File selection
|
82 |
-
st.sidebar.title("
|
83 |
# Disabling warning
|
84 |
st.set_option("deprecation.showfileUploaderEncoding", False)
|
85 |
# Choose your own image
|
86 |
uploaded_file = st.sidebar.file_uploader(
|
87 |
-
"Upload
|
88 |
)
|
89 |
if uploaded_file is not None:
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
# Model selection
|
95 |
-
st.sidebar.title("Setup")
|
96 |
tv_model = st.sidebar.selectbox(
|
97 |
"Classification model",
|
98 |
TV_MODELS,
|
99 |
-
help="Supported
|
100 |
)
|
101 |
|
102 |
-
|
103 |
-
#
|
104 |
-
#
|
105 |
-
#
|
106 |
-
# "
|
|
|
|
|
|
|
|
|
|
|
107 |
# )
|
108 |
|
109 |
-
img_configs = {"img_size": [32, 64], "means": None, "stds": None}
|
110 |
-
# For newline
|
111 |
st.sidebar.write("\n")
|
112 |
|
113 |
-
if st.sidebar.button("Compute CAM"):
|
114 |
if uploaded_file is None:
|
115 |
st.sidebar.error("Please upload an image first")
|
116 |
|
117 |
else:
|
118 |
with st.spinner("Analyzing..."):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
)
|
|
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
if means is not None and stds is not None:
|
135 |
-
preprocess_steps.append(transforms.Normalize(means, stds))
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
"model_path": root_path
|
143 |
-
+ "/pre_trained_models/ResNet18/left_eye.pt",
|
144 |
-
"registered_model_name": "ResNet18",
|
145 |
-
"num_classes": 1,
|
146 |
-
}
|
147 |
-
registered_model_name = model_configs["registered_model_name"]
|
148 |
-
# default_layer = ""
|
149 |
-
if tv_model is not None:
|
150 |
-
with st.spinner("Loading model..."):
|
151 |
-
model = _load_model(model_configs)
|
152 |
-
|
153 |
-
if torch.cuda.is_available():
|
154 |
-
model = model.cuda()
|
155 |
-
|
156 |
-
if registered_model_name == "ResNet18":
|
157 |
-
target_layer = model.resnet.layer4[-1].conv2
|
158 |
-
elif registered_model_name == "ResNet50":
|
159 |
-
target_layer = model.resnet.layer4[-1].conv3
|
160 |
-
else:
|
161 |
-
raise Exception(
|
162 |
-
f"No target layer available for selected model: {registered_model_name}"
|
163 |
)
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
#
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
#
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
190 |
)
|
191 |
-
# with torch.no_grad():
|
192 |
-
# if input_mask is not None:
|
193 |
-
# out = self.model(input_img, input_mask)
|
194 |
-
# else:
|
195 |
-
# out = self.model(input_img)
|
196 |
-
# activation_map = cam_extractor(class_idx=target_class)
|
197 |
-
|
198 |
-
# Forward the image to the model
|
199 |
-
out = model(input_img)
|
200 |
-
print("out = ", out)
|
201 |
-
|
202 |
-
# Select the target class
|
203 |
-
# if class_selection == "Predicted class (argmax)":
|
204 |
-
# class_idx = out.squeeze(0).argmax().item()
|
205 |
-
# else:
|
206 |
-
# class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
|
207 |
-
|
208 |
-
# Retrieve the CAM
|
209 |
-
# act_maps = cam_extractor(class_idx=target_class)
|
210 |
-
act_maps = cam_extractor(0, out)
|
211 |
-
# Fuse the CAMs if there are several
|
212 |
-
activation_map = (
|
213 |
-
act_maps[0]
|
214 |
-
if len(act_maps) == 1
|
215 |
-
else cam_extractor.fuse_cams(act_maps)
|
216 |
-
)
|
217 |
-
|
218 |
-
# Overlayed CAM
|
219 |
-
fig, ax = plt.subplots()
|
220 |
-
result = overlay_mask(
|
221 |
-
img, to_pil_image(activation_map, mode="F"), alpha=0.5
|
222 |
-
)
|
223 |
-
ax.imshow(result)
|
224 |
-
ax.axis("off")
|
225 |
-
cols[-1].pyplot(fig)
|
226 |
|
227 |
|
228 |
if __name__ == "__main__":
|
|
|
4 |
from io import BytesIO
|
5 |
import os
|
6 |
import sys
|
7 |
+
import cv2
|
8 |
import matplotlib.pyplot as plt
|
9 |
+
import numpy as np
|
10 |
import streamlit as st
|
11 |
import torch
|
12 |
from PIL import Image
|
|
|
22 |
root_path = osp.abspath(osp.join(__file__, osp.pardir))
|
23 |
sys.path.append(root_path)
|
24 |
|
25 |
+
from preprocessing.dataset_creation import EyeDentityDatasetCreation
|
26 |
from utils import get_model
|
27 |
from registry_utils import import_registered_modules
|
28 |
|
|
|
41 |
# "LayerCAM",
|
42 |
]
|
43 |
TV_MODELS = [
|
44 |
+
"ResNet18",
|
45 |
+
"ResNet50",
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
]
|
47 |
+
SR_METHODS = ["GFPGAN", "CodeFormer", "RealESRGAN", "SRResNet", "HAT"]
|
48 |
+
UPSCALE = [2, 4]
|
49 |
+
UPSCALE_METHODS = ["BILINEAR", "BICUBIC"]
|
50 |
+
LABEL_MAP = ["left_pupil", "right_pupil"]
|
51 |
|
52 |
|
53 |
@torch.no_grad()
|
|
|
79 |
|
80 |
# Sidebar
|
81 |
# File selection
|
82 |
+
st.sidebar.title("Upload Face or Eye")
|
83 |
# Disabling warning
|
84 |
st.set_option("deprecation.showfileUploaderEncoding", False)
|
85 |
# Choose your own image
|
86 |
uploaded_file = st.sidebar.file_uploader(
|
87 |
+
"Upload Image", type=["png", "jpeg", "jpg"]
|
88 |
)
|
89 |
if uploaded_file is not None:
|
90 |
+
input_img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB")
|
91 |
+
# print("input_img before = ", input_img.size)
|
92 |
+
max_size = [input_img.size[0], input_img.size[1]]
|
93 |
+
cols[0].text(f"Input Image: {max_size[0]} x {max_size[1]}")
|
94 |
+
if input_img.size[0] == input_img.size[1] and input_img.size[0] >= 256:
|
95 |
+
max_size[0] = 256
|
96 |
+
max_size[1] = 256
|
97 |
+
else:
|
98 |
+
if input_img.size[0] >= 640:
|
99 |
+
max_size[0] = 640
|
100 |
+
elif input_img.size[0] < 64:
|
101 |
+
max_size[0] = 64
|
102 |
+
if input_img.size[1] >= 480:
|
103 |
+
max_size[1] = 480
|
104 |
+
elif input_img.size[1] < 32:
|
105 |
+
max_size[1] = 32
|
106 |
+
input_img.thumbnail((max_size[0], max_size[1])) # Bicubic resampling
|
107 |
+
# print("input_img after = ", input_img.size)
|
108 |
+
# cols[0].image(input_img)
|
109 |
+
fig0, axs0 = plt.subplots(1, 1, figsize=(10, 10))
|
110 |
+
# Display the input image
|
111 |
+
axs0.imshow(input_img)
|
112 |
+
axs0.axis("off")
|
113 |
+
axs0.set_title("Input Image")
|
114 |
+
|
115 |
+
# Display the plot
|
116 |
+
cols[0].pyplot(fig0)
|
117 |
+
cols[0].text(f"Input Image Resized: {max_size[0]} x {max_size[1]}")
|
118 |
+
|
119 |
+
st.sidebar.title("Setup")
|
120 |
|
121 |
+
# Upscale selection
|
122 |
+
upscale = "-"
|
123 |
+
# upscale = st.sidebar.selectbox(
|
124 |
+
# "Upscale",
|
125 |
+
# ["-"] + UPSCALE,
|
126 |
+
# help="Upscale the uploaded image 2 or 4 times. Keep blank for no upscaling",
|
127 |
+
# )
|
128 |
+
|
129 |
+
# Upscale method selection
|
130 |
+
if upscale != "-":
|
131 |
+
upscale_method_or_model = st.sidebar.selectbox(
|
132 |
+
"Upscale Method / Model",
|
133 |
+
UPSCALE_METHODS + SR_METHODS,
|
134 |
+
help="Select a method or model to upscale the uploaded image",
|
135 |
+
)
|
136 |
+
else:
|
137 |
+
upscale_method_or_model = None
|
138 |
+
|
139 |
+
# Pupil selection
|
140 |
+
pupil_selection = st.sidebar.selectbox(
|
141 |
+
"Pupil Selection",
|
142 |
+
["-"] + LABEL_MAP,
|
143 |
+
help="Select left or right pupil OR keep blank for both pupil diameter estimation",
|
144 |
+
)
|
145 |
|
146 |
# Model selection
|
|
|
147 |
tv_model = st.sidebar.selectbox(
|
148 |
"Classification model",
|
149 |
TV_MODELS,
|
150 |
+
help="Supported Models for Pupil Diameter Estimation",
|
151 |
)
|
152 |
|
153 |
+
cam_method = "CAM"
|
154 |
+
# cam_method = st.sidebar.selectbox(
|
155 |
+
# "CAM method",
|
156 |
+
# CAM_METHODS,
|
157 |
+
# help="The way your class activation map will be computed",
|
158 |
+
# )
|
159 |
+
# target_layer = st.sidebar.text_input(
|
160 |
+
# "Target layer",
|
161 |
+
# default_layer,
|
162 |
+
# help='If you want to target several layers, add a "+" separator (e.g. "layer3+layer4")',
|
163 |
# )
|
164 |
|
|
|
|
|
165 |
st.sidebar.write("\n")
|
166 |
|
167 |
+
if st.sidebar.button("Predict Diameter & Compute CAM"):
|
168 |
if uploaded_file is None:
|
169 |
st.sidebar.error("Please upload an image first")
|
170 |
|
171 |
else:
|
172 |
with st.spinner("Analyzing..."):
|
173 |
+
if upscale == "-":
|
174 |
+
sr_configs = None
|
175 |
+
else:
|
176 |
+
sr_configs = {
|
177 |
+
"method": upscale_method_or_model,
|
178 |
+
"params": {"upscale": upscale},
|
179 |
+
}
|
180 |
+
config_file = {
|
181 |
+
"sr_configs": sr_configs,
|
182 |
+
"feature_extraction_configs": {
|
183 |
+
"blink_detection": False,
|
184 |
+
"upscale": upscale,
|
185 |
+
"extraction_library": "mediapipe",
|
186 |
+
},
|
187 |
+
}
|
188 |
+
|
189 |
+
img = np.array(input_img)
|
190 |
+
# img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
191 |
+
# if img.shape[0] > max_size or img.shape[1] > max_size:
|
192 |
+
# img = cv2.resize(img, (max_size, max_size))
|
193 |
+
|
194 |
+
ds_results = EyeDentityDatasetCreation(
|
195 |
+
feature_extraction_configs=config_file[
|
196 |
+
"feature_extraction_configs"
|
197 |
+
],
|
198 |
+
sr_configs=config_file["sr_configs"],
|
199 |
+
)(img)
|
200 |
+
# if ds_results is not None:
|
201 |
+
# print("ds_results = ", ds_results.keys())
|
202 |
+
|
203 |
+
preprocess_steps = [
|
204 |
+
transforms.ToTensor(),
|
205 |
+
transforms.Resize(
|
206 |
+
[32, 64],
|
207 |
+
# interpolation=transforms.InterpolationMode.BILINEAR,
|
208 |
+
interpolation=transforms.InterpolationMode.BICUBIC,
|
209 |
+
antialias=True,
|
210 |
+
),
|
211 |
+
]
|
212 |
+
preprocess_function = transforms.Compose(preprocess_steps)
|
213 |
|
214 |
+
left_eye = None
|
215 |
+
right_eye = None
|
216 |
+
|
217 |
+
if ds_results is None:
|
218 |
+
# print("type of input_img = ", type(input_img))
|
219 |
+
input_img = preprocess_function(input_img)
|
220 |
+
input_img = input_img.unsqueeze(0)
|
221 |
+
if pupil_selection == "left_pupil":
|
222 |
+
left_eye = input_img
|
223 |
+
elif pupil_selection == "right_pupil":
|
224 |
+
right_eye = input_img
|
225 |
+
else:
|
226 |
+
left_eye = input_img
|
227 |
+
right_eye = input_img
|
228 |
+
# print("type of left_eye = ", type(left_eye))
|
229 |
+
# print("type of right_eye = ", type(right_eye))
|
230 |
+
elif "eyes" in ds_results.keys():
|
231 |
+
if (
|
232 |
+
"left_eye" in ds_results["eyes"].keys()
|
233 |
+
and ds_results["eyes"]["left_eye"] is not None
|
234 |
+
):
|
235 |
+
left_eye = ds_results["eyes"]["left_eye"]
|
236 |
+
# print("type of left_eye = ", type(left_eye))
|
237 |
+
left_eye = to_pil_image(left_eye).convert("RGB")
|
238 |
+
# print("type of left_eye = ", type(left_eye))
|
239 |
+
|
240 |
+
left_eye = preprocess_function(left_eye)
|
241 |
+
# print("type of left_eye = ", type(left_eye))
|
242 |
+
|
243 |
+
left_eye = left_eye.unsqueeze(0)
|
244 |
+
if (
|
245 |
+
"right_eye" in ds_results["eyes"].keys()
|
246 |
+
and ds_results["eyes"]["right_eye"] is not None
|
247 |
+
):
|
248 |
+
right_eye = ds_results["eyes"]["right_eye"]
|
249 |
+
# print("type of right_eye = ", type(right_eye))
|
250 |
+
right_eye = to_pil_image(right_eye).convert("RGB")
|
251 |
+
# print("type of right_eye = ", type(right_eye))
|
252 |
+
|
253 |
+
right_eye = preprocess_function(right_eye)
|
254 |
+
# print("type of right_eye = ", type(right_eye))
|
255 |
+
|
256 |
+
right_eye = right_eye.unsqueeze(0)
|
257 |
+
else:
|
258 |
+
# print("type of input_img = ", type(input_img))
|
259 |
+
input_img = preprocess_function(input_img)
|
260 |
+
input_img = input_img.unsqueeze(0)
|
261 |
+
if pupil_selection == "left_pupil":
|
262 |
+
left_eye = input_img
|
263 |
+
elif pupil_selection == "right_pupil":
|
264 |
+
right_eye = input_img
|
265 |
+
else:
|
266 |
+
left_eye = input_img
|
267 |
+
right_eye = input_img
|
268 |
+
# print("type of left_eye = ", type(left_eye))
|
269 |
+
# print("type of right_eye = ", type(right_eye))
|
270 |
+
|
271 |
+
# print("left_eye = ", left_eye.shape)
|
272 |
+
# print("right_eye = ", right_eye.shape)
|
273 |
+
|
274 |
+
if pupil_selection == "-":
|
275 |
+
selected_eyes = ["left_eye", "right_eye"]
|
276 |
+
elif pupil_selection == "left_pupil":
|
277 |
+
selected_eyes = ["left_eye"]
|
278 |
+
elif pupil_selection == "right_pupil":
|
279 |
+
selected_eyes = ["right_eye"]
|
280 |
+
|
281 |
+
for eye_type in selected_eyes:
|
282 |
+
|
283 |
+
model_configs = {
|
284 |
+
"model_path": root_path
|
285 |
+
+ f"/pre_trained_models/{tv_model}/{eye_type}.pt",
|
286 |
+
"registered_model_name": tv_model,
|
287 |
+
"num_classes": 1,
|
288 |
+
}
|
289 |
+
registered_model_name = model_configs["registered_model_name"]
|
290 |
+
model = _load_model(model_configs)
|
291 |
+
|
292 |
+
if registered_model_name == "ResNet18":
|
293 |
+
target_layer = model.resnet.layer4[-1].conv2
|
294 |
+
elif registered_model_name == "ResNet50":
|
295 |
+
target_layer = model.resnet.layer4[-1].conv3
|
296 |
+
else:
|
297 |
+
raise Exception(
|
298 |
+
f"No target layer available for selected model: {registered_model_name}"
|
299 |
+
)
|
300 |
|
301 |
+
if left_eye is not None and eye_type == "left_eye":
|
302 |
+
input_img = left_eye
|
303 |
+
elif right_eye is not None and eye_type == "right_eye":
|
304 |
+
input_img = right_eye
|
305 |
+
else:
|
306 |
+
raise Exception("Wrong Data")
|
307 |
+
|
308 |
+
if cam_method is not None:
|
309 |
+
cam_extractor = torchcam_methods.__dict__[cam_method](
|
310 |
+
model,
|
311 |
+
target_layer=target_layer,
|
312 |
+
fc_layer=model.resnet.fc,
|
313 |
+
input_shape=input_img.shape,
|
314 |
)
|
315 |
+
|
316 |
+
# with torch.no_grad():
|
317 |
+
out = model(input_img)
|
318 |
+
cols[-1].markdown(
|
319 |
+
f"<h3>Predicted Pupil Diameter: {out[0].item():.2f} mm</h3>",
|
320 |
+
unsafe_allow_html=True,
|
321 |
)
|
322 |
+
# cols[-1].text(f"Predicted Pupil Diameter: {out[0].item():.2f}")
|
323 |
|
324 |
+
# Retrieve the CAM
|
325 |
+
act_maps = cam_extractor(0, out)
|
|
|
|
|
326 |
|
327 |
+
# Fuse the CAMs if there are several
|
328 |
+
activation_map = (
|
329 |
+
act_maps[0]
|
330 |
+
if len(act_maps) == 1
|
331 |
+
else cam_extractor.fuse_cams(act_maps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
)
|
333 |
|
334 |
+
# Convert input image and activation map to PIL images
|
335 |
+
input_image_pil = to_pil_image(input_img.squeeze(0))
|
336 |
+
activation_map_pil = to_pil_image(activation_map, mode="F")
|
337 |
+
|
338 |
+
# Create the overlayed CAM result
|
339 |
+
result = overlay_mask(
|
340 |
+
input_image_pil,
|
341 |
+
activation_map_pil,
|
342 |
+
alpha=0.5,
|
343 |
+
)
|
344 |
+
|
345 |
+
# Create a subplot with 1 row and 2 columns
|
346 |
+
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
|
347 |
+
|
348 |
+
# Display the input image
|
349 |
+
axs[0].imshow(input_image_pil)
|
350 |
+
axs[0].axis("off")
|
351 |
+
axs[0].set_title("Input Image")
|
352 |
+
|
353 |
+
# Display the overlayed CAM result
|
354 |
+
axs[1].imshow(result)
|
355 |
+
axs[1].axis("off")
|
356 |
+
axs[1].set_title("Overlayed CAM")
|
357 |
+
|
358 |
+
# Display the plot
|
359 |
+
cols[-1].pyplot(fig)
|
360 |
+
cols[-1].text(
|
361 |
+
f"eye image size: {input_img.shape[-1]} x {input_img.shape[-2]}"
|
362 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
|
364 |
|
365 |
if __name__ == "__main__":
|
config.yml
CHANGED
@@ -19,9 +19,9 @@ xai_configs:
|
|
19 |
"InputXGradient",
|
20 |
"GuidedBackprop",
|
21 |
"Deconvolution",
|
22 |
-
"GuidedGradCam",
|
23 |
-
"LayerGradCam",
|
24 |
-
"LayerGradientXActivation",
|
25 |
]
|
26 |
cam_methods: [
|
27 |
"CAM",
|
|
|
19 |
"InputXGradient",
|
20 |
"GuidedBackprop",
|
21 |
"Deconvolution",
|
22 |
+
# "GuidedGradCam",
|
23 |
+
# "LayerGradCam",
|
24 |
+
# "LayerGradientXActivation",
|
25 |
]
|
26 |
cam_methods: [
|
27 |
"CAM",
|
feature_extraction/extractor_mediapipe.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import warnings
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
from math import sqrt
|
7 |
+
import mediapipe as mp
|
8 |
+
from transformers import pipeline
|
9 |
+
|
10 |
+
warnings.filterwarnings("ignore")
|
11 |
+
|
12 |
+
|
13 |
+
class ExtractorMediaPipe:
|
14 |
+
|
15 |
+
def __init__(self, upscale=1):
|
16 |
+
|
17 |
+
self.upscale = int(upscale)
|
18 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
|
20 |
+
# ========== Face Extraction ==========
|
21 |
+
self.face_detector = mp.solutions.face_detection.FaceDetection(
|
22 |
+
model_selection=0, min_detection_confidence=0.5
|
23 |
+
)
|
24 |
+
self.face_mesh = mp.solutions.face_mesh.FaceMesh(
|
25 |
+
max_num_faces=1,
|
26 |
+
static_image_mode=True,
|
27 |
+
refine_landmarks=True,
|
28 |
+
min_detection_confidence=0.5,
|
29 |
+
min_tracking_confidence=0.5,
|
30 |
+
)
|
31 |
+
|
32 |
+
# ========== Eyes Extraction ==========
|
33 |
+
self.RIGHT_EYE = [
|
34 |
+
362,
|
35 |
+
382,
|
36 |
+
381,
|
37 |
+
380,
|
38 |
+
374,
|
39 |
+
373,
|
40 |
+
390,
|
41 |
+
249,
|
42 |
+
263,
|
43 |
+
466,
|
44 |
+
388,
|
45 |
+
387,
|
46 |
+
386,
|
47 |
+
385,
|
48 |
+
384,
|
49 |
+
398,
|
50 |
+
]
|
51 |
+
self.LEFT_EYE = [
|
52 |
+
33,
|
53 |
+
7,
|
54 |
+
163,
|
55 |
+
144,
|
56 |
+
145,
|
57 |
+
153,
|
58 |
+
154,
|
59 |
+
155,
|
60 |
+
133,
|
61 |
+
173,
|
62 |
+
157,
|
63 |
+
158,
|
64 |
+
159,
|
65 |
+
160,
|
66 |
+
161,
|
67 |
+
246,
|
68 |
+
]
|
69 |
+
# https://huggingface.co/dima806/closed_eyes_image_detection
|
70 |
+
# https://www.kaggle.com/code/dima806/closed-eye-image-detection-vit
|
71 |
+
self.pipe = pipeline(
|
72 |
+
"image-classification",
|
73 |
+
model="dima806/closed_eyes_image_detection",
|
74 |
+
device=self.device,
|
75 |
+
)
|
76 |
+
self.blink_lower_thresh = 0.22
|
77 |
+
self.blink_upper_thresh = 0.25
|
78 |
+
self.blink_confidence = 0.50
|
79 |
+
|
80 |
+
# ========== Iris Extraction ==========
|
81 |
+
self.RIGHT_IRIS = [474, 475, 476, 477]
|
82 |
+
self.LEFT_IRIS = [469, 470, 471, 472]
|
83 |
+
|
84 |
+
def extract_face(self, image):
|
85 |
+
|
86 |
+
tmp_image = image.copy()
|
87 |
+
results = self.face_detector.process(tmp_image)
|
88 |
+
|
89 |
+
if not results.detections:
|
90 |
+
# print("No face detected")
|
91 |
+
return None
|
92 |
+
else:
|
93 |
+
bboxC = results.detections[0].location_data.relative_bounding_box
|
94 |
+
ih, iw, _ = image.shape
|
95 |
+
|
96 |
+
# Get bounding box coordinates
|
97 |
+
x, y, w, h = (
|
98 |
+
int(bboxC.xmin * iw),
|
99 |
+
int(bboxC.ymin * ih),
|
100 |
+
int(bboxC.width * iw),
|
101 |
+
int(bboxC.height * ih),
|
102 |
+
)
|
103 |
+
|
104 |
+
# Calculate the center of the bounding box
|
105 |
+
center_x = x + w // 2
|
106 |
+
center_y = y + h // 2
|
107 |
+
|
108 |
+
# Calculate new bounds ensuring they fit within the image dimensions
|
109 |
+
half_size = 128 * self.upscale
|
110 |
+
x1 = max(center_x - half_size, 0)
|
111 |
+
y1 = max(center_y - half_size, 0)
|
112 |
+
x2 = min(center_x + half_size, iw)
|
113 |
+
y2 = min(center_y + half_size, ih)
|
114 |
+
|
115 |
+
# Adjust x1, x2, y1, and y2 to ensure the cropped region is exactly (256 * self.upscale) x (256 * self.upscale)
|
116 |
+
if x2 - x1 < (256 * self.upscale):
|
117 |
+
if x1 == 0:
|
118 |
+
x2 = min((256 * self.upscale), iw)
|
119 |
+
elif x2 == iw:
|
120 |
+
x1 = max(iw - (256 * self.upscale), 0)
|
121 |
+
|
122 |
+
if y2 - y1 < (256 * self.upscale):
|
123 |
+
if y1 == 0:
|
124 |
+
y2 = min((256 * self.upscale), ih)
|
125 |
+
elif y2 == ih:
|
126 |
+
y1 = max(ih - (256 * self.upscale), 0)
|
127 |
+
|
128 |
+
cropped_face = image[y1:y2, x1:x2]
|
129 |
+
|
130 |
+
# bicubic upsampling
|
131 |
+
# if self.upscale != 1:
|
132 |
+
# cropped_face = cv2.resize(
|
133 |
+
# cropped_face,
|
134 |
+
# (256 * self.upscale, 256 * self.upscale),
|
135 |
+
# interpolation=cv2.INTER_CUBIC,
|
136 |
+
# )
|
137 |
+
|
138 |
+
return cropped_face
|
139 |
+
|
140 |
+
@staticmethod
|
141 |
+
def landmarksDetection(image, results, draw=False):
|
142 |
+
image_height, image_width = image.shape[:2]
|
143 |
+
mesh_coordinates = [
|
144 |
+
(int(point.x * image_width), int(point.y * image_height))
|
145 |
+
for point in results.multi_face_landmarks[0].landmark
|
146 |
+
]
|
147 |
+
if draw:
|
148 |
+
[cv2.circle(image, i, 2, (0, 255, 0), -1) for i in mesh_coordinates]
|
149 |
+
return mesh_coordinates
|
150 |
+
|
151 |
+
@staticmethod
|
152 |
+
def euclideanDistance(point, point1):
|
153 |
+
x, y = point
|
154 |
+
x1, y1 = point1
|
155 |
+
distance = sqrt((x1 - x) ** 2 + (y1 - y) ** 2)
|
156 |
+
return distance
|
157 |
+
|
158 |
+
def blinkRatio(self, landmarks, right_indices, left_indices):
|
159 |
+
|
160 |
+
right_eye_landmark1 = landmarks[right_indices[0]]
|
161 |
+
right_eye_landmark2 = landmarks[right_indices[8]]
|
162 |
+
|
163 |
+
right_eye_landmark3 = landmarks[right_indices[12]]
|
164 |
+
right_eye_landmark4 = landmarks[right_indices[4]]
|
165 |
+
|
166 |
+
left_eye_landmark1 = landmarks[left_indices[0]]
|
167 |
+
left_eye_landmark2 = landmarks[left_indices[8]]
|
168 |
+
|
169 |
+
left_eye_landmark3 = landmarks[left_indices[12]]
|
170 |
+
left_eye_landmark4 = landmarks[left_indices[4]]
|
171 |
+
|
172 |
+
right_eye_horizontal_distance = self.euclideanDistance(
|
173 |
+
right_eye_landmark1, right_eye_landmark2
|
174 |
+
)
|
175 |
+
right_eye_vertical_distance = self.euclideanDistance(
|
176 |
+
right_eye_landmark3, right_eye_landmark4
|
177 |
+
)
|
178 |
+
|
179 |
+
left_eye_vertical_distance = self.euclideanDistance(
|
180 |
+
left_eye_landmark3, left_eye_landmark4
|
181 |
+
)
|
182 |
+
left_eye_horizontal_distance = self.euclideanDistance(
|
183 |
+
left_eye_landmark1, left_eye_landmark2
|
184 |
+
)
|
185 |
+
|
186 |
+
right_eye_ratio = right_eye_vertical_distance / right_eye_horizontal_distance
|
187 |
+
left_eye_ratio = left_eye_vertical_distance / left_eye_horizontal_distance
|
188 |
+
|
189 |
+
eyes_ratio = (right_eye_ratio + left_eye_ratio) / 2
|
190 |
+
|
191 |
+
return eyes_ratio
|
192 |
+
|
193 |
+
def extract_eyes_regions(self, image, landmarks, eye_indices):
|
194 |
+
h, w, _ = image.shape
|
195 |
+
points = [
|
196 |
+
(int(landmarks[idx].x * w), int(landmarks[idx].y * h))
|
197 |
+
for idx in eye_indices
|
198 |
+
]
|
199 |
+
|
200 |
+
x_min = min([p[0] for p in points])
|
201 |
+
x_max = max([p[0] for p in points])
|
202 |
+
y_min = min([p[1] for p in points])
|
203 |
+
y_max = max([p[1] for p in points])
|
204 |
+
|
205 |
+
center_x = (x_min + x_max) // 2
|
206 |
+
center_y = (y_min + y_max) // 2
|
207 |
+
|
208 |
+
target_width = 32 * self.upscale
|
209 |
+
target_height = 16 * self.upscale
|
210 |
+
|
211 |
+
x1 = max(center_x - target_width // 2, 0)
|
212 |
+
y1 = max(center_y - target_height // 2, 0)
|
213 |
+
x2 = x1 + target_width
|
214 |
+
y2 = y1 + target_height
|
215 |
+
|
216 |
+
if x2 > w:
|
217 |
+
x1 = w - target_width
|
218 |
+
x2 = w
|
219 |
+
if y2 > h:
|
220 |
+
y1 = h - target_height
|
221 |
+
y2 = h
|
222 |
+
|
223 |
+
return image[y1:y2, x1:x2]
|
224 |
+
|
225 |
+
def blink_detection_model(self, left_eye, right_eye):
|
226 |
+
|
227 |
+
left_eye = cv2.cvtColor(left_eye, cv2.COLOR_RGB2GRAY)
|
228 |
+
left_eye = Image.fromarray(left_eye)
|
229 |
+
preds_left = self.pipe(left_eye)
|
230 |
+
if preds_left[0]["label"] == "closeEye":
|
231 |
+
closed_left = preds_left[0]["score"] >= self.blink_confidence
|
232 |
+
else:
|
233 |
+
closed_left = preds_left[1]["score"] >= self.blink_confidence
|
234 |
+
|
235 |
+
right_eye = cv2.cvtColor(right_eye, cv2.COLOR_RGB2GRAY)
|
236 |
+
right_eye = Image.fromarray(right_eye)
|
237 |
+
preds_right = self.pipe(right_eye)
|
238 |
+
if preds_right[0]["label"] == "closeEye":
|
239 |
+
closed_right = preds_right[0]["score"] >= self.blink_confidence
|
240 |
+
else:
|
241 |
+
closed_right = preds_right[1]["score"] >= self.blink_confidence
|
242 |
+
|
243 |
+
# print("preds_left = ", preds_left)
|
244 |
+
# print("preds_right = ", preds_right)
|
245 |
+
|
246 |
+
return closed_left or closed_right
|
247 |
+
|
248 |
+
def extract_eyes(self, image, blink_detection=False):
|
249 |
+
|
250 |
+
tmp_face = image.copy()
|
251 |
+
results = self.face_mesh.process(tmp_face)
|
252 |
+
|
253 |
+
if results.multi_face_landmarks is None:
|
254 |
+
return None
|
255 |
+
|
256 |
+
face_landmarks = results.multi_face_landmarks[0].landmark
|
257 |
+
|
258 |
+
left_eye = self.extract_eyes_regions(image, face_landmarks, self.LEFT_EYE)
|
259 |
+
right_eye = self.extract_eyes_regions(image, face_landmarks, self.RIGHT_EYE)
|
260 |
+
blinked = False
|
261 |
+
|
262 |
+
if blink_detection:
|
263 |
+
mesh_coordinates = self.landmarksDetection(image, results, False)
|
264 |
+
eyes_ratio = self.blinkRatio(
|
265 |
+
mesh_coordinates, self.RIGHT_EYE, self.LEFT_EYE
|
266 |
+
)
|
267 |
+
if (
|
268 |
+
eyes_ratio > self.blink_lower_thresh
|
269 |
+
and eyes_ratio <= self.blink_upper_thresh
|
270 |
+
):
|
271 |
+
# print(
|
272 |
+
# "I think person blinked. eyes_ratio = ",
|
273 |
+
# eyes_ratio,
|
274 |
+
# "Confirming with ViT model...",
|
275 |
+
# )
|
276 |
+
blinked = self.blink_detection_model(
|
277 |
+
left_eye=left_eye, right_eye=right_eye
|
278 |
+
)
|
279 |
+
# if blinked:
|
280 |
+
# print("Yes, person blinked. Confirmed by model")
|
281 |
+
# else:
|
282 |
+
# print("No, person didn't blinked. False Alarm")
|
283 |
+
elif eyes_ratio <= self.blink_lower_thresh:
|
284 |
+
blinked = True
|
285 |
+
# print("Surely person blinked. eyes_ratio = ", eyes_ratio)
|
286 |
+
else:
|
287 |
+
blinked = False
|
288 |
+
|
289 |
+
return {"left_eye": left_eye, "right_eye": right_eye, "blinked": blinked}
|
290 |
+
|
291 |
+
@staticmethod
|
292 |
+
def segment_iris(iris_img):
|
293 |
+
|
294 |
+
# Convert RGB image to grayscale
|
295 |
+
iris_img_gray = cv2.cvtColor(iris_img, cv2.COLOR_RGB2GRAY)
|
296 |
+
|
297 |
+
# Apply Gaussian blur for denoising
|
298 |
+
iris_img_blur = cv2.GaussianBlur(iris_img_gray, (5, 5), 0)
|
299 |
+
|
300 |
+
# Perform adaptive thresholding
|
301 |
+
_, iris_img_mask = cv2.threshold(
|
302 |
+
iris_img_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
|
303 |
+
)
|
304 |
+
|
305 |
+
# Invert the mask
|
306 |
+
segmented_mask = cv2.bitwise_not(iris_img_mask)
|
307 |
+
segmented_mask = cv2.cvtColor(segmented_mask, cv2.COLOR_GRAY2RGB)
|
308 |
+
segmented_iris = cv2.bitwise_and(iris_img, segmented_mask)
|
309 |
+
|
310 |
+
return {
|
311 |
+
"segmented_iris": segmented_iris,
|
312 |
+
"segmented_mask": segmented_mask,
|
313 |
+
}
|
314 |
+
|
315 |
+
def extract_iris(self, image):
|
316 |
+
|
317 |
+
ih, iw, _ = image.shape
|
318 |
+
tmp_face = image.copy()
|
319 |
+
results = self.face_mesh.process(tmp_face)
|
320 |
+
|
321 |
+
if results.multi_face_landmarks is None:
|
322 |
+
return None
|
323 |
+
|
324 |
+
mesh_coordinates = self.landmarksDetection(image, results, False)
|
325 |
+
mesh_points = np.array(mesh_coordinates)
|
326 |
+
|
327 |
+
(l_cx, l_cy), l_radius = cv2.minEnclosingCircle(mesh_points[self.LEFT_IRIS])
|
328 |
+
(r_cx, r_cy), r_radius = cv2.minEnclosingCircle(mesh_points[self.RIGHT_IRIS])
|
329 |
+
|
330 |
+
# Crop the left iris to be exactly 16*upscaled x 16*upscaled
|
331 |
+
l_x1 = max(int(l_cx) - (8 * self.upscale), 0)
|
332 |
+
l_y1 = max(int(l_cy) - (8 * self.upscale), 0)
|
333 |
+
l_x2 = min(int(l_cx) + (8 * self.upscale), iw)
|
334 |
+
l_y2 = min(int(l_cy) + (8 * self.upscale), ih)
|
335 |
+
|
336 |
+
cropped_left_iris = image[l_y1:l_y2, l_x1:l_x2]
|
337 |
+
|
338 |
+
left_iris_segmented_data = self.segment_iris(
|
339 |
+
cv2.cvtColor(cropped_left_iris, cv2.COLOR_BGR2RGB)
|
340 |
+
)
|
341 |
+
|
342 |
+
# Crop the right iris to be exactly 16*upscaled x 16*upscaled
|
343 |
+
r_x1 = max(int(r_cx) - (8 * self.upscale), 0)
|
344 |
+
r_y1 = max(int(r_cy) - (8 * self.upscale), 0)
|
345 |
+
r_x2 = min(int(r_cx) + (8 * self.upscale), iw)
|
346 |
+
r_y2 = min(int(r_cy) + (8 * self.upscale), ih)
|
347 |
+
|
348 |
+
cropped_right_iris = image[r_y1:r_y2, r_x1:r_x2]
|
349 |
+
|
350 |
+
right_iris_segmented_data = self.segment_iris(
|
351 |
+
cv2.cvtColor(cropped_right_iris, cv2.COLOR_BGR2RGB)
|
352 |
+
)
|
353 |
+
|
354 |
+
return {
|
355 |
+
"left_iris": {
|
356 |
+
"img": cropped_left_iris,
|
357 |
+
"segmented_iris": left_iris_segmented_data["segmented_iris"],
|
358 |
+
"segmented_mask": left_iris_segmented_data["segmented_mask"],
|
359 |
+
},
|
360 |
+
"right_iris": {
|
361 |
+
"img": cropped_right_iris,
|
362 |
+
"segmented_iris": right_iris_segmented_data["segmented_iris"],
|
363 |
+
"segmented_mask": right_iris_segmented_data["segmented_mask"],
|
364 |
+
},
|
365 |
+
}
|
feature_extraction/features_extractor.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import warnings
|
4 |
+
import os.path as osp
|
5 |
+
|
6 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
7 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
8 |
+
sys.path.append(root_path)
|
9 |
+
|
10 |
+
from feature_extraction.extractor_mediapipe import ExtractorMediaPipe
|
11 |
+
|
12 |
+
warnings.filterwarnings("ignore")
|
13 |
+
|
14 |
+
|
15 |
+
class FeaturesExtractor:
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self, extraction_library="mediapipe", blink_detection=False, upscale=1
|
19 |
+
):
|
20 |
+
self.upscale = upscale
|
21 |
+
self.blink_detection = blink_detection
|
22 |
+
self.extraction_library = extraction_library
|
23 |
+
self.feature_extractor = ExtractorMediaPipe(self.upscale)
|
24 |
+
|
25 |
+
def __call__(self, image):
|
26 |
+
results = {}
|
27 |
+
face = self.feature_extractor.extract_face(image)
|
28 |
+
if face is None:
|
29 |
+
# print("No face found. Skipped feature extraction!")
|
30 |
+
return None
|
31 |
+
else:
|
32 |
+
results["img"] = image
|
33 |
+
results["face"] = face
|
34 |
+
eyes_data = self.feature_extractor.extract_eyes(image, self.blink_detection)
|
35 |
+
if eyes_data is None:
|
36 |
+
# print("No eyes found. Skipped feature extraction!")
|
37 |
+
return results
|
38 |
+
else:
|
39 |
+
results["eyes"] = eyes_data
|
40 |
+
if eyes_data["blinked"]:
|
41 |
+
# print("Found blinked eyes!")
|
42 |
+
return results
|
43 |
+
else:
|
44 |
+
iris_data = self.feature_extractor.extract_iris(image)
|
45 |
+
if iris_data is None:
|
46 |
+
# print("No iris found. Skipped feature extraction!")
|
47 |
+
return results
|
48 |
+
else:
|
49 |
+
results["iris"] = iris_data
|
50 |
+
return results
|
packages.txt
DELETED
File without changes
|
preprocessing/dataset_creation.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import cv2
|
3 |
+
import os.path as osp
|
4 |
+
|
5 |
+
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
|
6 |
+
sys.path.append(root_path)
|
7 |
+
|
8 |
+
from preprocessing.dataset_creation_utils import get_sr_method
|
9 |
+
from feature_extraction.features_extractor import FeaturesExtractor
|
10 |
+
|
11 |
+
|
12 |
+
class EyeDentityDatasetCreation:
|
13 |
+
|
14 |
+
def __init__(self, feature_extraction_configs, sr_configs=None):
|
15 |
+
self.extraction_library = feature_extraction_configs["extraction_library"]
|
16 |
+
self.sr_configs = sr_configs
|
17 |
+
if self.sr_configs:
|
18 |
+
self.sr_method_name = sr_configs["method"]
|
19 |
+
self.upscale = sr_configs["params"]["upscale"]
|
20 |
+
if self.sr_method_name != "-":
|
21 |
+
self.sr_method = get_sr_method(self, sr_configs)
|
22 |
+
else:
|
23 |
+
self.upscale = 1
|
24 |
+
|
25 |
+
self.blink_detection = feature_extraction_configs["blink_detection"]
|
26 |
+
self.features_extractor = FeaturesExtractor(
|
27 |
+
extraction_library=self.extraction_library,
|
28 |
+
blink_detection=self.blink_detection,
|
29 |
+
upscale=self.upscale,
|
30 |
+
)
|
31 |
+
|
32 |
+
def __call__(self, img):
|
33 |
+
# img = cv2.imread(img)
|
34 |
+
if self.sr_configs is None or self.sr_configs != "-":
|
35 |
+
img = img
|
36 |
+
else:
|
37 |
+
img = self.sr_method(img)
|
38 |
+
|
39 |
+
result_dict = self.features_extractor(img)
|
40 |
+
return result_dict
|
preprocessing/dataset_creation_utils.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
from SR_Inference.inference_hat import HAT
|
6 |
+
from SR_Inference.inference_gfpgan import GFPGAN
|
7 |
+
from SR_Inference.inference_realesr import RealEsr
|
8 |
+
from SR_Inference.inference_srresnet import SRResNet
|
9 |
+
from SR_Inference.inference_codeformer import CodeFormer
|
10 |
+
|
11 |
+
|
12 |
+
def seed_everything(seed=42):
|
13 |
+
random.seed(seed)
|
14 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
15 |
+
np.random.seed(seed)
|
16 |
+
torch.manual_seed(seed)
|
17 |
+
torch.cuda.manual_seed(seed)
|
18 |
+
torch.backends.cudnn.benchmark = True
|
19 |
+
torch.backends.cudnn.deterministic = True
|
20 |
+
|
21 |
+
|
22 |
+
def get_sr_method(self, sr_configs):
|
23 |
+
sr_method_class = globals().get(self.sr_method_name)
|
24 |
+
if sr_method_class is not None:
|
25 |
+
return sr_method_class(**sr_configs["params"])
|
26 |
+
else:
|
27 |
+
raise Exception(
|
28 |
+
f"No such SR method called '{self.sr_method_name}' implemented!"
|
29 |
+
)
|
registrations/models.py
CHANGED
@@ -121,4 +121,4 @@ class ResNet50(nn.Module):
|
|
121 |
return x
|
122 |
|
123 |
|
124 |
-
print("Registered models in MODEL_REGISTRY:", MODEL_REGISTRY.keys())
|
|
|
121 |
return x
|
122 |
|
123 |
|
124 |
+
# print("Registered models in MODEL_REGISTRY:", MODEL_REGISTRY.keys())
|
registry_utils.py
CHANGED
@@ -58,22 +58,22 @@ def import_registered_modules(registration_folder="registrations"):
|
|
58 |
list: List of imported modules.
|
59 |
"""
|
60 |
|
61 |
-
print("\n")
|
62 |
|
63 |
registration_modules_folder = (
|
64 |
osp.dirname(osp.abspath(__file__)) + f"/{registration_folder}"
|
65 |
)
|
66 |
-
print("registration_modules_folder = ", registration_modules_folder)
|
67 |
|
68 |
registration_modules_file_names = [
|
69 |
osp.splitext(osp.basename(v))[0]
|
70 |
for v in scandir(dir_path=registration_modules_folder)
|
71 |
]
|
72 |
-
print("registration_modules_file_names = ", registration_modules_file_names)
|
73 |
|
74 |
imported_modules = [
|
75 |
importlib.import_module(f"{registration_folder}.{file_name}")
|
76 |
for file_name in registration_modules_file_names
|
77 |
]
|
78 |
-
print("imported_modules = ", imported_modules)
|
79 |
-
print("\n")
|
|
|
58 |
list: List of imported modules.
|
59 |
"""
|
60 |
|
61 |
+
# print("\n")
|
62 |
|
63 |
registration_modules_folder = (
|
64 |
osp.dirname(osp.abspath(__file__)) + f"/{registration_folder}"
|
65 |
)
|
66 |
+
# print("registration_modules_folder = ", registration_modules_folder)
|
67 |
|
68 |
registration_modules_file_names = [
|
69 |
osp.splitext(osp.basename(v))[0]
|
70 |
for v in scandir(dir_path=registration_modules_folder)
|
71 |
]
|
72 |
+
# print("registration_modules_file_names = ", registration_modules_file_names)
|
73 |
|
74 |
imported_modules = [
|
75 |
importlib.import_module(f"{registration_folder}.{file_name}")
|
76 |
for file_name in registration_modules_file_names
|
77 |
]
|
78 |
+
# print("imported_modules = ", imported_modules)
|
79 |
+
# print("\n")
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
tqdm
|
2 |
PyYAML
|
3 |
numpy
|
@@ -10,7 +11,7 @@ scikit_learn
|
|
10 |
torch
|
11 |
captum
|
12 |
evaluate
|
13 |
-
|
14 |
facexlib
|
15 |
realesrgan
|
16 |
opencv_python
|
@@ -18,7 +19,7 @@ cmake
|
|
18 |
dlib
|
19 |
einops
|
20 |
transformers
|
21 |
-
|
22 |
# streamlit
|
23 |
mediapipe
|
24 |
imutils
|
|
|
1 |
+
# https://huggingface.co/docs/hub/en/spaces-dependencies
|
2 |
tqdm
|
3 |
PyYAML
|
4 |
numpy
|
|
|
11 |
torch
|
12 |
captum
|
13 |
evaluate
|
14 |
+
basicsr
|
15 |
facexlib
|
16 |
realesrgan
|
17 |
opencv_python
|
|
|
19 |
dlib
|
20 |
einops
|
21 |
transformers
|
22 |
+
gfpgan
|
23 |
# streamlit
|
24 |
mediapipe
|
25 |
imutils
|
xx-packages.txt
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://huggingface.co/docs/hub/en/spaces-dependencies
|
2 |
+
# tqdm
|
3 |
+
# PyYAML
|
4 |
+
# numpy
|
5 |
+
# pandas
|
6 |
+
# matplotlib
|
7 |
+
# seaborn
|
8 |
+
# mlflow
|
9 |
+
# pillow
|
10 |
+
# scikit_learn
|
11 |
+
# torch
|
12 |
+
# captum
|
13 |
+
# evaluate
|
14 |
+
# basicsr
|
15 |
+
# facexlib
|
16 |
+
# realesrgan
|
17 |
+
# opencv_python
|
18 |
+
# cmake
|
19 |
+
# dlib
|
20 |
+
# einops
|
21 |
+
# transformers
|
22 |
+
# gfpgan
|
23 |
+
# streamlit
|
24 |
+
# mediapipe
|
25 |
+
# imutils
|
26 |
+
# scipy
|
27 |
+
# torchvision==0.16.0
|
28 |
+
# torchcam
|