Esmail-AGumaan commited on
Commit
898fdaa
1 Parent(s): 8182f5b

Update decoder.py

Browse files
Files changed (1) hide show
  1. decoder.py +99 -99
decoder.py CHANGED
@@ -1,100 +1,100 @@
1
- import torch
2
- from torch import nn
3
- from torch.nn import functional as F
4
- from nanograd.models.stable_diffusion.attention import SelfAttention
5
-
6
- class VAE_AttentionBlock(nn.Module):
7
- def __init__(self, channels):
8
- super().__init__()
9
- self.groupnorm = nn.GroupNorm(32, channels)
10
- self.attention = SelfAttention(1, channels)
11
-
12
- def forward(self, x):
13
- residue = x
14
- x = self.groupnorm(x)
15
- n, c, h, w = x.shape
16
- x = x.view((n, c, h * w))
17
- x = x.transpose(-1, -2)
18
- x = self.attention(x)
19
- x = x.transpose(-1, -2)
20
- x = x.view((n, c, h, w))
21
- x += residue
22
-
23
- return x
24
-
25
- class VAE_ResidualBlock(nn.Module):
26
- def __init__(self, in_channels, out_channels):
27
- super().__init__()
28
- self.groupnorm_1 = nn.GroupNorm(32, in_channels)
29
- self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
30
-
31
- self.groupnorm_2 = nn.GroupNorm(32, out_channels)
32
- self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
33
-
34
- if in_channels == out_channels:
35
- self.residual_layer = nn.Identity()
36
- else:
37
- self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
38
-
39
- def forward(self, x):
40
- residue = x
41
- x = self.groupnorm_1(x)
42
- x = F.silu(x)
43
- x = self.conv_1(x)
44
- x = self.groupnorm_2(x)
45
- x = F.silu(x)
46
- x = self.conv_2(x)
47
-
48
- return x + self.residual_layer(residue)
49
-
50
- class VAE_Decoder(nn.Sequential):
51
- def __init__(self):
52
- super().__init__(
53
- nn.Conv2d(4, 4, kernel_size=1, padding=0),
54
- nn.Conv2d(4, 512, kernel_size=3, padding=1),
55
- VAE_ResidualBlock(512, 512),
56
- VAE_AttentionBlock(512),
57
- VAE_ResidualBlock(512, 512),
58
- VAE_ResidualBlock(512, 512),
59
- VAE_ResidualBlock(512, 512),
60
- VAE_ResidualBlock(512, 512),
61
-
62
- # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 4, Width / 4)
63
- nn.Upsample(scale_factor=2),
64
-
65
- nn.Conv2d(512, 512, kernel_size=3, padding=1),
66
-
67
- VAE_ResidualBlock(512, 512),
68
- VAE_ResidualBlock(512, 512),
69
- VAE_ResidualBlock(512, 512),
70
-
71
- # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 2, Width / 2)
72
- nn.Upsample(scale_factor=2),
73
-
74
- nn.Conv2d(512, 512, kernel_size=3, padding=1),
75
-
76
- VAE_ResidualBlock(512, 256),
77
- VAE_ResidualBlock(256, 256),
78
- VAE_ResidualBlock(256, 256),
79
-
80
- nn.Upsample(scale_factor=2),
81
-
82
- nn.Conv2d(256, 256, kernel_size=3, padding=1),
83
-
84
- VAE_ResidualBlock(256, 128),
85
- VAE_ResidualBlock(128, 128),
86
- VAE_ResidualBlock(128, 128),
87
-
88
- nn.GroupNorm(32, 128),
89
-
90
- nn.SiLU(),
91
-
92
- nn.Conv2d(128, 3, kernel_size=3, padding=1),
93
- )
94
-
95
- def forward(self, x):
96
- x /= 0.18215
97
-
98
- for module in self:
99
- x = module(x)
100
  return x
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from attention import SelfAttention
5
+
6
+ class VAE_AttentionBlock(nn.Module):
7
+ def __init__(self, channels):
8
+ super().__init__()
9
+ self.groupnorm = nn.GroupNorm(32, channels)
10
+ self.attention = SelfAttention(1, channels)
11
+
12
+ def forward(self, x):
13
+ residue = x
14
+ x = self.groupnorm(x)
15
+ n, c, h, w = x.shape
16
+ x = x.view((n, c, h * w))
17
+ x = x.transpose(-1, -2)
18
+ x = self.attention(x)
19
+ x = x.transpose(-1, -2)
20
+ x = x.view((n, c, h, w))
21
+ x += residue
22
+
23
+ return x
24
+
25
+ class VAE_ResidualBlock(nn.Module):
26
+ def __init__(self, in_channels, out_channels):
27
+ super().__init__()
28
+ self.groupnorm_1 = nn.GroupNorm(32, in_channels)
29
+ self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
30
+
31
+ self.groupnorm_2 = nn.GroupNorm(32, out_channels)
32
+ self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
33
+
34
+ if in_channels == out_channels:
35
+ self.residual_layer = nn.Identity()
36
+ else:
37
+ self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
38
+
39
+ def forward(self, x):
40
+ residue = x
41
+ x = self.groupnorm_1(x)
42
+ x = F.silu(x)
43
+ x = self.conv_1(x)
44
+ x = self.groupnorm_2(x)
45
+ x = F.silu(x)
46
+ x = self.conv_2(x)
47
+
48
+ return x + self.residual_layer(residue)
49
+
50
+ class VAE_Decoder(nn.Sequential):
51
+ def __init__(self):
52
+ super().__init__(
53
+ nn.Conv2d(4, 4, kernel_size=1, padding=0),
54
+ nn.Conv2d(4, 512, kernel_size=3, padding=1),
55
+ VAE_ResidualBlock(512, 512),
56
+ VAE_AttentionBlock(512),
57
+ VAE_ResidualBlock(512, 512),
58
+ VAE_ResidualBlock(512, 512),
59
+ VAE_ResidualBlock(512, 512),
60
+ VAE_ResidualBlock(512, 512),
61
+
62
+ # (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 4, Width / 4)
63
+ nn.Upsample(scale_factor=2),
64
+
65
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
66
+
67
+ VAE_ResidualBlock(512, 512),
68
+ VAE_ResidualBlock(512, 512),
69
+ VAE_ResidualBlock(512, 512),
70
+
71
+ # (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 2, Width / 2)
72
+ nn.Upsample(scale_factor=2),
73
+
74
+ nn.Conv2d(512, 512, kernel_size=3, padding=1),
75
+
76
+ VAE_ResidualBlock(512, 256),
77
+ VAE_ResidualBlock(256, 256),
78
+ VAE_ResidualBlock(256, 256),
79
+
80
+ nn.Upsample(scale_factor=2),
81
+
82
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
83
+
84
+ VAE_ResidualBlock(256, 128),
85
+ VAE_ResidualBlock(128, 128),
86
+ VAE_ResidualBlock(128, 128),
87
+
88
+ nn.GroupNorm(32, 128),
89
+
90
+ nn.SiLU(),
91
+
92
+ nn.Conv2d(128, 3, kernel_size=3, padding=1),
93
+ )
94
+
95
+ def forward(self, x):
96
+ x /= 0.18215
97
+
98
+ for module in self:
99
+ x = module(x)
100
  return x