csukuangfj commited on
Commit
5427eec
·
1 Parent(s): 2b26c31

add export script

Browse files
Files changed (5) hide show
  1. __init__.py +0 -0
  2. convert_to_pb.py +91 -0
  3. convert_to_torch.py +240 -0
  4. run.sh +30 -0
  5. unet.py +150 -0
__init__.py ADDED
File without changes
convert_to_pb.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
3
+
4
+ # Please see ./run.sh for usages
5
+ import argparse
6
+ import os
7
+
8
+ import tensorflow as tf
9
+
10
+
11
+ # Code in the following function is modified from
12
+ # https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
13
+ def freeze_graph(model_dir, output_node_names, output_filename):
14
+ """Extract the sub graph defined by the output nodes and convert all its
15
+ variables into constant
16
+
17
+ Args:
18
+ model_dir:
19
+ the root folder containing the checkpoint state file
20
+ output_node_names:
21
+ a string, containing all the output node's names, comma separated
22
+ output_filename:
23
+ Filename to save the graph.
24
+ """
25
+ if not tf.compat.v1.gfile.Exists(model_dir):
26
+ raise AssertionError(
27
+ "Export directory doesn't exists. Please specify an export "
28
+ "directory: %s" % model_dir
29
+ )
30
+
31
+ if not output_node_names:
32
+ print("You need to supply the name of a node to --output_node_names.")
33
+ return -1
34
+
35
+ # We retrieve our checkpoint fullpath
36
+ checkpoint = tf.train.get_checkpoint_state(model_dir)
37
+ input_checkpoint = checkpoint.model_checkpoint_path
38
+
39
+ # We precise the file fullname of our freezed graph
40
+ absolute_model_dir = "/".join(input_checkpoint.split("/")[:-1])
41
+ output_graph = output_filename
42
+
43
+ # We clear devices to allow TensorFlow to control on which device it will load operations
44
+ clear_devices = True
45
+
46
+ # We start a session using a temporary fresh Graph
47
+ with tf.compat.v1.Session(graph=tf.Graph()) as sess:
48
+ # We import the meta graph in the current default Graph
49
+ saver = tf.compat.v1.train.import_meta_graph(
50
+ input_checkpoint + ".meta", clear_devices=clear_devices
51
+ )
52
+
53
+ # We restore the weights
54
+ saver.restore(sess, input_checkpoint)
55
+
56
+ # We use a built-in TF helper to export variables to constants
57
+ output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
58
+ sess, # The session is used to retrieve the weights
59
+ tf.compat.v1.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
60
+ output_node_names.split(
61
+ ","
62
+ ), # The output node names are used to select the usefull nodes
63
+ )
64
+
65
+ # Finally we serialize and dump the output graph to the filesystem
66
+ with tf.compat.v1.gfile.GFile(output_graph, "wb") as f:
67
+ f.write(output_graph_def.SerializeToString())
68
+ print("%d ops in the final graph." % len(output_graph_def.node))
69
+
70
+ return output_graph_def
71
+
72
+
73
+ if __name__ == "__main__":
74
+ parser = argparse.ArgumentParser()
75
+ parser.add_argument(
76
+ "--model-dir", type=str, default="", help="Model folder to export"
77
+ )
78
+ parser.add_argument(
79
+ "--output-node-names",
80
+ type=str,
81
+ default="vocals_spectrogram/mul,accompaniment_spectrogram/mul",
82
+ help="The name of the output nodes, comma separated.",
83
+ )
84
+
85
+ parser.add_argument(
86
+ "--output-filename",
87
+ type=str,
88
+ )
89
+ args = parser.parse_args()
90
+
91
+ freeze_graph(args.model_dir, args.output_node_names, args.output_filename)
convert_to_torch.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
3
+
4
+ # Please see ./run.sh for usage
5
+
6
+ import argparse
7
+
8
+ import numpy as np
9
+ import tensorflow as tf
10
+ import torch
11
+ import torch.nn as nn
12
+ from unet import UNet
13
+
14
+
15
+ def load_graph(frozen_graph_filename):
16
+ # This function is modified from
17
+ # https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc
18
+
19
+ # We load the protobuf file from the disk and parse it to retrieve the
20
+ # unserialized graph_def
21
+ with tf.compat.v1.gfile.GFile(frozen_graph_filename, "rb") as f:
22
+ graph_def = tf.compat.v1.GraphDef()
23
+ graph_def.ParseFromString(f.read())
24
+
25
+ # Then, we import the graph_def into a new Graph and returns it
26
+ with tf.Graph().as_default() as graph:
27
+ # The name var will prefix every op/nodes in your graph
28
+ # Since we load everything in a new graph, this is not needed
29
+ # tf.import_graph_def(graph_def, name="prefix")
30
+ tf.import_graph_def(graph_def, name="")
31
+ return graph
32
+
33
+
34
+ def generate_waveform():
35
+ np.random.seed(20230821)
36
+ waveform = np.random.rand(60 * 44100).astype(np.float32)
37
+
38
+ # (num_samples, num_channels)
39
+ waveform = waveform.reshape(-1, 2)
40
+ return waveform
41
+
42
+
43
+ def get_param(graph, name):
44
+ with tf.compat.v1.Session(graph=graph) as sess:
45
+ constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
46
+ for constant_op in constant_ops:
47
+ if constant_op.name != name:
48
+ continue
49
+
50
+ value = sess.run(constant_op.outputs[0])
51
+ return torch.from_numpy(value)
52
+
53
+
54
+ @torch.no_grad()
55
+ def main(name):
56
+ graph = load_graph(f"./2stems/frozen_{name}_model.pb")
57
+ # for op in graph.get_operations():
58
+ # print(op.name)
59
+ x = graph.get_tensor_by_name("waveform:0")
60
+ # y = graph.get_tensor_by_name("Reshape:0")
61
+ y0 = graph.get_tensor_by_name("strided_slice_3:0")
62
+ # y1 = graph.get_tensor_by_name("leaky_re_lu_5/LeakyRelu:0")
63
+ # y1 = graph.get_tensor_by_name("conv2d_5/BiasAdd:0")
64
+ # y1 = graph.get_tensor_by_name("conv2d_transpose/BiasAdd:0")
65
+ # y1 = graph.get_tensor_by_name("re_lu/Relu:0")
66
+ # y1 = graph.get_tensor_by_name("batch_normalization_6/cond/FusedBatchNorm_1:0")
67
+ # y1 = graph.get_tensor_by_name("concatenate/concat:0")
68
+ # y1 = graph.get_tensor_by_name("concatenate_1/concat:0")
69
+ # y1 = graph.get_tensor_by_name("concatenate_4/concat:0")
70
+ # y1 = graph.get_tensor_by_name("batch_normalization_11/cond/FusedBatchNorm_1:0")
71
+ # y1 = graph.get_tensor_by_name("conv2d_6/Sigmoid:0")
72
+ y1 = graph.get_tensor_by_name(f"{name}_spectrogram/mul:0")
73
+
74
+ unet = UNet()
75
+ unet.eval()
76
+
77
+ # For the conv2d in tensorflow, weight shape is (kernel_h, kernel_w, in_channel, out_channel)
78
+ # default input shape is NHWC
79
+
80
+ # For the conv2d in torch, weight shape is (out_channel, in_channel, kernel_h, kernel_w)
81
+ # default input shape is NCHW
82
+ state_dict = unet.state_dict()
83
+ # print(list(state_dict.keys()))
84
+
85
+ if name == "vocals":
86
+ state_dict["conv.weight"] = get_param(graph, "conv2d/kernel").permute(
87
+ 3, 2, 0, 1
88
+ )
89
+ state_dict["conv.bias"] = get_param(graph, "conv2d/bias")
90
+
91
+ state_dict["bn.weight"] = get_param(graph, "batch_normalization/gamma")
92
+ state_dict["bn.bias"] = get_param(graph, "batch_normalization/beta")
93
+ state_dict["bn.running_mean"] = get_param(
94
+ graph, "batch_normalization/moving_mean"
95
+ )
96
+ state_dict["bn.running_var"] = get_param(
97
+ graph, "batch_normalization/moving_variance"
98
+ )
99
+
100
+ conv_offset = 0
101
+ bn_offset = 0
102
+ else:
103
+ state_dict["conv.weight"] = get_param(graph, "conv2d_7/kernel").permute(
104
+ 3, 2, 0, 1
105
+ )
106
+ state_dict["conv.bias"] = get_param(graph, "conv2d_7/bias")
107
+
108
+ state_dict["bn.weight"] = get_param(graph, "batch_normalization_12/gamma")
109
+ state_dict["bn.bias"] = get_param(graph, "batch_normalization_12/beta")
110
+ state_dict["bn.running_mean"] = get_param(
111
+ graph, "batch_normalization_12/moving_mean"
112
+ )
113
+ state_dict["bn.running_var"] = get_param(
114
+ graph, "batch_normalization_12/moving_variance"
115
+ )
116
+ conv_offset = 7
117
+ bn_offset = 12
118
+
119
+ for i in range(1, 6):
120
+ state_dict[f"conv{i}.weight"] = get_param(
121
+ graph, f"conv2d_{i+conv_offset}/kernel"
122
+ ).permute(3, 2, 0, 1)
123
+ state_dict[f"conv{i}.bias"] = get_param(graph, f"conv2d_{i+conv_offset}/bias")
124
+ if i >= 5:
125
+ continue
126
+ state_dict[f"bn{i}.weight"] = get_param(
127
+ graph, f"batch_normalization_{i+bn_offset}/gamma"
128
+ )
129
+ state_dict[f"bn{i}.bias"] = get_param(
130
+ graph, f"batch_normalization_{i+bn_offset}/beta"
131
+ )
132
+ state_dict[f"bn{i}.running_mean"] = get_param(
133
+ graph, f"batch_normalization_{i+bn_offset}/moving_mean"
134
+ )
135
+ state_dict[f"bn{i}.running_var"] = get_param(
136
+ graph, f"batch_normalization_{i+bn_offset}/moving_variance"
137
+ )
138
+
139
+ if name == "vocals":
140
+ state_dict["up1.weight"] = get_param(graph, "conv2d_transpose/kernel").permute(
141
+ 3, 2, 0, 1
142
+ )
143
+ state_dict["up1.bias"] = get_param(graph, "conv2d_transpose/bias")
144
+
145
+ state_dict["bn5.weight"] = get_param(graph, "batch_normalization_6/gamma")
146
+ state_dict["bn5.bias"] = get_param(graph, "batch_normalization_6/beta")
147
+ state_dict["bn5.running_mean"] = get_param(
148
+ graph, "batch_normalization_6/moving_mean"
149
+ )
150
+ state_dict["bn5.running_var"] = get_param(
151
+ graph, "batch_normalization_6/moving_variance"
152
+ )
153
+ conv_offset = 0
154
+ bn_offset = 0
155
+ else:
156
+ state_dict["up1.weight"] = get_param(
157
+ graph, "conv2d_transpose_6/kernel"
158
+ ).permute(3, 2, 0, 1)
159
+ state_dict["up1.bias"] = get_param(graph, "conv2d_transpose_6/bias")
160
+
161
+ state_dict["bn5.weight"] = get_param(graph, "batch_normalization_18/gamma")
162
+ state_dict["bn5.bias"] = get_param(graph, "batch_normalization_18/beta")
163
+ state_dict["bn5.running_mean"] = get_param(
164
+ graph, "batch_normalization_18/moving_mean"
165
+ )
166
+ state_dict["bn5.running_var"] = get_param(
167
+ graph, "batch_normalization_18/moving_variance"
168
+ )
169
+ conv_offset = 6
170
+ bn_offset = 12
171
+
172
+ for i in range(1, 6):
173
+ state_dict[f"up{i+1}.weight"] = get_param(
174
+ graph, f"conv2d_transpose_{i+conv_offset}/kernel"
175
+ ).permute(3, 2, 0, 1)
176
+
177
+ state_dict[f"up{i+1}.bias"] = get_param(
178
+ graph, f"conv2d_transpose_{i+conv_offset}/bias"
179
+ )
180
+
181
+ state_dict[f"bn{5+i}.weight"] = get_param(
182
+ graph, f"batch_normalization_{6+i+bn_offset}/gamma"
183
+ )
184
+ state_dict[f"bn{5+i}.bias"] = get_param(
185
+ graph, f"batch_normalization_{6+i+bn_offset}/beta"
186
+ )
187
+ state_dict[f"bn{5+i}.running_mean"] = get_param(
188
+ graph, f"batch_normalization_{6+i+bn_offset}/moving_mean"
189
+ )
190
+ state_dict[f"bn{5+i}.running_var"] = get_param(
191
+ graph, f"batch_normalization_{6+i+bn_offset}/moving_variance"
192
+ )
193
+
194
+ if name == "vocals":
195
+ state_dict["up7.weight"] = get_param(graph, "conv2d_6/kernel").permute(
196
+ 3, 2, 0, 1
197
+ )
198
+ state_dict["up7.bias"] = get_param(graph, "conv2d_6/bias")
199
+ else:
200
+ state_dict["up7.weight"] = get_param(graph, "conv2d_13/kernel").permute(
201
+ 3, 2, 0, 1
202
+ )
203
+ state_dict["up7.bias"] = get_param(graph, "conv2d_13/bias")
204
+
205
+ unet.load_state_dict(state_dict)
206
+
207
+ with tf.compat.v1.Session(graph=graph) as sess:
208
+ y0_out, y1_out = sess.run([y0, y1], feed_dict={x: generate_waveform()})
209
+ # y0_out = sess.run(y0, feed_dict={x: generate_waveform()})
210
+ # y1_out = sess.run(y1, feed_dict={x: generate_waveform()})
211
+ # print(y0_out.shape)
212
+ # print(y1_out.shape)
213
+
214
+ # for the batchnormalization in tensorflow,
215
+ # default input shape is NHWC
216
+
217
+ # for the batchnormalization in torch,
218
+ # default input shape is NCHW
219
+
220
+ # NHWC to NCHW
221
+ torch_y1_out = unet(torch.from_numpy(y0_out).permute(0, 3, 1, 2))
222
+
223
+ # print(torch_y1_out.shape, torch.from_numpy(y1_out).permute(0, 3, 1, 2).shape)
224
+ assert torch.allclose(
225
+ torch_y1_out, torch.from_numpy(y1_out).permute(0, 3, 1, 2), atol=1e-1
226
+ ), ((torch_y1_out - torch.from_numpy(y1_out).permute(0, 3, 1, 2)).abs().max())
227
+ torch.save(unet.state_dict(), f"2stems/{name}.pt")
228
+
229
+
230
+ if __name__ == "__main__":
231
+ parser = argparse.ArgumentParser()
232
+ parser.add_argument(
233
+ "--name",
234
+ type=str,
235
+ required=True,
236
+ choices=["vocals", "accompaniment"],
237
+ )
238
+ args = parser.parse_args()
239
+ print(vars(args))
240
+ main(args.name)
run.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
3
+
4
+ if [ ! -f 2stems.tar.gz ]; then
5
+ wget https://github.com/deezer/spleeter/releases/download/v1.4.0/2stems.tar.gz
6
+ fi
7
+
8
+ if [ ! -d ./2stems ]; then
9
+ mkdir -p 2stems
10
+ cd 2stems
11
+ tar xvf ../2stems.tar.gz
12
+ cd ..
13
+ fi
14
+
15
+ if [ ! -f 2stems/frozen_vocals_model.pb ]; then
16
+ python3 ./convert_to_pb.py \
17
+ --model-dir ./2stems \
18
+ --output-node-names vocals_spectrogram/mul \
19
+ --output-filename ./2stems/frozen_vocals_model.pb
20
+ fi
21
+
22
+ if [ ! -f 2stems/frozen_accompaniment_model.pb ]; then
23
+ python3 ./convert_to_pb.py \
24
+ --model-dir ./2stems \
25
+ --output-node-names accompaniment_spectrogram/mul \
26
+ --output-filename ./2stems/frozen_accompaniment_model.pb
27
+ fi
28
+
29
+ python3 ./convert_to_torch.py --name vocals
30
+ python3 ./convert_to_torch.py --name accompaniment
unet.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
2
+
3
+ import torch
4
+
5
+
6
+ class UNet(torch.nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.conv = torch.nn.Conv2d(2, 16, kernel_size=5, stride=(2, 2), padding=0)
10
+ self.bn = torch.nn.BatchNorm2d(
11
+ 16, track_running_stats=True, eps=1e-3, momentum=0.01
12
+ )
13
+ #
14
+ self.conv1 = torch.nn.Conv2d(16, 32, kernel_size=5, stride=(2, 2), padding=0)
15
+ self.bn1 = torch.nn.BatchNorm2d(
16
+ 32, track_running_stats=True, eps=1e-3, momentum=0.01
17
+ )
18
+
19
+ self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=5, stride=(2, 2), padding=0)
20
+ self.bn2 = torch.nn.BatchNorm2d(
21
+ 64, track_running_stats=True, eps=1e-3, momentum=0.01
22
+ )
23
+
24
+ self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=5, stride=(2, 2), padding=0)
25
+ self.bn3 = torch.nn.BatchNorm2d(
26
+ 128, track_running_stats=True, eps=1e-3, momentum=0.01
27
+ )
28
+
29
+ self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=5, stride=(2, 2), padding=0)
30
+ self.bn4 = torch.nn.BatchNorm2d(
31
+ 256, track_running_stats=True, eps=1e-3, momentum=0.01
32
+ )
33
+
34
+ self.conv5 = torch.nn.Conv2d(256, 512, kernel_size=5, stride=(2, 2), padding=0)
35
+
36
+ self.up1 = torch.nn.ConvTranspose2d(512, 256, kernel_size=5, stride=2)
37
+ self.bn5 = torch.nn.BatchNorm2d(
38
+ 256, track_running_stats=True, eps=1e-3, momentum=0.01
39
+ )
40
+
41
+ self.up2 = torch.nn.ConvTranspose2d(512, 128, kernel_size=5, stride=2)
42
+ self.bn6 = torch.nn.BatchNorm2d(
43
+ 128, track_running_stats=True, eps=1e-3, momentum=0.01
44
+ )
45
+
46
+ self.up3 = torch.nn.ConvTranspose2d(256, 64, kernel_size=5, stride=2)
47
+ self.bn7 = torch.nn.BatchNorm2d(
48
+ 64, track_running_stats=True, eps=1e-3, momentum=0.01
49
+ )
50
+
51
+ self.up4 = torch.nn.ConvTranspose2d(128, 32, kernel_size=5, stride=2)
52
+ self.bn8 = torch.nn.BatchNorm2d(
53
+ 32, track_running_stats=True, eps=1e-3, momentum=0.01
54
+ )
55
+
56
+ self.up5 = torch.nn.ConvTranspose2d(64, 16, kernel_size=5, stride=2)
57
+ self.bn9 = torch.nn.BatchNorm2d(
58
+ 16, track_running_stats=True, eps=1e-3, momentum=0.01
59
+ )
60
+
61
+ self.up6 = torch.nn.ConvTranspose2d(32, 1, kernel_size=5, stride=2)
62
+ self.bn10 = torch.nn.BatchNorm2d(
63
+ 1, track_running_stats=True, eps=1e-3, momentum=0.01
64
+ )
65
+
66
+ # output logit is False, so we need self.up7
67
+ self.up7 = torch.nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3)
68
+
69
+ def forward(self, x):
70
+ in_x = x
71
+ # in_x is (3, 2, 512, 1024) = (T, 2, 512, 1024)
72
+ x = torch.nn.functional.pad(x, (1, 2, 1, 2), "constant", 0)
73
+ conv1 = self.conv(x)
74
+ batch1 = self.bn(conv1)
75
+ rel1 = torch.nn.functional.leaky_relu(batch1, negative_slope=0.2)
76
+
77
+ x = torch.nn.functional.pad(rel1, (1, 2, 1, 2), "constant", 0)
78
+ conv2 = self.conv1(x) # (3, 32, 128, 256)
79
+ batch2 = self.bn1(conv2)
80
+ rel2 = torch.nn.functional.leaky_relu(
81
+ batch2, negative_slope=0.2
82
+ ) # (3, 32, 128, 256)
83
+
84
+ x = torch.nn.functional.pad(rel2, (1, 2, 1, 2), "constant", 0)
85
+ conv3 = self.conv2(x) # (3, 64, 64, 128)
86
+ batch3 = self.bn2(conv3)
87
+ rel3 = torch.nn.functional.leaky_relu(
88
+ batch3, negative_slope=0.2
89
+ ) # (3, 64, 64, 128)
90
+
91
+ x = torch.nn.functional.pad(rel3, (1, 2, 1, 2), "constant", 0)
92
+ conv4 = self.conv3(x) # (3, 128, 32, 64)
93
+ batch4 = self.bn3(conv4)
94
+ rel4 = torch.nn.functional.leaky_relu(
95
+ batch4, negative_slope=0.2
96
+ ) # (3, 128, 32, 64)
97
+
98
+ x = torch.nn.functional.pad(rel4, (1, 2, 1, 2), "constant", 0)
99
+ conv5 = self.conv4(x) # (3, 256, 16, 32)
100
+ batch5 = self.bn4(conv5)
101
+ rel6 = torch.nn.functional.leaky_relu(
102
+ batch5, negative_slope=0.2
103
+ ) # (3, 256, 16, 32)
104
+
105
+ x = torch.nn.functional.pad(rel6, (1, 2, 1, 2), "constant", 0)
106
+ conv6 = self.conv5(x) # (3, 512, 8, 16)
107
+
108
+ up1 = self.up1(conv6)
109
+ up1 = up1[:, :, 1:-2, 1:-2] # (3, 256, 16, 32)
110
+ up1 = torch.nn.functional.relu(up1)
111
+ batch7 = self.bn5(up1)
112
+ merge1 = torch.cat([conv5, batch7], axis=1) # (3, 512, 16, 32)
113
+
114
+ up2 = self.up2(merge1)
115
+ up2 = up2[:, :, 1:-2, 1:-2]
116
+ up2 = torch.nn.functional.relu(up2)
117
+ batch8 = self.bn6(up2)
118
+
119
+ merge2 = torch.cat([conv4, batch8], axis=1) # (3, 256, 32, 64)
120
+
121
+ up3 = self.up3(merge2)
122
+ up3 = up3[:, :, 1:-2, 1:-2]
123
+ up3 = torch.nn.functional.relu(up3)
124
+ batch9 = self.bn7(up3)
125
+
126
+ merge3 = torch.cat([conv3, batch9], axis=1) # (3, 128, 64, 128)
127
+
128
+ up4 = self.up4(merge3)
129
+ up4 = up4[:, :, 1:-2, 1:-2]
130
+ up4 = torch.nn.functional.relu(up4)
131
+ batch10 = self.bn8(up4)
132
+
133
+ merge4 = torch.cat([conv2, batch10], axis=1) # (3, 64, 128, 256)
134
+
135
+ up5 = self.up5(merge4)
136
+ up5 = up5[:, :, 1:-2, 1:-2]
137
+ up5 = torch.nn.functional.relu(up5)
138
+ batch11 = self.bn9(up5)
139
+
140
+ merge5 = torch.cat([conv1, batch11], axis=1) # (3, 32, 256, 512)
141
+
142
+ up6 = self.up6(merge5)
143
+ up6 = up6[:, :, 1:-2, 1:-2]
144
+ up6 = torch.nn.functional.relu(up6)
145
+ batch12 = self.bn10(up6) # (3, 1, 512, 1024) = (T, 1, 512, 1024)
146
+
147
+ up7 = self.up7(batch12)
148
+ up7 = torch.sigmoid(up7) # (3, 2, 512, 1024)
149
+
150
+ return up7 * in_x