mv-lab commited on
Commit
39417b0
·
1 Parent(s): 616408c

InstructIR x HF

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ *.pt
3
+ *.gif
4
+ *.pth
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import os
6
+ import torch
7
+ import numpy as np
8
+ import yaml
9
+
10
+ #from gradio_imageslider import ImageSlider
11
+
12
+ ## local code
13
+ from models import instructir
14
+ from text.models import LanguageModel, LMHead
15
+
16
+
17
+ def dict2namespace(config):
18
+ namespace = argparse.Namespace()
19
+ for key, value in config.items():
20
+ if isinstance(value, dict):
21
+ new_value = dict2namespace(value)
22
+ else:
23
+ new_value = value
24
+ setattr(namespace, key, new_value)
25
+ return namespace
26
+
27
+
28
+ CONFIG = "configs/eval5d.yml"
29
+ LM_MODEL = "models/lm_instructir-7d.pt"
30
+ MODEL_NAME = "models/im_instructir-7d.pt"
31
+
32
+ # parse config file
33
+ with open(os.path.join(CONFIG), "r") as f:
34
+ config = yaml.safe_load(f)
35
+
36
+ cfg = dict2namespace(config)
37
+
38
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
39
+ model = instructir.create_model(input_channels =cfg.model.in_ch, width=cfg.model.width, enc_blks = cfg.model.enc_blks,
40
+ middle_blk_num = cfg.model.middle_blk_num, dec_blks = cfg.model.dec_blks, txtdim=cfg.model.textdim)
41
+ model = model.to(device)
42
+ print ("IMAGE MODEL CKPT:", MODEL_NAME)
43
+ model.load_state_dict(torch.load(MODEL_NAME, map_location="cpu"), strict=True)
44
+
45
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
46
+ LMODEL = cfg.llm.model
47
+ language_model = LanguageModel(model=LMODEL)
48
+ lm_head = LMHead(embedding_dim=cfg.llm.model_dim, hidden_dim=cfg.llm.embd_dim, num_classes=cfg.llm.nclasses)
49
+ lm_head = lm_head.to(device)
50
+
51
+ print("LMHEAD MODEL CKPT:", LM_MODEL)
52
+ lm_head.load_state_dict(torch.load(LM_MODEL, map_location="cpu"), strict=True)
53
+
54
+
55
+ def load_img (filename, norm=True,):
56
+ img = np.array(Image.open(filename).convert("RGB"))
57
+ if norm:
58
+ img = img / 255.
59
+ img = img.astype(np.float32)
60
+ return img
61
+
62
+
63
+ def process_img (image, prompt):
64
+ img = np.array(image)
65
+ img = img / 255.
66
+ img = img.astype(np.float32)
67
+ y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
68
+
69
+ lm_embd = language_model(prompt)
70
+ lm_embd = lm_embd.to(device)
71
+
72
+ with torch.no_grad():
73
+ text_embd, deg_pred = lm_head (lm_embd)
74
+ x_hat = model(y, text_embd)
75
+
76
+ restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
77
+ restored_img = np.clip(restored_img, 0. , 1.)
78
+
79
+ restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
80
+ return Image.fromarray(restored_img) #(image, Image.fromarray(restored_img))
81
+
82
+
83
+
84
+ title = "InstructIR ✏️🖼️ 🤗"
85
+ description = ''' ## [High-Quality Image Restoration Following Human Instructions](https://github.com/mv-lab/InstructIR)
86
+
87
+ [Marcos V. Conde](https://scholar.google.com/citations?user=NtB1kjYAAAAJ&hl=en), [Gregor Geigle](https://scholar.google.com/citations?user=uIlyqRwAAAAJ&hl=en), [Radu Timofte](https://scholar.google.com/citations?user=u3MwH5kAAAAJ&hl=en)
88
+
89
+ Computer Vision Lab, University of Wuerzburg | Sony PlayStation, FTG
90
+
91
+ ### TL;DR: quickstart
92
+ InstructIR takes as input an image and a human-written instruction for how to improve that image. The neural model performs all-in-one image restoration. InstructIR achieves state-of-the-art results on several restoration tasks including image denoising, deraining, deblurring, dehazing, and (low-light) image enhancement.
93
+
94
+ **🚀 You can start with the [demo tutorial](https://github.com/mv-lab/InstructIR/blob/main/demo.ipynb)**
95
+
96
+ <details>
97
+ <summary> <b> Abstract</b> (click me to read)</summary>
98
+ <p>
99
+ Image restoration is a fundamental problem that involves recovering a high-quality clean image from its degraded observation. All-In-One image restoration models can effectively restore images from various types and levels of degradation using degradation-specific information as prompts to guide the restoration model. In this work, we present the first approach that uses human-written instructions to guide the image restoration model. Given natural language prompts, our model can recover high-quality images from their degraded counterparts, considering multiple degradation types. Our method, InstructIR, achieves state-of-the-art results on several restoration tasks including image denoising, deraining, deblurring, dehazing, and (low-light) image enhancement. InstructIR improves +1dB over previous all-in-one restoration methods. Moreover, our dataset and results represent a novel benchmark for new research on text-guided image restoration and enhancement.
100
+ </p>
101
+ </details>
102
+
103
+ > Disclaimer: please remember this is not a product, thus, you will notice some limitations.
104
+
105
+ **This demo expects an image with some degradations (blur, noise, rain, low-light, haze) and a prompt requesting what should be done.**
106
+ Due to the GPU memory limitations, the app might crash if you feed a high-resolution image (2K, 4K).
107
+
108
+ <br>
109
+ '''
110
+ # **Demo notebook can be found [here](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Swin2SR/Perform_image_super_resolution_with_Swin2SR.ipynb).
111
+
112
+ article = "<p style='text-align: center'><a href='https://github.com/mv-lab/InstructIR' target='_blank'>High-Quality Image Restoration Following Human Instructions</a></p>"
113
+
114
+ examples = [['images/rain-020.png', "I love this photo, could you remove the raindrops? please keep the content intact"],
115
+ ['images/gradio_demo_images/city.jpg', "I took this photo during a foggy day, can you improve it?"],
116
+ ['images/gradio_demo_images/frog.png', "can you remove the tiny dots in the image? it is very unpleasant"],
117
+ ["images/lol_748.png", "my image is too dark, I cannot see anything, can you fix it?"],
118
+ ["images/gopro.png", "I took this photo while I was running, can you stabilize the image? it is too blurry"],
119
+ ["images/a0010.jpg", "please I want this image for my photo album, can you edit it as a photographer"]]
120
+
121
+ css = """
122
+ .image-frame img, .image-container img {
123
+ width: auto;
124
+ height: auto;
125
+ max-width: none;
126
+ }
127
+ """
128
+
129
+ demo = gr.Interface(
130
+ fn=process_img,
131
+ inputs=[
132
+ gr.Image(type="pil", label="Input"),
133
+ gr.Text(label="Prompt")
134
+ ],
135
+ outputs=[gr.Image(type="pil", label="Ouput")], #ImageSlider(position=0.5, type="pil", label="SideBySide")], #gr.Image(type="pil", label="Ouput"), #
136
+ title=title,
137
+ description=description,
138
+ article=article,
139
+ examples=examples,
140
+ css=css,
141
+ )
142
+
143
+ if __name__ == "__main__":
144
+ demo.launch()
145
+
146
+ # with gr.Blocks() as demo:
147
+ # with gr.Row(equal_height=True):
148
+ # with gr.Column(scale=1):
149
+ # input = gr.Image(type="pil", label="Input")
150
+ # with gr.Column(scale=1):
151
+ # prompt = gr.Text(label="Prompt")
152
+ # process_btn = gr.Button("Process")
153
+ # with gr.Row(equal_height=True):
154
+ # output = gr.Image(type="pil", label="Ouput")
155
+ # slider = ImageSlider(position=0.5, type="pil", label="SideBySide")
156
+ # process_btn.click(fn=process_img, inputs=[input, prompt], outputs=[output, slider])
157
+ # demo.launch(share=True)
configs/eval5d.yml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ llm:
2
+ model: 'TaylorAI/bge-micro-v2' # See Paper Sec. 3.2 and Appendix
3
+ model_dim: 384
4
+ embd_dim: 256
5
+ nclasses: 7 # noise, blur, rain, haze, lol, enhancement, upsampling (Paper Sec. 4.3)
6
+ weights: False
7
+
8
+ model:
9
+ arch: "instructir"
10
+ use_text: True
11
+ in_ch: 3
12
+ out_ch: 3
13
+ width : 32
14
+ enc_blks: [2, 2, 4, 8]
15
+ middle_blk_num: 4
16
+ dec_blks: [2, 2, 2, 2]
17
+ textdim: 256
18
+ weights: False
19
+
20
+ test:
21
+ batch_size: 1
22
+ num_workers: 3
23
+
24
+ dn_datapath: "data/denoising_testsets/"
25
+ dn_datasets: ["CBSD68", "urban100", "Kodak24", "McMaster"]
26
+ dn_sigmas: [15, 25, 50]
27
+
28
+ rain_targets: ["data/Rain/rain_test/Rain100L/target/"]
29
+ rain_inputs: ["data/Rain/rain_test/Rain100L/input/"]
30
+
31
+ haze_targets: "data/SOTS-OUT/GT/"
32
+ haze_inputs : "data/SOTS-OUT/IN/"
33
+
34
+ lol_targets: "data/LOL/eval15/high/"
35
+ lol_inputs : "data/LOL/eval15/low/"
36
+
37
+ gopro_targets: "data/gopro_test/GoPro/target/"
38
+ gopro_inputs: "data/gopro_test/GoPro/input/"
39
+
40
+
images/a0010.jpg ADDED
images/frog.png ADDED
images/gopro.png ADDED
images/gradio_demo_images/bear.png ADDED
images/gradio_demo_images/city.jpg ADDED
images/gradio_demo_images/frog.png ADDED
images/lol_1.png ADDED
images/lol_748.png ADDED
images/noise50.png ADDED
images/rain-020.png ADDED
models/instructir.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import init as init
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ from models.nafnet_utils import Local_Base, LayerNorm2d
9
+ from models.nafnet import SimpleGate, NAFBlock
10
+
11
+
12
+ class ICB(nn.Module):
13
+ """
14
+ Instruction Condition Block (ICB)
15
+ Paper Section 3.3
16
+ """
17
+
18
+ def __init__(self, feature_dim, text_dim=768):
19
+ super(ICB, self).__init__()
20
+ self.fc = nn.Linear(text_dim, feature_dim)
21
+ self.block = NAFBlock(feature_dim)
22
+ self.beta = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
23
+ self.gamma = nn.Parameter(torch.zeros((1, feature_dim, 1, 1)), requires_grad=True)
24
+
25
+ def forward(self, x, text_embedding):
26
+ gating_factors = torch.sigmoid(self.fc(text_embedding))
27
+ gating_factors = gating_factors.unsqueeze(-1).unsqueeze(-1)
28
+
29
+ f = x * self.gamma + self.beta # 1) learned feature scaling/modulation
30
+ f = f * gating_factors # 2) (soft) feature routing based on text
31
+ f = self.block(f) # 3) block feature enhancement
32
+ return f + x
33
+
34
+
35
+ class InstructIR(nn.Module):
36
+ """
37
+ InstructIR model using NAFNet (ECCV 2022) as backbone.
38
+ The model takes as input an RGB image and a text embedding (encoded instruction).
39
+ Described in Paper Section 3.3
40
+ """
41
+
42
+ def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], txtdim=768):
43
+ super().__init__()
44
+
45
+ self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
46
+ bias=True)
47
+ self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
48
+ bias=True)
49
+
50
+ self.encoders = nn.ModuleList()
51
+ self.decoders = nn.ModuleList()
52
+ self.middle_blks = nn.ModuleList()
53
+ self.ups = nn.ModuleList()
54
+ self.downs = nn.ModuleList()
55
+ self.enc_cond = nn.ModuleList()
56
+ self.dec_cond = nn.ModuleList()
57
+
58
+ chan = width
59
+ for num in enc_blk_nums:
60
+ self.encoders.append(
61
+ nn.Sequential(
62
+ *[NAFBlock(chan) for _ in range(num)]
63
+ )
64
+ )
65
+
66
+ self.enc_cond.append(ICB(chan, txtdim))
67
+
68
+ self.downs.append(
69
+ nn.Conv2d(chan, 2*chan, 2, 2)
70
+ )
71
+ chan = chan * 2
72
+
73
+ self.middle_blks = nn.Sequential(
74
+ *[NAFBlock(chan) for _ in range(middle_blk_num)]
75
+ )
76
+
77
+ for num in dec_blk_nums:
78
+ self.ups.append(
79
+ nn.Sequential(
80
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
81
+ nn.PixelShuffle(2)
82
+ )
83
+ )
84
+ chan = chan // 2
85
+ self.decoders.append(
86
+ nn.Sequential(
87
+ *[NAFBlock(chan) for _ in range(num)]
88
+ )
89
+ )
90
+ # Add text embedding as modulation
91
+ self.dec_cond.append(ICB(chan, txtdim))
92
+
93
+ self.padder_size = 2 ** len(self.encoders)
94
+
95
+ def forward(self, inp, txtembd):
96
+ B, C, H, W = inp.shape
97
+ inp = self.check_image_size(inp)
98
+
99
+ x = self.intro(inp)
100
+ encs = []
101
+
102
+ for encoder, enc_mod, down in zip(self.encoders, self.enc_cond, self.downs):
103
+ x = encoder(x)
104
+ x = enc_mod(x, txtembd)
105
+ encs.append(x)
106
+ x = down(x)
107
+
108
+ x = self.middle_blks(x)
109
+
110
+ for decoder, up, enc_skip, dec_mod in zip(self.decoders, self.ups, encs[::-1], self.dec_cond):
111
+ x = up(x)
112
+ x = x + enc_skip
113
+ x = decoder(x)
114
+ x = dec_mod(x, txtembd)
115
+
116
+ x = self.ending(x)
117
+ x = x + inp
118
+
119
+ return x[:, :, :H, :W]
120
+
121
+ def check_image_size(self, x):
122
+ _, _, h, w = x.size()
123
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
124
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
125
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
126
+ return x
127
+
128
+
129
+ def create_model(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2], txtdim=768):
130
+
131
+ net = InstructIR(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
132
+ enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, txtdim=txtdim)
133
+
134
+ return net
models/nafnet.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2022 megvii-model. All Rights Reserved.
3
+ # ------------------------------------------------------------------------
4
+ # Source: https://github.com/megvii-research/NAFNet
5
+
6
+ '''
7
+ Simple Baselines for Image Restoration
8
+
9
+ @article{chen2022simple,
10
+ title={Simple Baselines for Image Restoration},
11
+ author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
12
+ journal={arXiv preprint arXiv:2204.04676},
13
+ year={2022}
14
+ }
15
+ '''
16
+
17
+ import math
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.nn import init as init
22
+ from torch.nn.modules.batchnorm import _BatchNorm
23
+ from models.nafnet_utils import Local_Base, LayerNorm2d
24
+
25
+
26
+ class SimpleGate(nn.Module):
27
+ def forward(self, x):
28
+ x1, x2 = x.chunk(2, dim=1)
29
+ return x1 * x2
30
+
31
+ class NAFBlock(nn.Module):
32
+ def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
33
+ super().__init__()
34
+ dw_channel = c * DW_Expand
35
+ self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
36
+ self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
37
+ bias=True)
38
+ self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
39
+
40
+ # Simplified Channel Attention
41
+ self.sca = nn.Sequential(
42
+ nn.AdaptiveAvgPool2d(1),
43
+ nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
44
+ groups=1, bias=True),
45
+ )
46
+
47
+ # SimpleGate
48
+ self.sg = SimpleGate()
49
+
50
+ ffn_channel = FFN_Expand * c
51
+ self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
52
+ self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
53
+
54
+ self.norm1 = LayerNorm2d(c)
55
+ self.norm2 = LayerNorm2d(c)
56
+
57
+ self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
58
+ self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
59
+
60
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
61
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
62
+
63
+ def forward(self, inp):
64
+ x = inp
65
+
66
+ x = self.norm1(x)
67
+
68
+ x = self.conv1(x)
69
+ x = self.conv2(x)
70
+ x = self.sg(x)
71
+ x = x * self.sca(x)
72
+ x = self.conv3(x)
73
+
74
+ x = self.dropout1(x)
75
+
76
+ y = inp + x * self.beta
77
+
78
+ x = self.conv4(self.norm2(y))
79
+ x = self.sg(x)
80
+ x = self.conv5(x)
81
+
82
+ x = self.dropout2(x)
83
+
84
+ return y + x * self.gamma
85
+
86
+
87
+ class NAFNet(nn.Module):
88
+
89
+ def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]):
90
+ super().__init__()
91
+
92
+ self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
93
+ bias=True)
94
+ self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
95
+ bias=True)
96
+
97
+ self.encoders = nn.ModuleList()
98
+ self.decoders = nn.ModuleList()
99
+ self.middle_blks = nn.ModuleList()
100
+ self.ups = nn.ModuleList()
101
+ self.downs = nn.ModuleList()
102
+
103
+ chan = width
104
+ for num in enc_blk_nums:
105
+ self.encoders.append(
106
+ nn.Sequential(
107
+ *[NAFBlock(chan) for _ in range(num)]
108
+ )
109
+ )
110
+ self.downs.append(
111
+ nn.Conv2d(chan, 2*chan, 2, 2)
112
+ )
113
+ chan = chan * 2
114
+
115
+ self.middle_blks = \
116
+ nn.Sequential(
117
+ *[NAFBlock(chan) for _ in range(middle_blk_num)]
118
+ )
119
+
120
+ for num in dec_blk_nums:
121
+ self.ups.append(
122
+ nn.Sequential(
123
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
124
+ nn.PixelShuffle(2)
125
+ )
126
+ )
127
+ chan = chan // 2
128
+ self.decoders.append(
129
+ nn.Sequential(
130
+ *[NAFBlock(chan) for _ in range(num)]
131
+ )
132
+ )
133
+
134
+ self.padder_size = 2 ** len(self.encoders)
135
+
136
+ def forward(self, inp):
137
+ B, C, H, W = inp.shape
138
+ inp = self.check_image_size(inp)
139
+
140
+ x = self.intro(inp)
141
+
142
+ encs = []
143
+
144
+ for encoder, down in zip(self.encoders, self.downs):
145
+ x = encoder(x)
146
+ encs.append(x)
147
+ x = down(x)
148
+
149
+ x = self.middle_blks(x)
150
+
151
+ for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
152
+ x = up(x)
153
+ x = x + enc_skip
154
+ x = decoder(x)
155
+
156
+ x = self.ending(x)
157
+ x = x + inp
158
+
159
+ return x[:, :, :H, :W]
160
+
161
+ def check_image_size(self, x):
162
+ _, _, h, w = x.size()
163
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
164
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
165
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
166
+ return x
167
+
168
+ class NAFNetLocal(Local_Base, NAFNet):
169
+ def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs):
170
+ Local_Base.__init__(self)
171
+ NAFNet.__init__(self, *args, **kwargs)
172
+
173
+ N, C, H, W = train_size
174
+ base_size = (int(H * 1.5), int(W * 1.5))
175
+
176
+ self.eval()
177
+ with torch.no_grad():
178
+ self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
179
+
180
+
181
+ def create_nafnet(input_channels = 3, width = 32, enc_blks = [2, 2, 4, 8], middle_blk_num = 12, dec_blks = [2, 2, 2, 2]):
182
+ """
183
+ Create Nafnet model
184
+ https://github.com/megvii-research/NAFNet/blob/main/options/test/SIDD/NAFNet-width32.yml
185
+ """
186
+
187
+ net = NAFNet(img_channel=input_channels, width=width, middle_blk_num=middle_blk_num,
188
+ enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
189
+
190
+ # inp_shape = (3, 256, 256)
191
+
192
+ # from ptflops import get_model_complexity_info
193
+
194
+ # macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
195
+
196
+ # params = float(params[:-3])
197
+ # macs = float(macs[:-4])
198
+
199
+ # print(macs, params)
200
+
201
+ return net
models/nafnet_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------
2
+ # Copyright (c) 2022 megvii-model. All Rights Reserved.
3
+ # ------------------------------------------------------------------------
4
+ # Source: https://github.com/megvii-research/NAFNet
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import math
11
+
12
+ class LayerNormFunction(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ def forward(ctx, x, weight, bias, eps):
16
+ ctx.eps = eps
17
+ N, C, H, W = x.size()
18
+ mu = x.mean(1, keepdim=True)
19
+ var = (x - mu).pow(2).mean(1, keepdim=True)
20
+ y = (x - mu) / (var + eps).sqrt()
21
+ ctx.save_for_backward(y, var, weight)
22
+ y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
23
+ return y
24
+
25
+ @staticmethod
26
+ def backward(ctx, grad_output):
27
+ eps = ctx.eps
28
+
29
+ N, C, H, W = grad_output.size()
30
+ y, var, weight = ctx.saved_variables
31
+ g = grad_output * weight.view(1, C, 1, 1)
32
+ mean_g = g.mean(dim=1, keepdim=True)
33
+
34
+ mean_gy = (g * y).mean(dim=1, keepdim=True)
35
+ gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
36
+ return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
37
+ dim=0), None
38
+
39
+ class LayerNorm2d(nn.Module):
40
+
41
+ def __init__(self, channels, eps=1e-6):
42
+ super(LayerNorm2d, self).__init__()
43
+ self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
44
+ self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
45
+ self.eps = eps
46
+
47
+ def forward(self, x):
48
+ return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
49
+
50
+
51
+
52
+ class AvgPool2d(nn.Module):
53
+ def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
54
+ super().__init__()
55
+ self.kernel_size = kernel_size
56
+ self.base_size = base_size
57
+ self.auto_pad = auto_pad
58
+
59
+ # only used for fast implementation
60
+ self.fast_imp = fast_imp
61
+ self.rs = [5, 4, 3, 2, 1]
62
+ self.max_r1 = self.rs[0]
63
+ self.max_r2 = self.rs[0]
64
+ self.train_size = train_size
65
+
66
+ def extra_repr(self) -> str:
67
+ return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
68
+ self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
69
+ )
70
+
71
+ def forward(self, x):
72
+ if self.kernel_size is None and self.base_size:
73
+ train_size = self.train_size
74
+ if isinstance(self.base_size, int):
75
+ self.base_size = (self.base_size, self.base_size)
76
+ self.kernel_size = list(self.base_size)
77
+ self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
78
+ self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
79
+
80
+ # only used for fast implementation
81
+ self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
82
+ self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
83
+
84
+ if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
85
+ return F.adaptive_avg_pool2d(x, 1)
86
+
87
+ if self.fast_imp: # Non-equivalent implementation but faster
88
+ h, w = x.shape[2:]
89
+ if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
90
+ out = F.adaptive_avg_pool2d(x, 1)
91
+ else:
92
+ r1 = [r for r in self.rs if h % r == 0][0]
93
+ r2 = [r for r in self.rs if w % r == 0][0]
94
+ # reduction_constraint
95
+ r1 = min(self.max_r1, r1)
96
+ r2 = min(self.max_r2, r2)
97
+ s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
98
+ n, c, h, w = s.shape
99
+ k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
100
+ out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
101
+ out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
102
+ else:
103
+ n, c, h, w = x.shape
104
+ s = x.cumsum(dim=-1).cumsum_(dim=-2)
105
+ s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
106
+ k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
107
+ s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
108
+ out = s4 + s1 - s2 - s3
109
+ out = out / (k1 * k2)
110
+
111
+ if self.auto_pad:
112
+ n, c, h, w = x.shape
113
+ _h, _w = out.shape[2:]
114
+ # print(x.shape, self.kernel_size)
115
+ pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
116
+ out = torch.nn.functional.pad(out, pad2d, mode='replicate')
117
+
118
+ return out
119
+
120
+ def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
121
+ for n, m in model.named_children():
122
+ if len(list(m.children())) > 0:
123
+ ## compound module, go inside it
124
+ replace_layers(m, base_size, train_size, fast_imp, **kwargs)
125
+
126
+ if isinstance(m, nn.AdaptiveAvgPool2d):
127
+ pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
128
+ assert m.output_size == 1
129
+ setattr(model, n, pool)
130
+
131
+
132
+ '''
133
+ ref.
134
+ @article{chu2021tlsc,
135
+ title={Revisiting Global Statistics Aggregation for Improving Image Restoration},
136
+ author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin},
137
+ journal={arXiv preprint arXiv:2112.04491},
138
+ year={2021}
139
+ }
140
+ '''
141
+ class Local_Base():
142
+ def convert(self, *args, train_size, **kwargs):
143
+ replace_layers(self, *args, train_size=train_size, **kwargs)
144
+ imgs = torch.rand(train_size)
145
+ with torch.no_grad():
146
+ self.forward(imgs)
requirements_gradio.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ Pillow>=6.2.2
4
+ sentence-transformers==2.3.0
5
+ gradio==4.16.0
6
+ #gradio_imageslider==0.0.18
text/models.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from transformers import DistilBertModel, DistilBertTokenizer, AutoModel, AutoTokenizer
5
+ import os
6
+
7
+ # Models that use mean pooling
8
+ POOL_MODELS = {"sentence-transformers/all-MiniLM-L6-v2", "TaylorAI/bge-micro-v2"}
9
+
10
+ #Mean Pooling - Take attention mask into account for correct averaging
11
+ def mean_pooling(model_output, attention_mask):
12
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
13
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
14
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
15
+
16
+
17
+ class LanguageModel(nn.Module):
18
+ def __init__(self, model='distilbert-base-uncased'):
19
+ super(LanguageModel, self).__init__()
20
+
21
+ self.tokenizer = AutoTokenizer.from_pretrained(model)
22
+ self.model = AutoModel.from_pretrained(model)
23
+ self.model_name = model
24
+ # Remove the CLIP vision tower
25
+ if "clip" in self.model_name:
26
+ self.model.vision_model = None
27
+ # Freeze the pre-trained parameters (very important)
28
+ for param in self.model.parameters():
29
+ param.requires_grad = False
30
+
31
+ # Make sure to set evaluation mode (also important)
32
+ self.model.eval()
33
+
34
+ def forward(self, text_batch):
35
+ inputs = self.tokenizer(text_batch, padding=True, truncation=True, return_tensors="pt")
36
+ with torch.no_grad(): # Ensure no gradients are computed for this forward pass
37
+
38
+ if "clip" in self.model_name:
39
+ sentence_embedding = self.model.get_text_features(**inputs)
40
+ return sentence_embedding
41
+
42
+ outputs = self.model(**inputs)
43
+
44
+ if any(model in self.model_name for model in POOL_MODELS):
45
+ sentence_embeddings = mean_pooling(outputs, inputs['attention_mask'])
46
+ # Normalize embeddings
47
+ sentence_embedding = F.normalize(sentence_embeddings, p=2, dim=1)
48
+ else:
49
+ sentence_embedding = outputs.last_hidden_state[:, 0, :]
50
+ return sentence_embedding
51
+
52
+
53
+ class LMHead(nn.Module):
54
+ def __init__(self, embedding_dim=384, hidden_dim=256, num_classes=4):
55
+ super(LMHead, self).__init__()
56
+
57
+ self.fc1 = nn.Linear(embedding_dim, hidden_dim)
58
+ #self.gelu = nn.GELU()
59
+ self.fc2 = nn.Linear(hidden_dim, num_classes)
60
+
61
+ def forward(self, x):
62
+ embd = self.fc1(x)
63
+ embd = F.normalize(embd, p=2, dim=1)
64
+ deg_pred = self.fc2(embd)
65
+ return embd, deg_pred
text/sample_prompts.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "denoising": [
3
+ "Help me reduce the fuzziness in this image.",
4
+ "I need this image denoised ASAP.",
5
+ "Clean up this noisy image, it's an eyesore.",
6
+ "Can you clean the dots from my image?",
7
+ "Help me with my picture, it's full of tiny spots.",
8
+ "Clean up this image, it's all grainy."
9
+ ],
10
+ "deblurring": [
11
+ "Please, clean up this blurry photo.",
12
+ "My picture's not sharp, fix it.",
13
+ "Deblur my picture, it's too fuzzy.",
14
+ "Help, my photo is too blurry.",
15
+ "Please, make my image less smudgy."
16
+ ],
17
+ "dehazing": [
18
+ "Please, fix the haziness in my image.",
19
+ "I need to remove the haziness from this image.",
20
+ "Get rid of the fog in my image.",
21
+ "Fix my photo, it's too misty.",
22
+ "Help me, my photo is all hazy."
23
+ ],
24
+ "deraining": [
25
+ "I want to eliminate the water from this image.",
26
+ "Clear the rain from my picture.",
27
+ "I need to clear the rain from this image.",
28
+ "Can you get rid of the raindrops in my picture?"
29
+ ],
30
+ "sr": [
31
+ "I need to enhance the size and quality of this image.",
32
+ "My photo is lacking size and clarity; can you improve it?",
33
+ "I'd appreciate it if you could upscale this photo.",
34
+ "My picture is too little, enlarge it."
35
+ ],
36
+ "ambiguous": [
37
+ "Please, clear up the mess on this image.",
38
+ "I want this image to look good.",
39
+ "make it pop",
40
+ "Fix my photo, it's all messed up."
41
+ ],
42
+ "lol": [
43
+ "I took this photo during night, enhance it",
44
+ "The photo is too dark, improve exposure",
45
+ "my image has poor lighting conditions, can you fix it?",
46
+ "Can you make the image brighter?"
47
+ ],
48
+ "enhancement": [
49
+ "make my image look like DSLR",
50
+ "improve the colors of my image",
51
+ "enhance the colors of the image",
52
+ "Can you edit this to look like an award-winning photo?",
53
+ "I want the picture to be retouched for a professional portfolio."
54
+ ]
55
+ }